In [49]:
# %% Imports
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
from sklearn.metrics import f1_score
import torch.nn as nn

# %% Dataset Class
class PatientDataset(Dataset):
    def __init__(self, json_path, root_dir, id_range, transform=None):
        self.root_dir = root_dir
        self.start_id, self.end_id = id_range
        self.transform = transform

        with open(json_path, 'r') as f:
            data = json.load(f)
        self.data, self.entry_count = self._prepare_data(data)

    def _prepare_data(self, data):
        prepared_data = []
        entry_count = 0

        for patient_id, patient_data in data.items():
            if self.start_id <= patient_id <= self.end_id:
                for side in ['Right', 'Left', 'Right1', 'Left1', 'Right2', 'Left2', 'Right3', 'Left3']:
                    if side in patient_data:
                        side_data = patient_data[side]
                        label = side_data['Label']
                        image_paths = side_data.get("Paths", {})

                        images = []
                        skip_entry = False

                        for img_type in ["deep", "surface"]:
                            if img_type in image_paths:
                                path = image_paths[img_type]
                                full_path = os.path.abspath(os.path.join(self.root_dir, path))
                                if os.path.exists(full_path):
                                    images.append(full_path)
                                else:
                                    skip_entry = True
                                    break
                            else:
                                skip_entry = True
                                break

                        if skip_entry:
                            continue

                        prepared_data.append({
                            'patient_id': patient_id,
                            'side': side,
                            'images': images,
                            'label': label
                        })
                        entry_count += 1

        return prepared_data, entry_count

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
     item = self.data[idx]
     images = []
    
     for img_path in item['images']:
        img = Image.open(img_path).convert("RGB")  # Convert to grayscale (1 channel)
        if self.transform:
            img = self.transform(img)
        images.append(img)
    
     if len(images) != 2:
        raise ValueError(f"Expected 2 images per entry, but got {len(images)} for patient {item['patient_id']} side {item['side']}")
    
    # Stack the two images along the channel dimension to create a 2-channel tensor
     images = torch.stack(images)  # Shape: [2, H, W]
     label_tensor = torch.tensor(item['label'], dtype=torch.float32)

     return images, label_tensor, item['patient_id'], item['side']



# %% Define paths, ID range, and transforms
json_path = "train.json"
root_dir = ""
id_range = ("20230402140053", "20230708145810")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create dataset and DataLoader
dataset = PatientDataset(json_path, root_dir, id_range, transform)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, pin_memory=True)

# %% Model


In [73]:
from collections import Counter

# Initialize a counter to count unique labels
label_counter = Counter()

# Loop through the entire dataset to count unique labels
for _, label, _, _ in dataset:
    # Convert one-hot encoded label to a unique string identifier
    active_indices = [str(i) for i, value in enumerate(label) if value == 1]
    unique_label = "_".join(active_indices)  # Join active indices with "_"
    label_counter[unique_label] += 1

# Compute total samples and distribution
total_samples = len(dataset)
distribution = {label: count / total_samples * 100 for label, count in label_counter.items()}

# Display results
print("Total Number of Cases per Unique Label:")
for label, count in label_counter.items():
    print(f"Unique Label {label}: {count}")

print("\nUnique Label Distribution (%):")
for label, dist in distribution.items():
    print(f"Unique Label {label}: {dist:.2f}%")


Total Number of Cases per Unique Label:
Unique Label 1: 3
Unique Label 0_3: 22
Unique Label 0_1: 36
Unique Label : 151
Unique Label 0: 31
Unique Label 0_1_3: 27
Unique Label 3: 21
Unique Label 0_3_4: 3
Unique Label 2_3: 4
Unique Label 1_2_3: 6
Unique Label 0_1_3_4: 4
Unique Label 0_4: 5
Unique Label 1_2: 11
Unique Label 4: 4
Unique Label 3_4: 3
Unique Label 0_1_4: 4
Unique Label 2: 1

Unique Label Distribution (%):
Unique Label 1: 0.89%
Unique Label 0_3: 6.55%
Unique Label 0_1: 10.71%
Unique Label : 44.94%
Unique Label 0: 9.23%
Unique Label 0_1_3: 8.04%
Unique Label 3: 6.25%
Unique Label 0_3_4: 0.89%
Unique Label 2_3: 1.19%
Unique Label 1_2_3: 1.79%
Unique Label 0_1_3_4: 1.19%
Unique Label 0_4: 1.49%
Unique Label 1_2: 3.27%
Unique Label 4: 1.19%
Unique Label 3_4: 0.89%
Unique Label 0_1_4: 1.19%
Unique Label 2: 0.30%


In [76]:
from imblearn.over_sampling import SMOTE
import torch
import numpy as np
from torch.utils.data import DataLoader

# Assuming PatientDataset is already defined and `json_path`, `root_dir`, `id_range`, `transform` are provided
dataset = PatientDataset(json_path, root_dir, id_range, transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 1. Extract features and labels from the dataset
features = []
labels = []

for images, targets, _, _ in dataloader:  # Ensure your dataset's __getitem__ returns correct structure
    images = images.view(images.size(0), -1)  # Flatten the images for SMOTE
    features.append(images.cpu().numpy())  # Convert to numpy
    labels.append(targets.cpu().numpy())  # Convert to numpy

# Concatenate all batches
features = np.vstack(features)
labels = np.concatenate(labels)

# 2. Apply SMOTE to balance the dataset
smote = SMOTE(sampling_strategy='auto', random_state=42)
features_resampled, labels_resampled = smote.fit_resample(features, labels)

# Check the distribution of resampled labels
unique_labels, counts = np.unique(labels_resampled, return_counts=True)
print("Resampled Label Distribution:")
for label, count in zip(unique_labels, counts):
    print(f"Label {label}: {count}")

# 3. Reshape the resampled features back to original image dimensions
# Assuming the original images have dimensions [channels, height, width] (e.g., [3, 224, 224])
channels, height, width = 3, 224, 224
features_resampled = features_resampled.reshape(-1, channels, height, width)

# Convert resampled features and labels to tensors
features_resampled_tensor = torch.tensor(features_resampled, dtype=torch.float32)
labels_resampled_tensor = torch.tensor(labels_resampled, dtype=torch.long)

# 4. Create a new Dataset class for resampled data
class ResampledPatientDataset(torch.utils.data.Dataset):
    def __init__(self, features, labels, transform=None):
        self.features = features
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        image = self.features[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Create a dataloader for the resampled dataset
resampled_dataset = ResampledPatientDataset(features_resampled_tensor, labels_resampled_tensor)
resampled_dataloader = DataLoader(resampled_dataset, batch_size=32, shuffle=True)

# Now, `resampled_dataloader` can be used for training your model.


ValueError: Imbalanced-learn currently supports binary, multiclass and binarized encoded multiclasss targets. Multilabel and multioutput targets are not supported.

In [71]:
from collections import Counter
import torch

# Initialize a counter for disease occurrences
disease_counter = Counter()

# Loop through the dataset to count diseases
for _, label_tensor, _, _ in dataset:
    # If the label is a one-hot vector, convert it to indices
    if label_tensor.ndim > 0:
        diseases = torch.nonzero(label_tensor).squeeze().tolist()
        if isinstance(diseases, int):  # Handle single disease case
            diseases = [diseases]
        for disease in diseases:
            disease_counter[disease] += 1
    else:  # Handle case where labels are single integers
        disease_counter[int(label_tensor.item())] += 1

# Total samples
total_samples = len(dataset)

# Compute disease distribution
disease_distribution = {disease: count / total_samples * 100 for disease, count in disease_counter.items()}

# Display results
print("Total Number of Cases per Disease:")
for disease, count in disease_counter.items():
    print(f"Disease {disease}: {count}")

print("\nDisease Distribution (%):")
for disease, dist in disease_distribution.items():
    print(f"Disease {disease}: {dist:.2f}%")


Total Number of Cases per Disease:
Disease 1: 91
Disease 0: 132
Disease 3: 90
Disease 4: 23
Disease 2: 22

Disease Distribution (%):
Disease 1: 27.08%
Disease 0: 39.29%
Disease 3: 26.79%
Disease 4: 6.85%
Disease 2: 6.55%


In [75]:
!pip install imblearn

Collecting imblearn
  Downloading imblearn-0.0-py2.py3-none-any.whl.metadata (355 bytes)
Collecting imbalanced-learn (from imblearn)
  Downloading imbalanced_learn-0.13.0-py3-none-any.whl.metadata (8.8 kB)
Collecting sklearn-compat<1,>=0.1 (from imbalanced-learn->imblearn)
  Downloading sklearn_compat-0.1.3-py3-none-any.whl.metadata (18 kB)
Downloading imblearn-0.0-py2.py3-none-any.whl (1.9 kB)
Downloading imbalanced_learn-0.13.0-py3-none-any.whl (238 kB)
Downloading sklearn_compat-0.1.3-py3-none-any.whl (18 kB)
Installing collected packages: sklearn-compat, imbalanced-learn, imblearn
Successfully installed imbalanced-learn-0.13.0 imblearn-0.0 sklearn-compat-0.1.3


In [47]:
for x,i,_,_ in train_loader:
    print(x.shape)
     
    break

torch.Size([4, 2, 1, 224, 224])


In [3]:
dataset.__len__()

336

In [66]:
import torch
import torch.nn as nn

class CrossModalFusion(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_modalities=2, num_patches=16, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_modalities = num_modalities
        self.num_patches = num_patches
        
        # Linear projection layer
        self.projection = nn.Linear(embed_dim, embed_dim)
        
        # Positional and modality embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, (num_patches + 1), embed_dim))
        self.modality_embeddings = nn.Parameter(torch.randn(num_modalities, 1, embed_dim))
        
        # Multi-head attention
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        # MLP block
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, z0):
        """
        Input: z0 - Tensor of shape [(num_modalities * num_patches + 1), embed_dim]
        """
        B, _, D = z0.shape  # Batch size, tokens, embedding dimension
        
        # Linear projection with positional and modality embeddings
        z_proj = self.projection(z0)
        z_proj += self.pos_embedding[:, :z0.shape[1], :]
        z_proj += torch.cat(
            [self.modality_embeddings[i].repeat(1, z_proj.shape[1] // self.num_modalities, 1) for i in range(self.num_modalities)],
            dim=0
        )
        
        # Split into queries, keys, and values
        queries = z_proj[:, :self.num_patches + 1, :]  # CLS and patches of the first modality
        keys = z_proj.reshape(B, -1, D)  # All modalities and patches combined
        values = keys.clone()
        
        # Multi-head attention with residual connection
        attn_output, _ = self.multihead_attn(queries, keys, values)
        z_proj = z_proj + attn_output
        
        # Layer norm and MLP with residual connection
        z_proj = z_proj + self.mlp(self.norm2(z_proj))
        
        return z_proj


In [67]:
class MultiModalOCT(nn.Module):
    def __init__(self, embed_dim=512):
        super().__init__()
        
        # Use ResNet18 as encoder
        resnet = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        
        # Cross Modal Fusion block
        self.cmf = CrossModalFusion(embed_dim=embed_dim)
        
        # Final classifier
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(embed_dim, 5),
            nn.Sigmoid()
        )

    def forward(self, x):
        B, M, C, H, W = x.shape  # [B, 2, 3, 224, 224]
        
        # Reshape for encoder
        x = x.view(B * M, C, H, W)
        
        # Encode each image
        features = self.encoder(x)
        features = features.squeeze(-1).squeeze(-1)
        features = features.view(B, M, -1)
        
        # Apply cross-modal fusion
        fused_features = self.cmf(features)
        
        # Flatten for classifier
        fused_features = fused_features.reshape(B, -1)
        
        # Classification
        output = self.classifier(fused_features)
        output = output.squeeze()  # Remove all singular dimensions
        
        return output



In [70]:
def train_model(model, train_loader, val_loader, num_epochs=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)
    
    best_val_acc = 0.0
    
    # Print first batch shapes
    for images, labels, patient_ids, sides in train_loader:
        print(f"Input images shape: {images.shape}")
        print(f"Input labels shape: {labels.shape}")
        outputs = model(images.to(device))
        print(f"Model output shape: {outputs.shape}")
        break
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        # Training loop
        for i in range(len(train_loader)):
            images, labels, patient_ids, sides = next(iter(train_loader))
            images = images.to(device)
            labels = labels.float().to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            
            # Check if we need to adjust output shape to match labels
            if outputs.shape != labels.shape:
                if len(outputs.shape) < len(labels.shape):
                    outputs = outputs.unsqueeze(-1)
                elif len(labels.shape) < len(outputs.shape):
                    labels = labels.unsqueeze(-1)
            
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            # Calculate accuracy
            predictions = (outputs > 0.5).float()
            train_correct += torch.sum(predictions == labels).item()
            train_total += labels.numel()
            
            train_loss += loss.item()
        
        # Calculate training metrics
        train_loss = train_loss / len(train_loader)
        train_accuracy = (train_correct / train_total) * 100
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for i in range(len(val_loader)):
                images, labels, patient_ids, sides = next(iter(val_loader))
                images = images.to(device)
                labels = labels.float().to(device)
                
                outputs = model(images)
                
                # Match dimensions
                if outputs.shape != labels.shape:
                    if len(outputs.shape) < len(labels.shape):
                        outputs = outputs.unsqueeze(-1)
                    elif len(labels.shape) < len(outputs.shape):
                        labels = labels.unsqueeze(-1)
                
                loss = criterion(outputs, labels)
                
                # Calculate accuracy
                predictions = (outputs > 0.5).float()
                val_correct += torch.sum(predictions == labels).item()
                val_total += labels.numel()
                
                val_loss += loss.item()
        
        val_loss = val_loss / len(val_loader)
        val_accuracy = (val_correct / val_total) * 100
        
        # Learning rate scheduling based on validation accuracy
        scheduler.step(val_accuracy)
        
        # Save best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 60)

# Initialize and train model
model = MultiModalOCT()
train_model(model, train_loader, val_loader)

Using device: cuda
Input images shape: torch.Size([4, 2, 3, 224, 224])
Input labels shape: torch.Size([4, 5])
Model output shape: torch.Size([4, 5])
Epoch 1/5
Train Loss: 0.4878, Train Accuracy: 79.55%
Val Loss: 0.6185, Val Accuracy: 60.00%
Learning Rate: 0.001000
------------------------------------------------------------
Epoch 2/5
Train Loss: 0.4549, Train Accuracy: 77.24%
Val Loss: 0.5749, Val Accuracy: 80.00%
Learning Rate: 0.001000
------------------------------------------------------------
Epoch 3/5
Train Loss: 0.4174, Train Accuracy: 82.01%
Val Loss: 0.6403, Val Accuracy: 70.00%
Learning Rate: 0.001000
------------------------------------------------------------
Epoch 4/5
Train Loss: 0.4179, Train Accuracy: 80.97%
Val Loss: 0.4040, Val Accuracy: 80.00%
Learning Rate: 0.001000
------------------------------------------------------------
Epoch 5/5
Train Loss: 0.4406, Train Accuracy: 80.15%
Val Loss: 0.5646, Val Accuracy: 80.00%
Learning Rate: 0.001000
---------------------------

In [69]:
import torch
import torch.nn as nn

class CrossModalFusion(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8, num_modalities=2, num_patches=16, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_modalities = num_modalities
        self.num_patches = num_patches
        
        # Linear projection layer
        self.projection = nn.Linear(embed_dim, embed_dim)
        
        # Positional and modality embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, (num_patches + 1) * num_modalities, embed_dim))
        self.modality_embeddings = nn.Parameter(torch.randn(num_modalities, 1, embed_dim))
        
        # Multi-head attention
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        
        # MLP block
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer norms
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, z0):
        """
        Input: z0 - Tensor of shape [batch_size, (num_modalities * num_patches + 1), embed_dim]
        """
        B, T, D = z0.shape  # Batch size, tokens, embedding dimension

        # Linear projection with positional embeddings
        z_proj = self.projection(z0)
        z_proj += self.pos_embedding[:, :T, :]

        # Modality-specific embeddings
        modality_tokens = T // self.num_modalities
        modality_embeds = torch.cat(
            [self.modality_embeddings[i].expand(B, modality_tokens, D) for i in range(self.num_modalities)],
            dim=1
        )
        z_proj += modality_embeds

        # Split into queries, keys, and values
        queries = z_proj[:, :self.num_patches + 1, :]  # CLS and patches of the first modality
        keys = z_proj  # All modalities and patches combined
        values = keys.clone()

        # Multi-head attention with residual connection
        attn_output, _ = self.multihead_attn(queries, keys, values)
        z_proj = z_proj + attn_output

        # Layer norm and MLP with residual connection
        z_proj = self.norm1(z_proj)
        z_proj = z_proj + self.mlp(self.norm2(z_proj))

        return z_proj
