In [11]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, models
!pip install h5py boto3 matplotlib numpy opencv-python
# !pip install -q tensorflow-addons



In [12]:

import boto3
from botocore.handlers import disable_signing
import os

# Create an S3 client with anonymous access
s3 = boto3.resource('s3')
s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)

bucket = s3.Bucket('sevir')
prefix = 'data/'

# Create directory to store files
os.makedirs('sevir_data', exist_ok=True)

# # List available files
# print("📦 Listing available .h5 files...")
# for obj in bucket.objects.filter(Prefix=prefix):
#     if obj.key.endswith('.h5'):
#         print(obj.key)

In [13]:
import os
import boto3
from botocore.handlers import disable_signing

# Connect anonymously to the SEVIR S3 bucket
s3 = boto3.resource('s3')
s3.meta.client.meta.events.register('choose-signer.s3.*', disable_signing)
bucket = s3.Bucket('sevir')

# Delete all files in a folder
def clear_folder(folder_path):
    if not os.path.exists(folder_path):
        return
    for f in os.listdir(folder_path):
        file_path = os.path.join(folder_path, f)
        os.remove(file_path)
    print(f"🧹 Cleared folder: {folder_path}")

# Download a single .h5 file
def download_one_ir107():
    prefix = 'data/ir107/'
    local_folder = 'SEVIR_data/ir107'
    os.makedirs(local_folder, exist_ok=True)

    print(f"\n📥 Searching for one .h5 file in {prefix} ...")
    for obj in bucket.objects.filter(Prefix=prefix):
        if obj.key.endswith('.h5'):
            filename = os.path.basename(obj.key)
            local_path = os.path.join(local_folder, filename)
            size_mb = obj.size / (1024 * 1024)
            print(f"➡️ Downloading {filename} ({size_mb:.2f} MB)...")
            bucket.download_file(obj.key, local_path)
            break  # Only download one file

# STEP 1: Clear ir069 and ir107 folders
clear_folder('SEVIR_data/ir069')
clear_folder('SEVIR_data/ir107')

# STEP 2: Download one file from ir107
download_one_ir107()

print("\n✅ Done: One ir107 file downloaded. ir069 and ir107 folders cleared before download.")


🧹 Cleared folder: SEVIR_data/ir107

📥 Searching for one .h5 file in data/ir107/ ...
➡️ Downloading SEVIR_IR107_RANDOMEVENTS_2018_0101_0430.h5 (1905.27 MB)...

✅ Done: One ir107 file downloaded. ir069 and ir107 folders cleared before download.


In [14]:
import h5py
import numpy as np

# file_path = '/content/SEVIR_data/ir107/SEVIR_IR107_RANDOMEVENTS_2018_0101_0430.h5'
file_path = '/kaggle/working/SEVIR_data/ir107/SEVIR_IR107_RANDOMEVENTS_2018_0101_0430.h5'

with h5py.File(file_path, 'r') as f:
    print("Top-level keys:", list(f.keys()))
    data = f['ir107'][:]  # Load entire ir107 dataset into memory

# Transpose to (event, time, height, width) => (553, 49, 192, 192)
data = data.transpose(0, 3, 1, 2)
print("Transposed shape:", data.shape)


Top-level keys: ['id', 'ir107']
Transposed shape: (553, 49, 192, 192)


In [15]:
# with h5py.File(file_path, 'r') as f:
#     def plot_event(sequence):
#         def normalize(frame):
#             return np.clip((frame - 180) / (300 - 180), 0, 1)
#         plt.figure(figsize=(15, 3))
#         for i, idx in enumerate([0, 8, 16, 24, 32, 40, 48]):
#             norm_frame = normalize(sequence[idx])
#             plt.subplot(1, 7, i + 1)
#             plt.imshow(norm_frame, cmap='inferno')
#             plt.title(f'Frame {idx}')
#             plt.axis('off')
#         plt.tight_layout()
#         plt.show()

#     for i in range(10):
#         print(f"🌀 Event {i}")
#         plot_event(f['ir107'][i])

In [16]:
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
import gc

def simple_memory_efficient_preprocessing(data):
    """Simplified preprocessing that won't crash your notebook"""
    print(f"Input data shape: {data.shape}")
    print(f"Data type: {data.dtype}")
    print(f"Temperature range: [{data.min():.1f}, {data.max():.1f}] K")
    
    # Memory management
    gc.collect()
    
    # Simple normalization
    def normalize(frames):
        return np.clip((frames - 180) / (300 - 180), 0, 1)
    
    # Simple activity check
    def has_activity(x, threshold=0.01):
        return np.var(x) > threshold
    
    X_inputs, Y_outputs = [], []
    
    for i, event in enumerate(data):
        if i % 50 == 0:
            print(f"Processing event {i}/{len(data)}")
            gc.collect()  # Force garbage collection every 50 events
        
        try:
            # Normalize
            norm_event = normalize(event.astype(np.float32))
            
            # Create single sequence (no overlapping to save memory)
            if norm_event.shape[0] >= 18:
                input_seq = norm_event[:12]      # First 12 frames
                target_seq = norm_event[12:18]   # Next 6 frames
                
                # Check activity
                if has_activity(input_seq) and has_activity(target_seq):
                    X_inputs.append(input_seq)
                    Y_outputs.append(target_seq)
        except Exception as e:
            print(f"Error processing event {i}: {e}")
            continue
    
    # Convert to arrays with memory management
    print("Converting to arrays...")
    X_inputs = np.array(X_inputs, dtype=np.float32)
    Y_outputs = np.array(Y_outputs, dtype=np.float32)
    
    print(f"✅ Preprocessing complete!")
    print(f"Valid sequences: {len(X_inputs)}")
    print(f"Input shape: {X_inputs.shape}")
    print(f"Output shape: {Y_outputs.shape}")
    print(f"Memory usage: {(X_inputs.nbytes + Y_outputs.nbytes) / 1e9:.2f} GB")
    
    # Simple train/val split
    X_train, X_val, Y_train, Y_val = train_test_split(
        X_inputs, Y_outputs, test_size=0.2, random_state=42
    )
    
    # Add channel dimension: (N, T, H, W, 1)
    X_train = X_train[..., np.newaxis]
    X_val = X_val[..., np.newaxis]
    Y_train = Y_train[..., np.newaxis]
    Y_val = Y_val[..., np.newaxis]
    
    print(f"\n📊 Data Split Summary:")
    print(f"Train: {X_train.shape[0]} samples")
    print(f"Val:   {X_val.shape[0]} samples")
    print(f"Final shapes: X{X_train.shape}, Y{Y_train.shape}")
    
    # Final memory cleanup
    del X_inputs, Y_outputs
    gc.collect()
    
    return X_train, X_val, Y_train, Y_val

# Memory management settings
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

# Limit TensorFlow memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Run preprocessing
X_train, X_val, Y_train, Y_val = simple_memory_efficient_preprocessing(data)


Input data shape: (553, 49, 192, 192)
Data type: int16
Temperature range: [-32768.0, 4590.0] K
Processing event 0/553
Processing event 50/553
Processing event 100/553
Processing event 150/553
Processing event 200/553
Processing event 250/553
Processing event 300/553
Processing event 350/553
Processing event 400/553
Processing event 450/553
Processing event 500/553
Processing event 550/553
Converting to arrays...
✅ Preprocessing complete!
Valid sequences: 259
Input shape: (259, 12, 192, 192)
Output shape: (259, 6, 192, 192)
Memory usage: 0.69 GB

📊 Data Split Summary:
Train: 207 samples
Val:   52 samples
Final shapes: X(207, 12, 192, 192, 1), Y(207, 6, 192, 192, 1)


In [17]:
# (X_train, X_val, X_test), (Y_train, Y_val, Y_test) = main_preprocessing_pipeline(data)

In [18]:
import os
os.environ['TF_ENABLE_LAYOUT_OPTIMIZER'] = '0'

import tensorflow as tf
from tensorflow.keras import layers, models, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.regularizers import l2
import numpy as np

# Simplified constants
L2_REG = 1e-5  # Reduced regularization
DROPOUT_RATE = 0.2  # Reduced dropout

def conv3d_block(x, filters, kernel_regularizer=None, dropout_rate=0.0):
    x = layers.Conv3D(filters, (3, 3, 3), padding='same', kernel_regularizer=kernel_regularizer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    if dropout_rate > 0:
        x = layers.Dropout(dropout_rate)(x)
    x = layers.Conv3D(filters, (3, 3, 3), padding='same', kernel_regularizer=kernel_regularizer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def simple_decoder3d_block(x, skip, filters, kernel_regularizer=None, dropout_rate=0.0):
    x = layers.UpSampling3D(size=(1, 2, 2))(x)
    x = layers.Conv3D(filters, (3, 3, 3), padding='same', kernel_regularizer=kernel_regularizer)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Concatenate()([x, skip])
    x = conv3d_block(x, filters, kernel_regularizer, dropout_rate)
    return x

def build_simple_effective_model(input_shape=(12, 192, 192, 1)):
    inputs = Input(shape=input_shape)

    # Simpler encoder
    e1 = conv3d_block(inputs, 16, l2(L2_REG), DROPOUT_RATE * 0.5)
    p1 = layers.MaxPooling3D((1, 2, 2))(e1)

    e2 = conv3d_block(p1, 32, l2(L2_REG), DROPOUT_RATE * 0.7)
    p2 = layers.MaxPooling3D((1, 2, 2))(e2)

    e3 = conv3d_block(p2, 64, l2(L2_REG), DROPOUT_RATE)
    p3 = layers.MaxPooling3D((1, 2, 2))(e3)

    # Simple ConvLSTM bottleneck (single direction)
    b = layers.Reshape((12, 24, 24, 64))(p3)
    b = layers.ConvLSTM2D(96, (3, 3), padding="same", return_sequences=True,
                          activation='relu', kernel_regularizer=l2(L2_REG))(b)
    b = layers.BatchNormalization()(b)
    b = layers.Dropout(DROPOUT_RATE)(b)
    b = layers.Reshape((12, 24, 24, 96))(b)

    # Simple decoder
    d1 = simple_decoder3d_block(b, e3, 64, l2(L2_REG), DROPOUT_RATE * 0.6)
    d2 = simple_decoder3d_block(d1, e2, 32, l2(L2_REG), DROPOUT_RATE * 0.4)
    d3 = simple_decoder3d_block(d2, e1, 16, l2(L2_REG), DROPOUT_RATE * 0.2)

    # Output layer
    output_conv = layers.Conv3D(1, (1, 1, 1), padding='same', activation='sigmoid')(d3)
    outputs = layers.Lambda(lambda t: t[:, -6:, :, :, :])(output_conv)

    return models.Model(inputs, outputs)

# FIXED: Simple, effective loss with proper type casting
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    # CRITICAL FIX: Cast both tensors to float32 to avoid type mismatch
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def simple_combined_loss(y_true, y_pred):
    # CRITICAL FIX: Cast both tensors to float32 at the start
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)

    # Just focal + dice, no connectivity loss
    bce = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    dice_loss = 1 - dice_coefficient(y_true, y_pred)
    return 0.6 * tf.reduce_mean(bce) + 0.4 * dice_loss

# Build and compile
model = build_simple_effective_model()
model.compile(
    optimizer=Adam(learning_rate=1e-4),  # Higher learning rate
    loss=simple_combined_loss,
    metrics=[dice_coefficient]
)

# Simple callbacks
callbacks = [
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7, verbose=1),
    EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True, verbose=1)
]

print("Simple, effective model ready for training!")
# model.summary()


Simple, effective model ready for training!


In [19]:
print("X_train shape:", X_val.shape)
print("Y_train shape:", Y_val.shape)


X_train shape: (52, 12, 192, 192, 1)
Y_train shape: (52, 6, 192, 192, 1)


In [20]:
# --- Training Hyperparameters ---
EPOCHS = 30
BATCH_SIZE = 8  # Reduced for memory efficiency with your model

# --- Start Training ---
print("Starting model training...")

history = model.fit(
    x=X_train,
    y=Y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, Y_val),
    callbacks=callbacks,  # Use the callbacks defined in your model code
    verbose=1
)

print("\nTraining finished.")
print("Model saved as 'simple_cloud_model.h5'")

# Analyze results immediately after training
plot_training_history(history)


Starting model training...
Epoch 1/30
[1m26/26[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m105s[0m 3s/step - dice_coefficient: 0.5535 - loss: 0.4622 - val_dice_coefficient: 0.3879 - val_loss: 0.6425 - learning_rate: 1.0000e-04
Epoch 2/30
[1m26/26[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 2s/step - dice_coefficient: 0.7783 - loss: 0.2393 - val_dice_coefficient: 0.3940 - val_loss: 0.6300 - learning_rate: 1.0000e-04
Epoch 3/30
[1m26/26[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 2s/step - dice_coefficient: 0.8053 - loss: 0.2004 - val_dice_coefficient: 0.4056 - val_loss: 0.6140 - learning_rate: 1.0000e-04
Epoch 4/30
[1m26/26[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 2s/step - dice_coefficient: 0.8241 - loss: 0.1797 - val_dice_coefficient: 0.4117 - val_loss: 0.5978 - learning_rate: 1.0000e-04
Epoch 5/30
[1m26/26[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 2s/step - dice_coefficient: 0.8386 - loss: 0.1721 - val_dice_coefficient: 0.4444

NameError: name 'plot_training_history' is not defined

In [None]:
model.save("cloud_predictor_model.h5")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter
from skimage import morphology

def plot_training_history(history):
    """Enhanced training history plotting"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    ax1.plot(history.history['loss'], label='Train Loss', linewidth=2)
    ax1.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title("Training vs Validation Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Dice Score
    ax2.plot(history.history['dice_coefficient'], label='Train Dice', linewidth=2)
    ax2.plot(history.history['val_dice_coefficient'], label='Val Dice', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Coefficient')
    ax2.set_title('Dice Score (Train vs Val)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Learning Rate
    if 'lr' in history.history:
        ax3.plot(history.history['lr'], linewidth=2)
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_title('Learning Rate Schedule')
        ax3.set_yscale('log')
        ax3.grid(True, alpha=0.3)
    else:
        ax3.text(0.5, 0.5, 'Learning Rate\nNot Available', ha='center', va='center', transform=ax3.transAxes)
        ax3.set_title('Learning Rate Schedule')
    
    # Overfitting Analysis
    train_loss = np.array(history.history['loss'])
    val_loss = np.array(history.history['val_loss'])
    overfitting_metric = val_loss - train_loss
    
    ax4.plot(overfitting_metric, linewidth=2, color='red')
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Val Loss - Train Loss')
    ax4.set_title('Overfitting Indicator')
    ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def gentle_post_processing(pred, threshold=0.3, min_size=20):
    """Gentle post-processing that preserves predictions"""
    pred_squeezed = pred.squeeze()
    
    # Convert to float32 if needed
    if pred_squeezed.dtype == np.float16:
        pred_squeezed = pred_squeezed.astype(np.float32)
    
    out = np.zeros_like(pred_squeezed)
    
    for t in range(pred_squeezed.shape[0]):
        # Lower threshold to keep more predictions
        binary = pred_squeezed[t] > threshold
        
        # Minimal cleanup
        binary = morphology.remove_small_objects(binary, min_size=min_size)
        
        out[t] = binary.astype(np.uint8)
    
    return out

def calculate_metrics(y_true, y_pred):
    """Calculate comprehensive metrics"""
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    
    dice = (2. * intersection) / (np.sum(y_true_f) + np.sum(y_pred_f) + 1e-6)
    iou = intersection / (union + 1e-6)
    
    tp = intersection
    fp = np.sum(y_pred_f) - intersection
    fn = np.sum(y_true_f) - intersection
    
    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
    
    return {'dice': dice, 'iou': iou, 'precision': precision, 'recall': recall, 'f1': f1}

def visualize_predictions_enhanced(X_sample, Y_true, Y_pred_raw, Y_pred_processed):
    """Enhanced visualization with metrics"""
    fig, axs = plt.subplots(4, 6, figsize=(20, 12))
    
    metrics_list = []
    
    for i in range(6):
        # Input frames
        axs[0, i].imshow(X_sample[i+6].squeeze(), cmap='gray')
        axs[0, i].set_title(f"Input t{i+6}")
        axs[0, i].axis('off')
        
        # Raw predictions
        axs[1, i].imshow(Y_pred_raw[i].squeeze(), cmap='viridis', vmin=0, vmax=1)
        axs[1, i].set_title(f"Raw Pred t{i+12}")
        axs[1, i].axis('off')
        
        # Processed predictions
        axs[2, i].imshow(Y_pred_processed[i], cmap='gray')
        axs[2, i].set_title(f"Processed t{i+12}")
        axs[2, i].axis('off')
        
        # Ground truth
        axs[3, i].imshow(Y_true[i].squeeze(), cmap='gray')
        axs[3, i].set_title(f"Ground Truth t{i+12}")
        axs[3, i].axis('off')
        
        # Calculate metrics for this frame
        metrics = calculate_metrics(Y_true[i].squeeze(), Y_pred_processed[i])
        metrics_list.append(metrics)
    
    plt.suptitle("Cloud Prediction Analysis", fontsize=16)
    plt.tight_layout()
    
    # Print average metrics
    avg_metrics = {}
    for key in metrics_list[0].keys():
        avg_metrics[key] = np.mean([m[key] for m in metrics_list])
    
    print("Average Metrics:")
    for key, value in avg_metrics.items():
        print(f"{key.upper()}: {value:.4f}")
    
    plt.show()
    
    return avg_metrics

# Analysis code
def analyze_model_results(model, X_val, Y_val, sample_idx=3):
    """Complete analysis function"""
    # 1. Plot training history
    plot_training_history(history)
    
    # 2. Choose a validation sample
    X_sample = X_val[sample_idx]
    Y_true = Y_val[sample_idx]
    
    print(f"Analyzing sample {sample_idx}...")
    
    # 3. Get model predictions
    print("Getting model predictions...")
    Y_pred_raw = model.predict(X_sample[np.newaxis, ...])[0].astype(np.float32)
    
    print(f"Prediction shape: {Y_pred_raw.shape}, dtype: {Y_pred_raw.dtype}")
    
    # 4. Apply gentle post-processing
    print("Applying gentle post-processing...")
    Y_pred_processed = gentle_post_processing(
        Y_pred_raw,
        threshold=0.3,
        min_size=20
    )
    
    # 5. Calculate metrics
    print("Calculating metrics...")
    metrics_list = []
    for i in range(6):
        metrics = calculate_metrics(
            Y_true[i].squeeze(),
            Y_pred_processed[i]
        )
        metrics_list.append(metrics)
        print(f"Time step {i+12}: Dice = {metrics['dice']:.4f}, IoU = {metrics['iou']:.4f}")
    
    # 6. Visualize results
    print("Creating visualization...")
    avg_metrics = visualize_predictions_enhanced(
        X_sample,
        Y_true,
        Y_pred_raw,
        Y_pred_processed
    )
    
    print("\n" + "="*50)
    print("ANALYSIS COMPLETE - SUMMARY RESULTS")
    print("="*50)
    for key, value in avg_metrics.items():
        print(f"  {key.upper()}: {value:.4f}")
    
    return avg_metrics

print("✅ All analysis functions loaded successfully!")
print("Run: analyze_model_results(model, X_val, Y_val) to analyze your results")


In [None]:
# 1. Plot training history
plot_training_history(history)

# 2. Choose a validation sample
sample_idx = 3
X_sample = X_val[sample_idx]
Y_true = Y_val[sample_idx]

print(f"Analyzing sample {sample_idx}...")

# 3. Get model predictions and fix dtype
print("Getting model predictions...")
Y_pred_raw = model.predict(X_sample[np.newaxis, ...])[0]

# CRITICAL FIX: Convert to float32 for scipy compatibility
Y_pred_raw = Y_pred_raw.astype(np.float32)

print(f"Prediction shape: {Y_pred_raw.shape}, dtype: {Y_pred_raw.dtype}")

# 4. Apply post-processing (now works!)
print("Applying post-processing...")
Y_pred_processed = gentle_post_processing(
    Y_pred_raw,
    threshold=0.3,
    min_size=20
)

# 5. Calculate metrics
print("Calculating metrics...")
metrics_list = []
for i in range(6):
    metrics = calculate_metrics(
        Y_true[i].squeeze(),
        Y_pred_processed[i]
    )
    metrics_list.append(metrics)
    print(f"Time step {i+12}: Dice = {metrics['dice']:.4f}, IoU = {metrics['iou']:.4f}")

# 6. Visualize results
print("Creating visualization...")
avg_metrics = visualize_predictions_enhanced(
    X_sample,
    Y_true,
    Y_pred_raw,
    Y_pred_processed
)

print("\n" + "="*50)
print("ANALYSIS COMPLETE - SUMMARY RESULTS")
print("="*50)
for key, value in avg_metrics.items():
    print(f"  {key.upper()}: {value:.4f}")


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
from scipy.ndimage import gaussian_filter, binary_erosion, binary_dilation
from skimage import morphology, measure
import tensorflow as tf

def plot_training_history(history):
    """Enhanced training history plotting"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    ax1.plot(history.history['loss'], label='Train Loss', linewidth=2)
    ax1.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title("Training vs Validation Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # Dice Score
    ax2.plot(history.history['dice_coefficient'], label='Train Dice', linewidth=2)
    ax2.plot(history.history['val_dice_coefficient'], label='Val Dice', linewidth=2)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Coefficient')
    ax2.set_title('Dice Score (Train vs Val)')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # Learning Rate
    if 'lr' in history.history:
        ax3.plot(history.history['lr'], linewidth=2)
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('Learning Rate')
        ax3.set_title('Learning Rate Schedule')
        ax3.set_yscale('log')
        ax3.grid(True, alpha=0.3)

    # Overfitting Analysis
    train_loss = np.array(history.history['loss'])
    val_loss = np.array(history.history['val_loss'])
    overfitting_metric = val_loss - train_loss

    ax4.plot(overfitting_metric, linewidth=2, color='red')
    ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Val Loss - Train Loss')
    ax4.set_title('Overfitting Indicator')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

def advanced_post_processing(pred, sigma=0.8, threshold=0.5, min_size=50):
    """Advanced post-processing with morphological operations"""
    pred_squeezed = pred.squeeze()
    out = np.zeros_like(pred_squeezed)

    for t in range(pred_squeezed.shape[0]):
        # Gaussian smoothing
        smoothed = gaussian_filter(pred_squeezed[t], sigma=sigma)

        # Adaptive thresholding
        binary = smoothed > threshold

        # Remove small objects
        binary = morphology.remove_small_objects(binary, min_size=min_size)

        # Fill holes
        binary = morphology.remove_small_holes(binary, area_threshold=min_size//2)

        # Morphological operations for cleaner boundaries
        kernel = morphology.disk(2)
        binary = morphology.opening(binary, kernel)
        binary = morphology.closing(binary, kernel)

        out[t] = binary.astype(np.uint8)

    return out

def calculate_metrics(y_true, y_pred):
    """Calculate comprehensive metrics"""
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()

    # Basic metrics
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection

    dice = (2. * intersection) / (np.sum(y_true_f) + np.sum(y_pred_f) + 1e-6)
    iou = intersection / (union + 1e-6)

    # Precision and Recall
    tp = intersection
    fp = np.sum(y_pred_f) - intersection
    fn = np.sum(y_true_f) - intersection

    precision = tp / (tp + fp + 1e-6)
    recall = tp / (tp + fn + 1e-6)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-6)

    return {
        'dice': dice,
        'iou': iou,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def visualize_predictions_enhanced(X_sample, Y_true, Y_pred_raw, Y_pred_processed):
    """Enhanced visualization with metrics"""
    fig, axs = plt.subplots(4, 6, figsize=(20, 12))

    metrics_list = []

    for i in range(6):
        # Input frames
        axs[0, i].imshow(X_sample[i+6].squeeze(), cmap='gray')
        axs[0, i].set_title(f"Input t{i+6}")
        axs[0, i].axis('off')

        # Raw predictions
        axs[1, i].imshow(Y_pred_raw[i].squeeze(), cmap='viridis', vmin=0, vmax=1)
        axs[1, i].set_title(f"Raw Pred t{i+12}")
        axs[1, i].axis('off')

        # Processed predictions
        axs[2, i].imshow(Y_pred_processed[i], cmap='gray')
        axs[2, i].set_title(f"Processed t{i+12}")
        axs[2, i].axis('off')

        # Ground truth
        axs[3, i].imshow(Y_true[i].squeeze(), cmap='gray')
        axs[3, i].set_title(f"Ground Truth t{i+12}")
        axs[3, i].axis('off')

        # Calculate metrics for this frame
        metrics = calculate_metrics(Y_true[i].squeeze(), Y_pred_processed[i])
        metrics_list.append(metrics)

    plt.suptitle("Enhanced Cloud Prediction Analysis", fontsize=16)
    plt.tight_layout()

    # Print average metrics
    avg_metrics = {}
    for key in metrics_list[0].keys():
        avg_metrics[key] = np.mean([m[key] for m in metrics_list])

    print("Average Metrics:")
    for key, value in avg_metrics.items():
        print(f"{key.upper()}: {value:.4f}")

    plt.show()

    return avg_metrics

# Data Augmentation Class
class DataAugmentation:
    def __init__(self, rotation_range=10, zoom_range=0.1, horizontal_flip=True):
        self.rotation_range = rotation_range
        self.zoom_range = zoom_range
        self.horizontal_flip = horizontal_flip

    def augment_batch(self, X_batch, Y_batch):
        """Apply augmentation to a batch"""
        augmented_X = []
        augmented_Y = []

        for i in range(len(X_batch)):
            # Random rotation
            if np.random.random() < 0.5:
                angle = np.random.uniform(-self.rotation_range, self.rotation_range)
                X_aug = self.rotate_3d(X_batch[i], angle)
                Y_aug = self.rotate_3d(Y_batch[i], angle)
            else:
                X_aug, Y_aug = X_batch[i], Y_batch[i]

            # Random horizontal flip
            if self.horizontal_flip and np.random.random() < 0.5:
                X_aug = np.flip(X_aug, axis=2)
                Y_aug = np.flip(Y_aug, axis=2)

            augmented_X.append(X_aug)
            augmented_Y.append(Y_aug)

        return np.array(augmented_X), np.array(augmented_Y)

    def rotate_3d(self, volume, angle):
        """Rotate 3D volume"""
        from scipy.ndimage import rotate
        return rotate(volume, angle, axes=(1, 2), reshape=False, mode='nearest')


In [None]:
# 2. Then all your analysis functions
plot_training_history(history)  # Analyze how training went
advanced_post_processing()      # Process predictions
calculate_metrics()             # Evaluate performance
visualize_predictions_enhanced() # Show results