In [5]:
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score
from tqdm import tqdm

class PatientDataset(Dataset):
    def __init__(self, json_path, root_dir, id_range, modalities, transform=None):
        self.root_dir = root_dir
        self.start_id, self.end_id = id_range
        self.transform = transform
        self.modalities = modalities
        
        with open(json_path, 'r') as f:
            data = json.load(f)
        self.data, self.entry_count = self._prepare_data(data)
        
    def _prepare_data(self, data):
        prepared_data = []
        entry_count = 0
        
        for patient_id in range(self.start_id, self.end_id + 1):
            if str(patient_id) in data:
                patient_data = data[str(patient_id)]
                for side in ['Right', 'Left', 'Right1', 'Left1', 'Right2', 'Left2', 'Right3', 'Left3']:
                    if side in patient_data:
                        side_data = patient_data[side]
                        label = side_data['Label']
                        image_paths = side_data.get("Paths", {})
                        
                        images = []
                        skip_entry = False
                        
                        for img_type in self.modalities:
                            if img_type in image_paths:
                                path = image_paths[img_type]
                                full_path = os.path.join(self.root_dir, path)
                                if os.path.exists(full_path):
                                    images.append(full_path)
                                else:
                                    skip_entry = True
                                    break
                            else:
                                skip_entry = True
                                break
                        
                        if skip_entry or len(images) != len(self.modalities):
                            continue
                        
                        prepared_data.append({
                            'patient_id': patient_id,
                            'side': side,
                            'images': images,
                            'label': label
                        })
                        entry_count += 1
        
        return prepared_data, entry_count
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        images = []
        
        for img_path in item['images']:
            img = Image.open(img_path).convert("RGB")
            if self.transform:
                img = self.transform(img)
            images.append(img)
        
        images = torch.stack(images)
        label_tensor = torch.tensor(item['label'], dtype=torch.float32)
        
        return images, label_tensor, item['patient_id'], item['side']

class OCTEncoder(nn.Module):
    def __init__(self, output_dim):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.project = nn.Linear(512, output_dim)
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.squeeze()
        x = self.project(x)
        return x

class CrossModalFusion(nn.Module):
    def __init__(self, num_modalities, num_patches, d_model, num_heads=8):
        super().__init__()
        self.num_modalities = num_modalities
        self.num_patches = num_patches
        self.d_model = d_model
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, num_modalities * num_patches + 1, d_model))
        self.modality_embed = nn.Parameter(torch.randn(1, num_modalities, d_model))
        
        self.norm1 = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        
        self.classifier = nn.Linear(d_model, 1)
    
    def prepare_inputs(self, features):
        B = features.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, features], dim=1)
        x = x + self.pos_embed
        modality_embeddings = self.modality_embed.repeat_interleave(self.num_patches, dim=1)
        x[:, 1:] = x[:, 1:] + modality_embeddings.expand(B, -1, -1)
        return x
    
    def forward(self, features):
        x = self.prepare_inputs(features)
        
        residual = x
        x = self.norm1(x)
        x, _ = self.mha(x, x, x)
        x = x + residual
        
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        
        cls_token = x[:, 0]
        output = self.classifier(cls_token)
        return torch.sigmoid(output)

class OCTMultiModalModel(nn.Module):
    def __init__(self, num_modalities, num_patches=1, d_model=256):
        super().__init__()
        self.encoder = OCTEncoder(d_model)
        self.fusion = CrossModalFusion(num_modalities, num_patches, d_model)
    
    def forward(self, x):
        B, M, C, H, W = x.shape
        features = []
        for i in range(M):
            modality_features = self.encoder(x[:, i])
            features.append(modality_features.unsqueeze(1))
        features = torch.cat(features, dim=1)
        output = self.fusion(features)
        return output

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    progress_bar = tqdm(train_loader, desc='Training')
    for images, labels, _, _ in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        preds = (outputs.squeeze() > 0.5).cpu().detach().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        # Update progress bar
        accuracy = accuracy_score(all_labels, all_preds)
        progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1),
                                'accuracy': accuracy})
    
    epoch_loss = total_loss / len(train_loader)
    epoch_accuracy = accuracy_score(all_labels, all_preds)
    epoch_auc = roc_auc_score(all_labels, all_preds)
    
    return epoch_loss, epoch_accuracy, epoch_auc

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc='Validation')
        for images, labels, _, _ in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels)
            total_loss += loss.item()
            
            preds = (outputs.squeeze() > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
            
            # Update progress bar
            accuracy = accuracy_score(all_labels, all_preds)
            progress_bar.set_postfix({'loss': total_loss / (progress_bar.n + 1),
                                    'accuracy': accuracy})
    
    val_loss = total_loss / len(val_loader)
    val_accuracy = accuracy_score(all_labels, all_preds)
    val_auc = roc_auc_score(all_labels, all_preds)
    
    return val_loss, val_accuracy, val_auc

def main():
    # Configuration
    json_path = "train.json"
    root_dir = ""
    modalities = ["modality1", "modality2", "modality3"]  # Replace with your modality names
    
    # Data transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = PatientDataset(
        json_path=json_path,
        root_dir=root_dir,
        id_range=(1, 800),  # Training patient IDs
        modalities=modalities,
        transform=transform
    )
    
    val_dataset = PatientDataset(
        json_path=json_path,
        root_dir=root_dir,
        id_range=(801, 1000),  # Validation patient IDs
        modalities=modalities,
        transform=transform
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)
    
    # Initialize model and training components
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = OCTMultiModalModel(
        num_modalities=len(modalities),
        d_model=256
    ).to(device)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    num_epochs = 50
    best_val_accuracy = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        # Train
        train_loss, train_accuracy, train_auc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_accuracy:.4f}, AUC: {train_auc:.4f}")
        
        # Validate
        val_loss, val_accuracy, val_auc = validate(
            model, val_loader, criterion, device
        )
        print(f"Val Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}, AUC: {val_auc:.4f}")
        
        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved new best model with validation accuracy: {val_accuracy:.4f}")

if __name__ == "__main__":
    main()

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [6]:
# %% Imports
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import f1_score
import torch.nn as nn

# %% Dataset Class
class PatientDataset(Dataset):
    def __init__(self, json_path, root_dir, id_range, transform=None):
        self.root_dir = root_dir
        self.start_id, self.end_id = id_range
        self.transform = transform

        with open(json_path, 'r') as f:
            data = json.load(f)
        self.data, self.entry_count = self._prepare_data(data)

    def _prepare_data(self, data):
        prepared_data = []
        entry_count = 0

        for patient_id, patient_data in data.items():
            if self.start_id <= patient_id <= self.end_id:
                for side in ['Right', 'Left', 'Right1', 'Left1', 'Right2', 'Left2', 'Right3', 'Left3']:
                    if side in patient_data:
                        side_data = patient_data[side]
                        label = side_data['Label']
                        image_paths = side_data.get("Paths", {})

                        images = []
                        skip_entry = False

                        for img_type in ["deep", "surface"]:
                            if img_type in image_paths:
                                path = image_paths[img_type]
                                full_path = os.path.abspath(os.path.join(self.root_dir, path))
                                if os.path.exists(full_path):
                                    images.append(full_path)
                                else:
                                    skip_entry = True
                                    break
                            else:
                                skip_entry = True
                                break

                        if skip_entry:
                            continue

                        prepared_data.append({
                            'patient_id': patient_id,
                            'side': side,
                            'images': images,
                            'label': label
                        })
                        entry_count += 1

        return prepared_data, entry_count

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
     item = self.data[idx]
     images = []
    
     for img_path in item['images']:
        img = Image.open(img_path).convert("RGB")  # Convert to grayscale (1 channel)
        if self.transform:
            img = self.transform(img)
        images.append(img)
    
     if len(images) != 2:
        raise ValueError(f"Expected 2 images per entry, but got {len(images)} for patient {item['patient_id']} side {item['side']}")
    
    # Stack the two images along the channel dimension to create a 2-channel tensor
     images = torch.stack(images)  # Shape: [2, H, W]
     label_tensor = torch.tensor(item['label'], dtype=torch.float32)

     return images, label_tensor, item['patient_id'], item['side']



# %% Define paths, ID range, and transforms
json_path = "train.json"
root_dir = "C:\\Users\\manoj\\OneDrive\\Desktop\\INTERN\\train"
id_range = ("20230402140053", "20230708145810")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create dataset and DataLoader
dataset = PatientDataset(json_path, root_dir, id_range, transform)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True)

# %% Model

In [7]:
class OCTEncoder(nn.Module):
    def __init__(self, output_dim=256):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        # Modify first conv layer to accept grayscale images
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.project = nn.Linear(512, output_dim)
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.squeeze()
        x = self.project(x)
        return x

class CrossModalFusion(nn.Module):
    def __init__(self, num_modalities=2, d_model=256, num_heads=8):
        super().__init__()
        self.num_modalities = num_modalities
        self.d_model = d_model
        
        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        
        # Positional and modality embeddings
        self.pos_embed = nn.Parameter(torch.randn(1, num_modalities + 1, d_model))
        self.modality_embed = nn.Parameter(torch.randn(1, num_modalities, d_model))
        
        # Multi-head attention
        self.norm1 = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # MLP
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(0.1)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, 1)
        )
        
    def prepare_inputs(self, features):
        B = features.shape[0]
        
        # Expand CLS token for batch
        cls_tokens = self.cls_token.expand(B, -1, -1)
        
        # Concatenate CLS token with features
        x = torch.cat([cls_tokens, features], dim=1)
        
        # Add positional embeddings
        x = x + self.pos_embed
        
        # Add modality embeddings to features (not CLS token)
        x[:, 1:] = x[:, 1:] + self.modality_embed
        
        return x
    
    def forward(self, features):
        x = self.prepare_inputs(features)
        
        # Self-attention block
        residual = x
        x = self.norm1(x)
        x, _ = self.mha(x, x, x)
        x = x + residual
        
        # MLP block
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        
        # Use CLS token for classification
        cls_token = x[:, 0]
        output = self.classifier(cls_token)
        return torch.sigmoid(output)

class OCTMultiModalModel(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        self.encoder = OCTEncoder(d_model)
        self.fusion = CrossModalFusion(num_modalities=2, d_model=d_model)
        
    def forward(self, x):
        B, M, C, H, W = x.shape  # Batch, Modalities (2), Channels, Height, Width
        
        # Encode each modality
        features = []
        for i in range(M):
            modality_features = self.encoder(x[:, i])  # [B, d_model]
            features.append(modality_features.unsqueeze(1))  # Add modality dimension
            
        # Concatenate features from both modalities
        features = torch.cat(features, dim=1)  # [B, 2, d_model]
        
        # Pass through fusion module
        output = self.fusion(features)
        return output

# Training function
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for batch_idx, (images, labels, _, _) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.squeeze(), labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # Calculate metrics
        preds = (outputs.squeeze() > 0.5).cpu().detach().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f'Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return epoch_loss, f1

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _, _ in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels)
            
            running_loss += loss.item()
            
            preds = (outputs.squeeze() > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    val_loss = running_loss / len(val_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return val_loss, f1

# Training setup code
def setup_training(train_loader, val_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = OCTMultiModalModel(d_model=256).to(device)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    num_epochs = 50
    best_val_f1 = 0
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_f1 = validate(model, val_loader, criterion, device)
        
        print(f'Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}')
        
        scheduler.step(val_f1)
        
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Saved new best model with validation F1: {val_f1:.4f}')
    
    return model

# Start training
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = setup_training(train_loader, val_loader)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to C:\Users\manoj/.cache\torch\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:25<00:00, 1.87MB/s]



Epoch 1/50


ValueError: Using a target size (torch.Size([16, 5])) that is different to the input size (torch.Size([16])) is deprecated. Please ensure they have the same size.

In [8]:
class OCTEncoder(nn.Module):
    def __init__(self, output_dim=256):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.project = nn.Linear(512, output_dim)
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.squeeze()
        x = self.project(x)
        return x

class CrossModalFusion(nn.Module):
    def __init__(self, num_modalities=2, d_model=256, num_heads=8):
        super().__init__()
        self.num_modalities = num_modalities
        self.d_model = d_model
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, num_modalities + 1, d_model))
        self.modality_embed = nn.Parameter(torch.randn(1, num_modalities, d_model))
        
        self.norm1 = nn.LayerNorm(d_model)
        self.mha = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        self.norm2 = nn.LayerNorm(d_model)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(0.1)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, 1)
        )
        
    def prepare_inputs(self, features):
        B = features.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, features], dim=1)
        x = x + self.pos_embed
        x[:, 1:] = x[:, 1:] + self.modality_embed
        return x
    
    def forward(self, features):
        x = self.prepare_inputs(features)
        
        residual = x
        x = self.norm1(x)
        x, _ = self.mha(x, x, x)
        x = x + residual
        
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        
        cls_token = x[:, 0]
        output = self.classifier(cls_token)
        return output  # Shape: [B, 1]

class OCTMultiModalModel(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        self.encoder = OCTEncoder(d_model)
        self.fusion = CrossModalFusion(num_modalities=2, d_model=d_model)
        
    def forward(self, x):
        B, M, C, H, W = x.shape
        
        features = []
        for i in range(M):
            modality_features = self.encoder(x[:, i])
            features.append(modality_features.unsqueeze(1))
            
        features = torch.cat(features, dim=1)
        output = self.fusion(features)
        return output  # Shape: [B, 1]

def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for batch_idx, (images, labels, _, _) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs.squeeze(), labels)  # Now both will be [B]
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        preds = (outputs.squeeze() > 0.5).cpu().detach().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f'Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(train_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return epoch_loss, f1

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _, _ in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels)  # Now both will be [B]
            
            running_loss += loss.item()
            
            preds = (outputs.squeeze() > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    val_loss = running_loss / len(val_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return val_loss, f1

def setup_training(train_loader, val_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = OCTMultiModalModel(d_model=256).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    num_epochs = 50
    best_val_f1 = 0
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_f1 = validate(model, val_loader, criterion, device)
        
        print(f'Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}')
        
        scheduler.step(val_f1)
        
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'Saved new best model with validation F1: {val_f1:.4f}')
    
    return model

# Start training
if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = setup_training(train_loader, val_loader)

Using device: cuda
Using device: cuda

Epoch 1/50


ValueError: Using a target size (torch.Size([16, 5])) that is different to the input size (torch.Size([16])) is deprecated. Please ensure they have the same size.

In [9]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for batch_idx, (images, labels, _, _) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        # Print shapes for debugging
        print(f"Images shape: {images.shape}")
        print(f"Labels shape: {labels.shape}")
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Print output shape
        print(f"Raw outputs shape: {outputs.shape}")
        print(f"Squeezed outputs shape: {outputs.squeeze().shape}")
        
        # Ensure labels are the right shape
        labels = labels.view(-1)  # Reshape to [B]
        outputs = outputs.view(-1)  # Reshape to [B]
        
        print(f"Final outputs shape: {outputs.shape}")
        print(f"Final labels shape: {labels.shape}")
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        preds = (outputs > 0.5).cpu().detach().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f'Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}')
            
        # Break after first batch during debugging
        break
    
    epoch_loss = running_loss / len(train_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return epoch_loss, f1

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels, _, _ in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            labels = labels.view(-1)  # Reshape to [B]
            outputs = model(images).view(-1)  # Reshape to [B]
            
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            
            preds = (outputs > 0.5).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())
    
    val_loss = running_loss / len(val_loader)
    f1 = f1_score(all_labels, all_preds)
    
    return val_loss, f1

def setup_training(train_loader, val_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Print dataset sizes
    print(f"Train dataset size: {len(train_loader.dataset)}")
    print(f"Val dataset size: {len(val_loader.dataset)}")
    
    # Get a sample batch
    sample_images, sample_labels, _, _ = next(iter(train_loader))
    print(f"Sample batch images shape: {sample_images.shape}")
    print(f"Sample batch labels shape: {sample_labels.shape}")
    
    model = OCTMultiModalModel(d_model=256).to(device)
    criterion = nn.BCEWithLogitsLoss()  # Changed from BCELoss to handle raw logits
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    num_epochs = 50
    best_val_f1 = 0
    
    try:
        for epoch in range(num_epochs):
            print(f'\nEpoch {epoch+1}/{num_epochs}')
            
            train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
            val_loss, val_f1 = validate(model, val_loader, criterion, device)
            
            print(f'Train Loss: {train_loss:.4f}, Train F1: {train_f1:.4f}')
            print(f'Val Loss: {val_loss:.4f}, Val F1: {val_f1:.4f}')
            
            scheduler.step(val_f1)
            
            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                torch.save(model.state_dict(), 'best_model.pth')
                print(f'Saved new best model with validation F1: {val_f1:.4f}')
    except Exception as e:
        print(f"Error during training: {str(e)}")
        print("Stack trace:")
        import traceback
        traceback.print_exc()
    
    return model

# Data shape checking function
def check_data_shapes(loader):
    print("\nChecking data shapes...")
    for batch_idx, (images, labels, patient_ids, sides) in enumerate(loader):
        print(f"Batch {batch_idx}:")
        print(f"Images shape: {images.shape}")
        print(f"Labels shape: {labels.shape}")
        print(f"Sample patient ID: {patient_ids[0]}")
        print(f"Sample side: {sides[0]}")
        break

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Check data shapes before training
    print("\nChecking training data:")
    check_data_shapes(train_loader)
    print("\nChecking validation data:")
    check_data_shapes(val_loader)
    
    model = setup_training(train_loader, val_loader)

Using device: cuda

Checking training data:

Checking data shapes...
Batch 0:
Images shape: torch.Size([16, 2, 3, 224, 224])
Labels shape: torch.Size([16, 5])
Sample patient ID: 20230424165605
Sample side: Right

Checking validation data:

Checking data shapes...
Batch 0:
Images shape: torch.Size([16, 2, 3, 224, 224])
Labels shape: torch.Size([16, 5])
Sample patient ID: 20230531162152
Sample side: Left
Using device: cuda
Train dataset size: 126
Val dataset size: 32
Sample batch images shape: torch.Size([16, 2, 3, 224, 224])
Sample batch labels shape: torch.Size([16, 5])

Epoch 1/50
Images shape: torch.Size([16, 2, 3, 224, 224])
Labels shape: torch.Size([16, 5])
Raw outputs shape: torch.Size([16, 1])
Squeezed outputs shape: torch.Size([16])
Final outputs shape: torch.Size([16])
Final labels shape: torch.Size([80])
Error during training: Target size (torch.Size([80])) must be the same as input size (torch.Size([16]))
Stack trace:


Traceback (most recent call last):
  File "C:\Users\manoj\AppData\Local\Temp\ipykernel_32892\2077344965.py", line 100, in setup_training
    train_loss, train_f1 = train_epoch(model, train_loader, criterion, optimizer, device)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\manoj\AppData\Local\Temp\ipykernel_32892\2077344965.py", line 28, in train_epoch
    loss = criterion(outputs, labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\manoj\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\manoj\AppData\Local\Programs\Python\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\manoj\AppData\Local\Programs\Python\Python31