In [None]:
import torch
from torch.utils.data import DataLoader
from dataloaderv2 import EchoVideoDataset
from modelv7 import MobileNetV3UNet
from utils import validate_model
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fix Dice Coefficient

def dice_coefficient(pred, target, smooth=1e-6):
    """
    Compute the Dice coefficient.
    Args:
        pred: Model predictions (logits or probabilities).
        target: Ground truth binary masks.
        smooth: Small value to avoid division by zero.
    Returns:
        Dice coefficient between 0 and 1.
    """
    pred = torch.sigmoid(pred)  # Convert logits to probabilities
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean()

# Update Loss Function
def combined_loss(pred, target, smooth=1e-6, alpha=0.5):
    """
    Combined loss of Dice and BCE.
    """
    pred = torch.sigmoid(pred)
    intersection = (pred * target).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3))
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss = 1 - dice.mean()
    bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(pred, target)
    return alpha * dice_loss + (1 - alpha) * bce_loss

# Initialize datasets
train_dataset = EchoVideoDataset(root="./data/echodynamic", split='train')
val_dataset = EchoVideoDataset(root="./data/echodynamic", split='val')

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

model = MobileNetV3UNet(in_channels=4, out_channels=1, config_name="large", backbone=True).to(device)

max_frames_per_step = 64  # Adjust this based on GPU memory
num_epochs = 10
lr = 1e-4
weight_decay = 1e-4
save_path = "best_model.pth"

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
scaler = GradScaler()

best_val_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    train_losses = []
    train_dice_scores = []

    # Training loop with progress bar
    with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs} [Train]", unit="batch") as pbar:
        for batch_inp, batch_mask in train_loader:
            batch_inp = batch_inp.squeeze(0).to(device)   # (num_frames,4,H,W)
            batch_mask = batch_mask.squeeze(0).to(device) # (num_frames,1,H,W)

            num_frames = batch_inp.shape[0]
            start_idx = 0
            while start_idx < num_frames:
                end_idx = min(start_idx + max_frames_per_step, num_frames)
                inp_chunk = batch_inp[start_idx:end_idx]
                mask_chunk = batch_mask[start_idx:end_idx]

                optimizer.zero_grad()
                with autocast():
                    pred_mask = model(inp_chunk)
                    loss = combined_loss(pred_mask, mask_chunk)

                scaler.scale(loss).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
                scaler.step(optimizer)
                scaler.update()

                train_losses.append(loss.item())
                dice_train = dice_coefficient(pred_mask, mask_chunk)  # Ensure sigmoid before Dice
                train_dice_scores.append(dice_train.item())

                start_idx = end_idx

            pbar.update(1)

    train_loss_mean = np.mean(train_losses)
    train_dice_mean = np.mean(train_dice_scores)

    # Validation loop
    val_loss_mean, val_dice_mean = validate_model(model, val_loader, device, combined_loss)

    # Logging
    print(f"Epoch {epoch+1}/{num_epochs} Summary: Train Loss: {train_loss_mean:.4f}, Train Dice: {train_dice_mean:.4f} | Val Loss: {val_loss_mean:.4f}, Val Dice: {val_dice_mean:.4f}")

    # Learning rate scheduler
    scheduler.step(val_loss_mean)

    # Save best model
    if val_loss_mean < best_val_loss:
        best_val_loss = val_loss_mean
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_val_loss,
        }, save_path)
        print(f"Best model updated at epoch {epoch+1} with Val Loss: {best_val_loss:.4f}")

    # Clear cache
    torch.cuda.empty_cache()

  scaler = GradScaler()
  m = torch.load(mask_path)
  with autocast():
Epoch 1/10 [Train]: 100%|██████████| 7465/7465 [28:02<00:00,  4.44batch/s]


Epoch 1/10 Summary: Train Loss: -12.1326, Train Dice: 1.8905 | Val Loss: -12.1303, Val Dice: 1.9154
Best model updated at epoch 1 with Val Loss: -12.1303


Epoch 2/10 [Train]: 100%|██████████| 7465/7465 [28:13<00:00,  4.41batch/s]


Epoch 2/10 Summary: Train Loss: -12.3575, Train Dice: 1.9209 | Val Loss: -12.1298, Val Dice: 1.9151


Epoch 3/10 [Train]: 100%|██████████| 7465/7465 [28:18<00:00,  4.39batch/s]


Epoch 3/10 Summary: Train Loss: -12.3617, Train Dice: 1.9220 | Val Loss: -11.8434, Val Dice: 1.8733


Epoch 4/10 [Train]: 100%|██████████| 7465/7465 [29:01<00:00,  4.29batch/s]


Epoch 4/10 Summary: Train Loss: -12.2740, Train Dice: 1.9092 | Val Loss: -12.2234, Val Dice: 1.9297
Best model updated at epoch 4 with Val Loss: -12.2234


Epoch 5/10 [Train]: 100%|██████████| 7465/7465 [28:43<00:00,  4.33batch/s]
