In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from torch.cuda.amp import autocast, GradScaler
import os

def dice_coefficient(pred, target, smooth=1e-6):
    pred_sig = torch.sigmoid(pred)
    intersection = (pred_sig * target).sum(dim=(2,3))
    union = pred_sig.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean()

def combined_loss(pred, target, smooth=1e-6, alpha=0.5):
    # Dice loss part
    pred_sig = torch.sigmoid(pred)
    intersection = (pred_sig * target).sum(dim=(2,3))
    union = pred_sig.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2. * intersection + smooth) / (union + smooth)
    dice_loss = 1 - dice.mean()

    # BCE loss part
    bce_loss = F.binary_cross_entropy_with_logits(pred, target)

    return alpha * dice_loss + (1 - alpha) * bce_loss

def validate_model(model, val_loader, device, criterion):
    model.eval()
    val_losses = []
    dice_scores = []
    with torch.no_grad():
        for batch_inp, batch_mask in val_loader:
            # batch_inp: (1, num_frames, 4, H, W)
            # batch_mask: (1, num_frames, 1, H, W)
            batch_inp = batch_inp.squeeze(0).to(device)   # (num_frames,4,H,W)
            batch_mask = batch_mask.squeeze(0).to(device) # (num_frames,1,H,W)

            pred_mask = model(batch_inp)
            loss = criterion(pred_mask, batch_mask)
            val_losses.append(loss.item())

            # Compute Dice coefficient for logging
            dice_val = dice_coefficient(pred_mask, batch_mask)
            dice_scores.append(dice_val.item())

    return np.mean(val_losses), np.mean(dice_scores)


def train_model(model, train_loader, val_loader, device, num_epochs=10, lr=1e-4, log_dir=None, save_path='best_model.pth'):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = combined_loss
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    scaler = GradScaler()

    # Optional TensorBoard logging
    writer = None
    if log_dir is not None:
        from torch.utils.tensorboard import SummaryWriter
        writer = SummaryWriter(log_dir=log_dir)

    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        train_dice_scores = []

        for batch_inp, batch_mask in train_loader:
            batch_inp = batch_inp.squeeze(0).to(device)   # (num_frames,4,H,W)
            batch_mask = batch_mask.squeeze(0).to(device) # (num_frames,1,H,W)

            optimizer.zero_grad()
            with autocast():
                pred_mask = model(batch_inp)
                loss = criterion(pred_mask, batch_mask)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())
            # Dice score for train batch
            dice_train = dice_coefficient(pred_mask, batch_mask)
            train_dice_scores.append(dice_train.item())

        # Validation
        val_loss, val_dice = validate_model(model, val_loader, device, criterion)
        train_loss_mean = np.mean(train_losses)
        train_dice_mean = np.mean(train_dice_scores)

        if writer is not None:
            writer.add_scalar('Loss/train', train_loss_mean, epoch)
            writer.add_scalar('Loss/val', val_loss, epoch)
            writer.add_scalar('Dice/train', train_dice_mean, epoch)
            writer.add_scalar('Dice/val', val_dice, epoch)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {train_loss_mean:.4f}, Train Dice: {train_dice_mean:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")

        # LR scheduler step
        scheduler.step(val_loss)

        # Model checkpointing
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, save_path)
            print(f"Best model updated at epoch {epoch+1} with Val Loss: {best_val_loss:.4f}")

    if writer is not None:
        writer.close()
    return model


In [2]:
import torch
from torch.utils.data import DataLoader
from dataloaderv2 import EchoVideoDataset
from modelv6 import MobileNetV3UNet
# from utils import train_model
import os
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "5"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize datasets
train_dataset = EchoVideoDataset(root="./data/echodynamic", split='train')
val_dataset = EchoVideoDataset(root="./data/echodynamic", split='val')

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Initialize the model with 4 channels and a pretrained backbone
# The updated model architecture handles the extra mask channel internally,
# so no manual weight adaptation is needed now.
model = MobileNetV3UNet(in_channels=4, out_channels=1, backbone_pretrained=True).to(device)

# Train the model with improved training loop
model = train_model(
    model,
    train_loader,
    val_loader,
    device,
    num_epochs=10,
    lr=1e-4,
    log_dir="./logs"  # optional: provide a directory for TensorBoard logs
)


  scaler = GradScaler()
  m = torch.load(mask_path)
  with autocast():


KeyboardInterrupt: 