# EchoNet-Dynamic Multi-Task Model (EF, EDV, ESV)

**Google Colab Version** ‚Äî Optimized for fast training with local SSD storage.

This notebook implements a complete pipeline for training a 3D CNN model to predict:
- **EF** (Ejection Fraction)
- **EDV** (End-Diastolic Volume)
- **ESV** (End-Systolic Volume)

from echocardiogram videos.

---

## Setup Flow (Each Session)
1. Mount Google Drive (where raw videos are stored)
2. Preprocess videos ‚Üí save directly to local SSD (~1-2 hrs)
3. Train model (fast I/O from local disk)
4. Model saved to Google Drive (persistent)


In [None]:
# ============================================================
# GOOGLE COLAB SETUP
# ============================================================

from google.colab import drive
drive.mount('/content/drive')

# Install required packages
%pip install -q imageio imageio-ffmpeg

print('\n‚úÖ Drive mounted and dependencies installed!')


In [21]:
# ============================================================
# 1) IMPORTS
# ============================================================

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import resize
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import os
import warnings
warnings.filterwarnings('ignore')

# Try to import video reading libraries
try:
    from torchvision.io import read_video
    USE_TORCHVISION = True
except:
    USE_TORCHVISION = False

try:
    import imageio
    USE_IMAGEIO = True
except ImportError:
    USE_IMAGEIO = False
    print("Warning: imageio not installed. Install with: pip install imageio imageio-ffmpeg")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cpu


## 2) Load Data


In [None]:
# ============================================================
# 2) PATH CONFIGURATION (GOOGLE DRIVE)
# ============================================================

# Google Drive paths (persistent storage)
DRIVE_BASE = "/content/drive/MyDrive/EchoNet-Dynamic"
VIDEOS_DIR = os.path.join(DRIVE_BASE, "Videos")
LABELS_CSV = os.path.join(DRIVE_BASE, "FileList.csv")
PREPROCESSED_DIR_DRIVE = os.path.join(DRIVE_BASE, "PreprocessedVideos")

# Local SSD paths (fast I/O, cleared each session)
LOCAL_BASE = "/content/echonet_local"
PREPROCESSED_DIR_LOCAL = os.path.join(LOCAL_BASE, "PreprocessedVideos")

# Create directories
os.makedirs(PREPROCESSED_DIR_DRIVE, exist_ok=True)
os.makedirs(LOCAL_BASE, exist_ok=True)

# Verify data exists
print("üìÅ Checking Google Drive paths...")
print(f"   DRIVE_BASE: {DRIVE_BASE}")
print(f"   Videos: {'‚úÖ' if os.path.exists(VIDEOS_DIR) else '‚ùå'} {VIDEOS_DIR}")
print(f"   Labels: {'‚úÖ' if os.path.exists(LABELS_CSV) else '‚ùå'} {LABELS_CSV}")

# Load labels CSV
print(f"\nüìä Loading labels...")
df = pd.read_csv(LABELS_CSV)

print(f"Dataset shape: {df.shape}")
print(f"Columns: {df.columns.tolist()}")
print(f"\nFirst few rows:")
print(df.head())
print(f"\nData types:")
print(df.dtypes)
print(f"\nMissing values:")
print(df.isnull().sum())


Loading labels from: /Volumes/Crucial X6/medical_ai_extra/EchoNet-Dynamic/FileList.csv
Dataset shape: (10030, 9)

Columns: ['FileName', 'EF', 'ESV', 'EDV', 'FrameHeight', 'FrameWidth', 'FPS', 'NumberOfFrames', 'Split']

First few rows:
             FileName         EF         ESV         EDV  FrameHeight  \
0  0X100009310A3BD7FC  78.498406   14.881368   69.210534          112   
1  0X1002E8FBACD08477  59.101988   40.383876   98.742884          112   
2  0X1005D03EED19C65B  62.363798   14.267784   37.909734          112   
3  0X10075961BC11C88E  54.545097   33.143084   72.914210          112   
4  0X10094BA0A028EAC3  24.887742  127.581945  169.855024          112   

   FrameWidth  FPS  NumberOfFrames  Split  
0         112   50             174    VAL  
1         112   50             215  TRAIN  
2         112   50             104  TRAIN  
3         112   55             122  TRAIN  
4         112   52             207    VAL  

Data types:
FileName           object
EF                floa

## 2.5) Preprocess All Videos (Saves to Local SSD)

This cell preprocesses all videos and saves them as `.pt` files **to local SSD** (fast I/O).
- Runs **each Colab session** (local storage is cleared when session ends)
- Takes ~1-2 hours
- Training is MUCH faster from local SSD than from Drive


In [None]:
# ============================================================
# 2.5) PREPROCESS ALL VIDEOS (SAVES TO LOCAL SSD)
# ============================================================
# This converts raw .avi videos to .pt tensors for faster loading
# Saves to LOCAL SSD for fast training I/O
# Needs to run each Colab session (local storage is ephemeral)

import shutil
import gc

# Save preprocessed videos to LOCAL SSD (fast I/O for training)
PREPROCESSED_DIR = PREPROCESSED_DIR_LOCAL
os.makedirs(PREPROCESSED_DIR, exist_ok=True)

def preprocess_single_video(video_path, num_frames=32, target_size=(112, 112)):
    """
    Memory-efficient video preprocessing.
    Uses get_data() to load ONLY the specific frames we need.
    """
    try:
        reader = imageio.get_reader(video_path)
        total_frames = reader.count_frames()
        
        # Calculate which frame indices to sample
        if total_frames >= num_frames:
            indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
        else:
            indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames)
            indices = np.array(indices[:num_frames])
        
        # Pre-allocate output tensor (C, T, H, W)
        video_tensor = np.zeros((3, num_frames, target_size[0], target_size[1]), dtype=np.float32)
        
        # Load ONLY the frames we need using get_data() - much more memory efficient!
        for out_idx, frame_idx in enumerate(indices):
            frame = reader.get_data(int(frame_idx))  # Load single frame
            
            # Convert HWC -> CHW, resize, normalize in one go
            frame_chw = np.transpose(frame, (2, 0, 1))
            frame_t = torch.from_numpy(frame_chw).float()
            frame_resized = resize(frame_t, target_size, antialias=True).numpy()
            video_tensor[:, out_idx, :, :] = frame_resized / 255.0
            
            # Free intermediate memory
            del frame, frame_chw, frame_t, frame_resized
        
        reader.close()
        return torch.from_numpy(video_tensor)
    
    except Exception as e:
        print(f"Error processing {video_path}: {e}")
        return None

# Check how many are already preprocessed
existing = set(f.replace('.pt', '') for f in os.listdir(PREPROCESSED_DIR) if f.endswith('.pt'))
to_process = df[~df['FileName'].isin(existing)]

print(f"üìä Preprocessing Status:")
print(f"   Already preprocessed: {len(existing):,}")
print(f"   Remaining to process: {len(to_process):,}")
print(f"   Total videos: {len(df):,}")

if len(to_process) > 0:
    print(f"\nüîÑ Starting preprocessing...")
    
    failed = []
    for idx, row in tqdm(to_process.iterrows(), total=len(to_process), desc="Preprocessing"):
        filename = row['FileName']
        video_path = os.path.join(VIDEOS_DIR, filename + '.avi')
        output_path = os.path.join(PREPROCESSED_DIR, f"{filename}.pt")
        
        if os.path.exists(output_path):
            continue
        
        video_tensor = preprocess_single_video(video_path)
        
        if video_tensor is not None:
            torch.save(video_tensor, output_path)
            del video_tensor  # Free immediately after saving
        else:
            failed.append(filename)
        
        # Aggressive memory cleanup every 10 videos
        if idx % 10 == 0:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    print(f"\n‚úÖ Preprocessing complete!")
    print(f"   Successful: {len(to_process) - len(failed):,}")
    print(f"   Failed: {len(failed)}")
else:
    print("\n‚úÖ All videos already preprocessed!")


## 2.6) Verify Preprocessing Complete

Quick check that all videos were preprocessed successfully.


In [None]:
# ============================================================
# 2.6) VERIFY PREPROCESSING COMPLETE
# ============================================================
# Quick check that preprocessing completed successfully

local_pt_files = [f for f in os.listdir(PREPROCESSED_DIR) if f.endswith('.pt')]

print(f"üìä Preprocessing Status:")
print(f"   Preprocessed videos: {len(local_pt_files):,}")
print(f"   Total in dataset: {len(df):,}")
print(f"   Location: {PREPROCESSED_DIR} (local SSD)")

if len(local_pt_files) >= len(df) * 0.95:  # Allow 5% failure
    print(f"\n‚úÖ Preprocessing complete! Ready to train.")
else:
    print(f"\n‚ö†Ô∏è Only {len(local_pt_files)}/{len(df)} videos preprocessed.")
    print("   Run the preprocessing cell above to complete.")


In [30]:
# ============================================================
# 3) PREPROCESSING
# ============================================================

def preprocess_video(video_path, num_frames=32, target_size=(112, 112)):
    """
    Preprocess video: sample frames, resize, normalize.
    
    Args:
        video_path: Path to video file
        num_frames: Number of frames to sample (default: 32)
        target_size: Target frame size (H, W) (default: (112, 112))
    
    Returns:
        tensor: Shape (C=3, T=32, H=112, W=112)
    """
    try:
        # Try to read video using torchvision first, fallback to imageio
        if USE_TORCHVISION:
            try:
                video, audio, info = read_video(video_path, output_format="TCHW")
                # video shape: (T, C, H, W)
                video_np = video.numpy()
            except Exception as e:
                # Fallback to imageio if torchvision fails
                if USE_IMAGEIO:
                    reader = imageio.get_reader(video_path)
                    frames = []
                    for frame in reader:
                        # frame is (H, W, C), convert to (C, H, W)
                        frame_chw = np.transpose(frame, (2, 0, 1))
                        frames.append(frame_chw)
                    reader.close()
                    # Stack frames: (T, C, H, W)
                    video_np = np.stack(frames, axis=0)
                else:
                    raise e
        elif USE_IMAGEIO:
            # Use imageio directly
            reader = imageio.get_reader(video_path)
            frames = []
            for frame in reader:
                # frame is (H, W, C), convert to (C, H, W)
                frame_chw = np.transpose(frame, (2, 0, 1))
                frames.append(frame_chw)
            reader.close()
            # Stack frames: (T, C, H, W)
            video_np = np.stack(frames, axis=0)
        else:
            raise ImportError("Neither torchvision nor imageio available for video reading")
        
        T, C, H, W = video_np.shape
        
        # Sample exactly num_frames uniformly
        if T >= num_frames:
            indices = np.linspace(0, T - 1, num_frames, dtype=int)
        else:
            # If video has fewer frames, repeat last frame
            indices = list(range(T)) + [T - 1] * (num_frames - T)
            indices = indices[:num_frames]
        
        sampled_frames = video_np[indices]  # (num_frames, C, H, W)
        
        # Resize frames to target_size using torchvision
        resized_frames = []
        for frame in sampled_frames:
            # frame is (C, H, W) - convert to torch tensor
            frame_tensor = torch.from_numpy(frame).float()
            # Resize using torchvision (expects (C, H, W) format)
            frame_resized = resize(frame_tensor, target_size, antialias=True)
            resized_frames.append(frame_resized.numpy())
        
        # Stack frames: (num_frames, C, H, W) -> (C, num_frames, H, W)
        video_tensor = np.stack(resized_frames, axis=1)
        
        # Normalize to [0, 1]
        video_tensor = video_tensor.astype(np.float32) / 255.0
        
        return torch.from_numpy(video_tensor)
    
    except Exception as e:
        print(f"Error processing {video_path}: {e}")
        # Return zero tensor as fallback
        return torch.zeros((3, num_frames, target_size[0], target_size[1]), dtype=torch.float32)

# Test preprocessing function
print("Preprocessing function defined successfully!")


Preprocessing function defined successfully!


In [31]:
# ============================================================
# 4) BUILD PYTORCH DATASET
# ============================================================

class EchoDataset(Dataset):
    def __init__(self, dataframe, videos_dir, num_frames=32, target_size=(112, 112), 
                 preprocessed_dir=None, use_preprocessed=True):
        """
        Args:
            dataframe: DataFrame with columns: FileName, EF, EDV, ESV
            videos_dir: Base directory containing video files
            num_frames: Number of frames to sample per video
            target_size: Target frame size (H, W)
            preprocessed_dir: Directory with preprocessed .pt files (optional)
            use_preprocessed: If True, use preprocessed files if available
        """
        self.dataframe = dataframe.reset_index(drop=True)
        self.videos_dir = videos_dir
        self.preprocessed_dir = preprocessed_dir
        self.use_preprocessed = use_preprocessed
        self.num_frames = num_frames
        self.target_size = target_size
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        
        # Try to load preprocessed video first
        if self.use_preprocessed and self.preprocessed_dir:
            preprocessed_path = os.path.join(self.preprocessed_dir, f"{row['FileName']}.pt")
            if os.path.exists(preprocessed_path):
                video_tensor = torch.load(preprocessed_path)
                # Ensure correct shape
                if video_tensor.shape != (3, self.num_frames, self.target_size[0], self.target_size[1]):
                    # If shape doesn't match, fall back to preprocessing
                    video_tensor = None
            else:
                video_tensor = None
        else:
            video_tensor = None
        
        # If preprocessed not available, process on-the-fly
        if video_tensor is None:
            filename = row['FileName']
            if not filename.endswith('.avi'):
                filename = filename + '.avi'
            video_path = os.path.join(self.videos_dir, filename)
            video_tensor = preprocess_video(video_path, self.num_frames, self.target_size)
        
        # Get labels
        EF = float(row['EF'])
        EDV = float(row['EDV'])
        ESV = float(row['ESV'])
        
        labels = torch.tensor([EF, EDV, ESV], dtype=torch.float32)
        
        return video_tensor, labels

print("EchoDataset class defined successfully!")


EchoDataset class defined successfully!


## 5) Train/Val/Test Split


In [None]:
# ============================================================
# 5) SPLITTING
# ============================================================

# Check if Split column exists
if 'Split' in df.columns:
    print("Using existing Split column from CSV")
    # Handle both uppercase (TRAIN/VAL/TEST) and lowercase (train/val/test) values
    df['Split'] = df['Split'].str.upper()
    train_df = df[df['Split'].isin(['TRAIN', 'TRAINING'])].copy()
    val_df = df[df['Split'].isin(['VAL', 'VALIDATION'])].copy()
    test_df = df[df['Split'].isin(['TEST', 'TESTING'])].copy()
    
    print(f"Train: {len(train_df)} samples")
    print(f"Val: {len(val_df)} samples")
    print(f"Test: {len(test_df)} samples")
else:
    print("Split column not found. Creating train/val/test splits...")
    # First split: train (70%) and temp (30%)
    train_df, temp_df = train_test_split(
        df, test_size=0.3, random_state=42, shuffle=True
    )
    # Second split: val (15%) and test (15%) from temp
    val_df, test_df = train_test_split(
        temp_df, test_size=0.5, random_state=42, shuffle=True
    )
    
    print(f"Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
    print(f"Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
    print(f"Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")

# Create datasets (PREPROCESSED_DIR set in cell 2.6 to local SSD for fast I/O)
# If PREPROCESSED_DIR is None, fall back to on-the-fly processing (slow)
use_preprocessed = PREPROCESSED_DIR is not None and os.path.exists(PREPROCESSED_DIR) and len(os.listdir(PREPROCESSED_DIR)) > 0

if use_preprocessed:
    print(f"‚úÖ Using preprocessed videos from: {PREPROCESSED_DIR}")
    print(f"   Found {len(os.listdir(PREPROCESSED_DIR)):,} preprocessed files (local SSD - fast!)")
else:
    print("‚ö†Ô∏è  No preprocessed videos found. Videos will be processed on-the-fly (slow).")
    print("   Run the preprocessing cells above first!")

train_dataset = EchoDataset(train_df, VIDEOS_DIR, preprocessed_dir=PREPROCESSED_DIR, 
                           use_preprocessed=use_preprocessed)
val_dataset = EchoDataset(val_df, VIDEOS_DIR, preprocessed_dir=PREPROCESSED_DIR,
                          use_preprocessed=use_preprocessed)
test_dataset = EchoDataset(test_df, VIDEOS_DIR, preprocessed_dir=PREPROCESSED_DIR,
                           use_preprocessed=use_preprocessed)

print("\nDatasets created successfully!")


Using existing Split column from CSV
Train: 7465 samples
Val: 1288 samples
Test: 1277 samples
‚úÖ Using preprocessed videos from: /Volumes/Crucial X6/medical_ai_extra/EchoNet-Dynamic/PreprocessedVideos
   Found 10030 preprocessed files

Datasets created successfully!


## 6) 3D CNN Model


In [33]:
# ============================================================
# 6) MODEL (3D CNN)
# ============================================================

class EchoNet3DCNN(nn.Module):
    def __init__(self, feature_dim=256):
        super(EchoNet3DCNN, self).__init__()
        
        # 3D CNN Encoder
        self.encoder = nn.Sequential(
            # Block 1
            nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            
            # Block 2
            nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)),
            
            # Block 3
            nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)),
            
            # Block 4
            nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(512),
            nn.ReLU(inplace=True),
            
            # Adaptive pooling to get fixed-size feature vector
            nn.AdaptiveAvgPool3d((1, 1, 1))
        )
        
        # Flatten
        self.flatten = nn.Flatten()
        
        # Feature projection to feature_dim
        self.feature_proj = nn.Sequential(
            nn.Linear(512, feature_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        
        # Prediction heads
        self.EF_head = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
        
        self.EDV_head = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
        
        self.ESV_head = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )
    
    def forward(self, x):
        # x shape: (B, C=3, T=32, H=112, W=112)
        features = self.encoder(x)  # (B, 512, 1, 1, 1)
        features = self.flatten(features)  # (B, 512)
        features = self.feature_proj(features)  # (B, feature_dim)
        
        # Multi-task predictions
        EF_pred = self.EF_head(features)  # (B, 1)
        EDV_pred = self.EDV_head(features)  # (B, 1)
        ESV_pred = self.ESV_head(features)  # (B, 1)
        
        return EF_pred.squeeze(-1), EDV_pred.squeeze(-1), ESV_pred.squeeze(-1)

# Initialize model
model = EchoNet3DCNN(feature_dim=256).to(device)
print(f"Model created. Total parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test forward pass
test_input = torch.randn(2, 3, 32, 112, 112).to(device)
EF_test, EDV_test, ESV_test = model(test_input)
print(f"Test forward pass successful!")
print(f"  EF output shape: {EF_test.shape}")
print(f"  EDV output shape: {EDV_test.shape}")
print(f"  ESV output shape: {ESV_test.shape}")


Model created. Total parameters: 4,883,331


Test forward pass successful!
  EF output shape: torch.Size([2])
  EDV output shape: torch.Size([2])
  ESV output shape: torch.Size([2])


In [None]:
# ============================================================
# 7) TRAINING SETUP
# ============================================================

# Hyperparameters
# Update device if GPU becomes available (device already set in imports)
if torch.backends.mps.is_available() and device.type == 'cpu':
    device = torch.device('mps')
    print("Switching to MPS (Apple Silicon GPU)")
elif torch.cuda.is_available() and device.type == 'cpu':
    device = torch.device('cuda')
    print("Switching to CUDA GPU")

BATCH_SIZE = 8 if device.type != 'cpu' else 4  # Larger batch on GPU
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5  # Reduced for testing - change back to 40 for full training
NUM_WORKERS = 0  # Use main process to avoid multiprocessing issues on Mac

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

# Loss function (MAE for each task)
criterion = nn.L1Loss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

print("Training setup complete!")
print(f"  Device: {device}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Number of epochs: {NUM_EPOCHS}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Workers: {NUM_WORKERS}")
print(f"\n‚ö†Ô∏è  NOTE: Training on CPU with on-the-fly video preprocessing is very slow.")
print(f"    Estimated time per epoch: ~{len(train_loader) * 100 / 3600:.1f} hours (at ~100s/iter)")
print(f"    Consider: 1) Using GPU, 2) Preprocessing videos ahead of time, or 3) Reducing epochs for testing")


Training setup complete!
  Device: cpu
  Batch size: 4
  Learning rate: 0.0001
  Number of epochs: 40
  Train batches: 1867
  Val batches: 322
  Workers: 0

‚ö†Ô∏è  NOTE: Training on CPU with on-the-fly video preprocessing is very slow.
    Estimated time per epoch: ~51.9 hours (at ~100s/iter)
    Consider: 1) Using GPU, 2) Preprocessing videos ahead of time, or 3) Reducing epochs for testing


## 8) Training Loop


In [43]:
# ============================================================
# 7) TRAINING LOOP
# ============================================================

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    total_EF_loss = 0.0
    total_EDV_loss = 0.0
    total_ESV_loss = 0.0
    
    for videos, labels in tqdm(loader, desc="Training"):
        videos = videos.to(device)
        labels = labels.to(device)  # (B, 3) [EF, EDV, ESV]
        
        # Forward pass
        EF_pred, EDV_pred, ESV_pred = model(videos)
        
        # Compute losses
        EF_loss = criterion(EF_pred, labels[:, 0])
        EDV_loss = criterion(EDV_pred, labels[:, 1])
        ESV_loss = criterion(ESV_pred, labels[:, 2])
        
        total_loss_batch = EF_loss + EDV_loss + ESV_loss
        
        # Backward pass
        optimizer.zero_grad()
        total_loss_batch.backward()
        optimizer.step()
        
        # Accumulate losses
        total_loss += total_loss_batch.item()
        total_EF_loss += EF_loss.item()
        total_EDV_loss += EDV_loss.item()
        total_ESV_loss += ESV_loss.item()
    
    return {
        'total_loss': total_loss / len(loader),
        'EF_loss': total_EF_loss / len(loader),
        'EDV_loss': total_EDV_loss / len(loader),
        'ESV_loss': total_ESV_loss / len(loader)
    }

def validate_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_EF_loss = 0.0
    total_EDV_loss = 0.0
    total_ESV_loss = 0.0
    
    with torch.no_grad():
        for videos, labels in tqdm(loader, desc="Validating"):
            videos = videos.to(device)
            labels = labels.to(device)
            
            # Forward pass
            EF_pred, EDV_pred, ESV_pred = model(videos)
            
            # Compute losses
            EF_loss = criterion(EF_pred, labels[:, 0])
            EDV_loss = criterion(EDV_pred, labels[:, 1])
            ESV_loss = criterion(ESV_pred, labels[:, 2])
            
            total_loss_batch = EF_loss + EDV_loss + ESV_loss
            
            # Accumulate losses
            total_loss += total_loss_batch.item()
            total_EF_loss += EF_loss.item()
            total_EDV_loss += EDV_loss.item()
            total_ESV_loss += ESV_loss.item()
    
    return {
        'total_loss': total_loss / len(loader),
        'EF_loss': total_EF_loss / len(loader),
        'EDV_loss': total_EDV_loss / len(loader),
        'ESV_loss': total_ESV_loss / len(loader)
    }

print("Training functions defined!")


Training functions defined!


In [None]:
# Training history
train_history = {
    'total_loss': [],
    'EF_loss': [],
    'EDV_loss': [],
    'ESV_loss': []
}

val_history = {
    'total_loss': [],
    'EF_loss': [],
    'EDV_loss': [],
    'ESV_loss': []
}

best_val_loss = float('inf')
# Create output directory on Google Drive (persistent across sessions)
OUTPUT_DIR = os.path.join(DRIVE_BASE, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
best_model_path = os.path.join(OUTPUT_DIR, "echonet_multi_task.pth")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Model will be saved to: {best_model_path}")

print("Starting training...")
print("=" * 60)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
    train_history['total_loss'].append(train_metrics['total_loss'])
    train_history['EF_loss'].append(train_metrics['EF_loss'])
    train_history['EDV_loss'].append(train_metrics['EDV_loss'])
    train_history['ESV_loss'].append(train_metrics['ESV_loss'])
    
    # Validate
    val_metrics = validate_epoch(model, val_loader, criterion, device)
    val_history['total_loss'].append(val_metrics['total_loss'])
    val_history['EF_loss'].append(val_metrics['EF_loss'])
    val_history['EDV_loss'].append(val_metrics['EDV_loss'])
    val_history['ESV_loss'].append(val_metrics['ESV_loss'])
    
    # Print metrics
    print(f"\nTrain Loss - Total: {train_metrics['total_loss']:.4f}, "
          f"EF: {train_metrics['EF_loss']:.4f}, "
          f"EDV: {train_metrics['EDV_loss']:.4f}, "
          f"ESV: {train_metrics['ESV_loss']:.4f}")
    print(f"Val Loss   - Total: {val_metrics['total_loss']:.4f}, "
          f"EF: {val_metrics['EF_loss']:.4f}, "
          f"EDV: {val_metrics['EDV_loss']:.4f}, "
          f"ESV: {val_metrics['ESV_loss']:.4f}")
    
    # Update learning rate
    scheduler.step(val_metrics['total_loss'])
    
    # Save best model
    if val_metrics['total_loss'] < best_val_loss:
        best_val_loss = val_metrics['total_loss']
        torch.save(model.state_dict(), best_model_path)
        print(f"‚úì Saved best model (val_loss: {best_val_loss:.4f})")

print("\n" + "=" * 60)
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Best model saved to: {best_model_path}")


Output directory: /Volumes/Crucial X6/medical_ai_extra/echonet_outputs
Model will be saved to: /Volumes/Crucial X6/medical_ai_extra/echonet_outputs/echonet_multi_task.pth
Starting training...

Epoch 1/40
------------------------------------------------------------


Training:   0%|          | 0/1867 [00:00<?, ?it/s]

Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1867/1867 [7:26:46<00:00, 14.36s/it]  
Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 322/322 [26:32<00:00,  4.95s/it]



Train Loss - Total: 71.5037, EF: 14.1937, EDV: 35.6947, ESV: 21.6153
Val Loss   - Total: 59.1177, EF: 9.9926, EDV: 29.4616, ESV: 19.6634
‚úì Saved best model (val_loss: 59.1177)

Epoch 2/40
------------------------------------------------------------


Training:  36%|‚ñà‚ñà‚ñà‚ñå      | 664/1867 [2:34:24<4:39:45, 13.95s/it]


KeyboardInterrupt: 

## 9) Testing/Evaluation


In [None]:
# ============================================================
# 8) TESTING
# ============================================================

# Load best model
print(f"Loading best model from: {best_model_path}")
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.eval()

# Create test loader
test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=NUM_WORKERS,
    pin_memory=True if torch.cuda.is_available() else False
)

# Evaluate on test set
print("\nEvaluating on test set...")
print("=" * 60)

all_EF_pred = []
all_EDV_pred = []
all_ESV_pred = []
all_EF_true = []
all_EDV_true = []
all_ESV_true = []

with torch.no_grad():
    for videos, labels in tqdm(test_loader, desc="Testing"):
        videos = videos.to(device)
        labels = labels.to(device)
        
        # Forward pass
        EF_pred, EDV_pred, ESV_pred = model(videos)
        
        # Store predictions and ground truth
        all_EF_pred.extend(EF_pred.cpu().numpy())
        all_EDV_pred.extend(EDV_pred.cpu().numpy())
        all_ESV_pred.extend(ESV_pred.cpu().numpy())
        all_EF_true.extend(labels[:, 0].cpu().numpy())
        all_EDV_true.extend(labels[:, 1].cpu().numpy())
        all_ESV_true.extend(labels[:, 2].cpu().numpy())

# Convert to numpy arrays
all_EF_pred = np.array(all_EF_pred)
all_EDV_pred = np.array(all_EDV_pred)
all_ESV_pred = np.array(all_ESV_pred)
all_EF_true = np.array(all_EF_true)
all_EDV_true = np.array(all_EDV_true)
all_ESV_true = np.array(all_ESV_true)

# Calculate MAE for each task
EF_mae = np.mean(np.abs(all_EF_pred - all_EF_true))
EDV_mae = np.mean(np.abs(all_EDV_pred - all_EDV_true))
ESV_mae = np.mean(np.abs(all_ESV_pred - all_ESV_true))

# Print results
print("\n" + "=" * 60)
print("TEST SET RESULTS")
print("=" * 60)
print(f"EF  MAE: {EF_mae:.4f}")
print(f"EDV MAE: {EDV_mae:.4f}")
print(f"ESV MAE: {ESV_mae:.4f}")
print("=" * 60)

# Additional metrics
from sklearn.metrics import mean_squared_error, r2_score

EF_rmse = np.sqrt(mean_squared_error(all_EF_true, all_EF_pred))
EDV_rmse = np.sqrt(mean_squared_error(all_EDV_true, all_EDV_pred))
ESV_rmse = np.sqrt(mean_squared_error(all_ESV_true, all_ESV_pred))

EF_r2 = r2_score(all_EF_true, all_EF_pred)
EDV_r2 = r2_score(all_EDV_true, all_EDV_pred)
ESV_r2 = r2_score(all_ESV_true, all_ESV_pred)

print("\nAdditional Metrics:")
print(f"EF  - RMSE: {EF_rmse:.4f}, R¬≤: {EF_r2:.4f}")
print(f"EDV - RMSE: {EDV_rmse:.4f}, R¬≤: {EDV_r2:.4f}")
print(f"ESV - RMSE: {ESV_rmse:.4f}, R¬≤: {ESV_r2:.4f}")
print("=" * 60)


In [None]:
# Optional: Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Total loss
axes[0, 0].plot(train_history['total_loss'], label='Train')
axes[0, 0].plot(val_history['total_loss'], label='Val')
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# EF loss
axes[0, 1].plot(train_history['EF_loss'], label='Train')
axes[0, 1].plot(val_history['EF_loss'], label='Val')
axes[0, 1].set_title('EF Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# EDV loss
axes[1, 0].plot(train_history['EDV_loss'], label='Train')
axes[1, 0].plot(val_history['EDV_loss'], label='Val')
axes[1, 0].set_title('EDV Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

# ESV loss
axes[1, 1].plot(train_history['ESV_loss'], label='Train')
axes[1, 1].plot(val_history['ESV_loss'], label='Val')
axes[1, 1].set_title('ESV Loss')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
training_curves_path = os.path.join(OUTPUT_DIR, 'training_curves.png')
plt.savefig(training_curves_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"Training curves saved to '{training_curves_path}'")


In [None]:
# Optional: Plot predictions vs ground truth
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# EF
axes[0].scatter(all_EF_true, all_EF_pred, alpha=0.5)
axes[0].plot([all_EF_true.min(), all_EF_true.max()], 
             [all_EF_true.min(), all_EF_true.max()], 'r--', lw=2)
axes[0].set_xlabel('True EF')
axes[0].set_ylabel('Predicted EF')
axes[0].set_title(f'EF (MAE: {EF_mae:.4f}, R¬≤: {EF_r2:.4f})')
axes[0].grid(True)

# EDV
axes[1].scatter(all_EDV_true, all_EDV_pred, alpha=0.5)
axes[1].plot([all_EDV_true.min(), all_EDV_true.max()], 
             [all_EDV_true.min(), all_EDV_true.max()], 'r--', lw=2)
axes[1].set_xlabel('True EDV')
axes[1].set_ylabel('Predicted EDV')
axes[1].set_title(f'EDV (MAE: {EDV_mae:.4f}, R¬≤: {EDV_r2:.4f})')
axes[1].grid(True)

# ESV
axes[2].scatter(all_ESV_true, all_ESV_pred, alpha=0.5)
axes[2].plot([all_ESV_true.min(), all_ESV_true.max()], 
             [all_ESV_true.min(), all_ESV_true.max()], 'r--', lw=2)
axes[2].set_xlabel('True ESV')
axes[2].set_ylabel('Predicted ESV')
axes[2].set_title(f'ESV (MAE: {ESV_mae:.4f}, R¬≤: {ESV_r2:.4f})')
axes[2].grid(True)

plt.tight_layout()
predictions_plot_path = os.path.join(OUTPUT_DIR, 'predictions_vs_ground_truth.png')
plt.savefig(predictions_plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"Predictions vs ground truth plots saved to '{predictions_plot_path}'")
