In [4]:
# You might need to install einops and timm libraries before running in console
#pip install einops timm

In [2]:
from git import Repo  # pip install gitpython
repo_dir = '/kaggle/working/AttentionMIL'
Repo.clone_from("https://github.com/AIMLab-UBC/EC-p53abnlike-AIclassifier.git", repo_dir)

<git.repo.base.Repo '/kaggle/working/AttentionMIL/.git'>

In [3]:
import sys
sys.path.append('/kaggle/working/AttentionMIL') 

In [4]:
print(sys.path)

['/kaggle/lib/kagglegym', '/kaggle/lib', '/kaggle/input/prostate-cancer-grade-assessment', '/opt/conda/lib/python310.zip', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/root/.local/lib/python3.10/site-packages', '/opt/conda/lib/python3.10/site-packages', '/root/src/BigQuery_Helper', '/kaggle/working/AttentionMIL']


In [5]:
import os
import pickle
import enum

import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset

from utils.utils import print_color
from augmentation.ToTensor import ToTensor

In [6]:
import os
import pandas as pd

# Load CSV files (if applicable)
train_labels = pd.read_csv('/kaggle/input/prostate-cancer-grade-assessment/train.csv')
test_images_dir = '/kaggle/input/prostate-cancer-grade-assessment/test_images'
train_images_dir = '/kaggle/input/prostate-cancer-grade-assessment/train_images'

# Explore the dataset
print(train_labels.head())
print(f"Train images path: {train_images_dir}")

                           image_id data_provider  isup_grade gleason_score
0  0005f7aaab2800f6170c399693a96917    karolinska           0           0+0
1  000920ad0b612851f8e01bcc880d9b3d    karolinska           0           0+0
2  0018ae58b01bdadc8e347995b69f99aa       radboud           4           4+4
3  001c62abd11fa4b57bf7a6c603a11bb9    karolinska           4           4+4
4  001d865e65ef5d2579c190a0e0350d8f    karolinska           0           0+0
Train images path: /kaggle/input/prostate-cancer-grade-assessment/train_images


In [7]:
import pandas as pd

def load_chunks(chunk_file_location, chunk_ids, patch_pattern):
    if chunk_file_location.endswith('.csv'):
        data = pd.read_csv(chunk_file_location)
        patch_paths = []
        for _, row in data.iterrows():
            patch_paths.append({
                "image_id": row["image_id"],
                "isup_grade": row["isup_grade"],
                # Add more fields as necessary
            })
        return patch_paths
    else:
        # Default JSON handling
        with open(chunk_file_location) as f:
            data = json.load(f)
            chunks = data['chunks']
            patch_paths = [chunks[i] for i in chunk_ids]
        return patch_paths

In [8]:
import openslide
from torchvision import transforms
import torch
from torch.utils.data import Dataset
import os

class ProstateCancerOpenSlideDataset(Dataset):
    def __init__(self, csv_file, img_dir, patch_size=224, level=0, transform=None):
        """
        Args:
            csv_file (str): Path to the CSV file with image IDs and labels.
            img_dir (str): Directory with the whole-slide images.
            patch_size (int): Size of the patches to extract.
            level (int): Level of the downsampled image to use.
            transform (callable, optional): Optional transform to apply to patches.
        """
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.patch_size = patch_size
        self.level = level
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = os.path.join(self.img_dir, f"{row['image_id']}.tiff")
        label = row['isup_grade']  # Update based on the target column

        # Open the WSI
        slide = openslide.OpenSlide(img_path)

        # Get the dimensions of the selected level
        dimensions = slide.level_dimensions[self.level]

        # Randomly select a region to extract as a patch
        # For a more structured approach, use pre-determined coordinates
        x = torch.randint(0, dimensions[0] - self.patch_size, (1,)).item()
        y = torch.randint(0, dimensions[1] - self.patch_size, (1,)).item()

        # Extract the patch
        patch = slide.read_region((x, y), self.level, (self.patch_size, self.patch_size)).convert('RGB')

        # Apply transformations
        if self.transform:
            patch = self.transform(patch)

        return patch, label


In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [10]:
def create_data_set(cfg, chunk_id, state=None, slide_id=None, training_set=False):
    """
    Custom logic to create dataset from Kaggle data.
    """
    csv_data = pd.read_csv(cfg["chunk_file_location"])
    patch_dataset = ProstateCancerDataset(
        csv_file=cfg["chunk_file_location"],
        img_dir="/kaggle/input/prostate-cancer-grade-assessment/train_images",
        transform=transform
    )
    return patch_dataset, csv_data["isup_grade"].tolist()

In [11]:
from utils.utils import my_collate
from torch.utils.data import DataLoader

class Data_Loader(object):
    def __init__(self,
                 cfg: dict,
                 state: str):
        self.cfg = cfg
        self.state = state
        chunk, training_set = self.handle_chunk()

    def handle_chunk(self) -> dict | bool:
        if self.state == 'train':
            chunk = self.cfg["training_chunks"]
            training_set = True
        elif self.state == 'validation':
            chunk = self.cfg["validation_chunks"]
            training_set = False
        elif self.state == 'test':
            chunk = self.cfg["test_chunks"]
            training_set = False
        elif self.state == 'external':
            chunk = self.cfg["external_chunks"]
            training_set = False
        else:
            raise ValueError(f'{state} should be either train, validation, test or external!')
        return chunk, training_set

    def run(self):
        raise NotImplementedError()


class RepresentationDataset(Data_Loader):
    def __init__(self,
                 cfg: dict,
                 state: str,
                 slide_id: int) -> None:
        self.cfg = cfg
        self.state = state
        chunk, _ = self.handle_chunk()
        self.patch_dataset, _ = create_data_set(cfg, chunk, state=state,
                                                slide_id=slide_id)

    def run(self):
        batch_size = self.cfg["eval_batch_size"]
        return DataLoader(self.patch_dataset, batch_size=batch_size,
                      shuffle=False, pin_memory=True,
                      num_workers=self.cfg["num_patch_workers"])

class Dataset(Data_Loader):
    def __init__(self,
                 cfg: dict,
                 state: str = 'train') -> None:
        self.cfg = cfg
        self.state = state
        chunk, training_set = self.handle_chunk()
        self.patch_dataset, self.labels = create_data_set(cfg, chunk, state=state,
                                                          training_set=training_set)

    def run(self):
        batch_size = self.cfg["batch_size"] if self.state=='train' else \
                     self.cfg["eval_batch_size"]

        return DataLoader(self.patch_dataset, batch_size=batch_size,
                      shuffle=True, pin_memory=True, collate_fn=my_collate,
                      num_workers=self.cfg["num_patch_workers"])

In [12]:
csv_file = '/kaggle/input/prostate-cancer-grade-assessment/train.csv'
img_dir = '/kaggle/input/prostate-cancer-grade-assessment/train_images'

# Initialize dataset
dataset = ProstateCancerOpenSlideDataset(
    csv_file=csv_file,
    img_dir=img_dir,
    patch_size=224,
    level=0,  # Use the highest resolution
    transform=transform
)

# Initialize DataLoader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# # Iterate through the DataLoader
# for patches, labels in dataloader:
#     print(patches.shape, labels)

In [13]:
import timm
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from einops import rearrange, repeat
from torchvision.models import ResNet50_Weights
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights

####################################################
out_channel = {'alexnet': 256, 'vgg16': 512, 'vgg19': 512, 'vgg16_bn': 512, 'vgg19_bn': 512,
               'resnet18': 512, 'resnet34': 512, 'resnet50': 2048, 'resnext50_32x4d': 2048,
               'resnext101_32x8d': 2048, 'mobilenet_v2': 1280, 'mobilenet_v3_small': 576,
               'mobilenet_v3_large': 960 ,'mnasnet1_3': 1280, 'shufflenet_v2_x1_5': 1024,
               'squeezenet1_1': 512, 'efficientnet-b0': 1280, 'efficientnet-l2': 5504,
               'efficientnet-b1': 1280, 'efficientnet-b2': 1408, 'efficientnet-b3': 1536,
               'efficientnet-b4': 1792, 'efficientnet-b5': 2048, 'efficientnet-b6': 2304,
               'efficientnet-b7': 2560, 'efficientnet-b8': 2816, 'vit_deit_small_patch16_224': 384}

feature_map = {'alexnet': -2, 'vgg16': -2,  'vgg19': -2, 'vgg16_bn': -2,  'vgg19_bn': -2,
               'resnet18': -2, 'resnet34': -2, 'resnet50': -2, 'resnext50_32x4d': -2,
               'resnext101_32x8d': -2, 'mobilenet_v2': 0, 'mobilenet_v3_large': -2,
               'mobilenet_v3_small': -2, 'mnasnet1_3': 0, 'shufflenet_v2_x1_5': -1,
               'squeezenet1_1': 0, 'vit_deit_small_patch16_224': 'inf'}
####################################################

class VanillaModel(nn.Module):
    def __init__(self,
                 backbone: str) -> None:
        super(VanillaModel, self).__init__()

        self.backbone  = backbone
        # Vision Transformer
        if 'vit' in self.backbone:
            model = timm.create_model(self.backbone, pretrained=False, num_classes=0)
            self.feature_extract = model
        else:
            model = getattr(models, self.backbone)
            if self.backbone == "resnet50":
                model = model(weights=ResNet50_Weights.IMAGENET1K_V1)  # Pretrained weights
            elif self.backbone == "efficientnet_b7":
                weights = EfficientNet_B7_Weights.IMAGENET1K_V1  # Pretrained weights
                model = efficientnet_b7(weights=weights)
            else:
                model = model(weights=None)  # No pretrained weights for other models
            # Seperate feature and classifier layers
            self.feature_extract = nn.Sequential(*list(model.children())[0]) if feature_map[self.backbone]==0 \
                                   else nn.Sequential(*list(model.children())[:feature_map[self.backbone]])

    def forward(self,
                x: torch.Tensor) -> torch.Tensor:
        feature = self.feature_extract(x)
        feature = F.adaptive_avg_pool2d(feature, 1)
        out     = torch.flatten(feature, 1)
        return out

class VarMIL(nn.Module):
    """
    Our modified implementation of https://arxiv.org/abs/2107.09405
    """
    def __init__(self,
                 cfg: dict) -> None:
        super().__init__()
        dim = 128
        self.device     = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.attention  = nn.Sequential(nn.Linear(out_channel[cfg['backbone']], dim),
                                       nn.Tanh(),
                                       nn.Linear(dim, 1))
        self.classifier = nn.Sequential(nn.Linear(2*out_channel[cfg['backbone']], dim),
                                       nn.ReLU(),
                                       nn.Linear(dim, cfg['num_classes']))

    def forward(self,
                x: torch.Tensor) -> torch.Tensor | torch.Tensor:
        """
        x   (input)            : B (batch size) x K (nb_patch) x out_channel
        A   (attention weights): B (batch size) x K (nb_patch) x 1
        M   (weighted mean)    : B (batch size) x out_channel
        S   (std)              : B (batch size) x K (nb_patch) x out_channel
        V   (weighted variance): B (batch size) x out_channel
        nb_patch (nb of patch) : B (batch size)
        M_V (concate M and V)  : B (batch size) x 2*out_channel
        out (final output)     : B (batch size) x num_classes
        """
        b, k, c = x.shape
        A = self.attention(x)
        A = A.masked_fill((x == 0).all(dim=2).reshape(A.shape), -9e15) # filter padded rows
        A = F.softmax(A, dim=1)                                        # softmax over K
        M = torch.einsum('b k d, b k o -> b o', A, x)                  # d is 1 here
        S = torch.pow(x-M.reshape(b,1,c), 2)
        V = torch.einsum('b k d, b k o -> b o', A, S)
        nb_patch = (torch.tensor(k).expand(b)).to(self.device)
        nb_patch = nb_patch - torch.sum((x == 0).all(dim=2), dim=1)    # filter padded rows
        nb_patch = nb_patch / (nb_patch - 1)                           # I / I-1
        nb_patch = torch.nan_to_num(nb_patch, posinf=1)                # for cases, when we have only 1 patch (inf)
        V = V * nb_patch[:, None]                                      # broadcasting
        M_V = torch.cat((M, V), dim=1)
        out = self.classifier(M_V)
        return A, out



class Attention(nn.Module):
    def __init__(self, cfg: dict) -> None:
        super().__init__()
        self.feature_extractor = VanillaModel(cfg['backbone'])  # Feature extractor
        if cfg['model'] == 'VarMIL':
            self.model = VarMIL(cfg)
        else:
            raise NotImplementedError()

    def trainable_parameters(self) -> None:
        params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'Total trainable parameters are {params}.')

    def forward(self, x: torch.Tensor) -> torch.Tensor | torch.Tensor:
        """
        x (input): B (batch size) x 3 x H x W
        """
        # Extract features for each patch
        b, c, h, w = x.shape
        features = self.feature_extractor(x)  # Shape: B x out_channel
        features = features.unsqueeze(1)     # Add a pseudo-patch dimension: B x 1 x out_channel

        # Pass features to VarMIL
        attention, out = self.model(features)
        return attention, out

In [14]:
import torch

# Define device based on GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg = {
    'backbone': 'resnet50',  # Replace with your desired backbone
    'model': 'VarMIL',
    'num_classes': 6         # Adjust based on your dataset
}

# Initialize the model
model = Attention(cfg)
model = model.to(device)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 191MB/s]


In [15]:
from torch.utils.data import random_split, DataLoader

# Define the lengths for train and validation splits
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # Remaining 20% for validation

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders for both subsets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [16]:
print(dataset.data['isup_grade'].value_counts())

isup_grade
0    2892
1    2666
2    1343
4    1249
3    1242
5    1224
Name: count, dtype: int64


In [17]:
from sklearn.utils.class_weight import compute_class_weight

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(dataset.data['isup_grade']),
    y=dataset.data['isup_grade']
)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

In [18]:
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

# Initialize SummaryWriter
writer = SummaryWriter(log_dir='runs/training_logs')  # Set a directory for logs

In [19]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=4e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


for epoch in range(10):  # Adjust number of epochs
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    all_labels = []
    all_predictions = []

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        # Forward pass
        _, outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item()
        _, predictions = torch.max(outputs, 1)
        correct_predictions += (predictions == labels).sum().item()
        total_samples += labels.size(0)

        # Store labels and predictions for metrics
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predictions.cpu().numpy())

        # Log training loss per step
        step = epoch * len(train_loader) + i
        writer.add_scalar('Loss/train', loss.item(), step)

    # Compute metrics
    all_labels = np.array(all_labels)
    all_predictions = np.array(all_predictions)
    train_accuracy = 100 * correct_predictions / total_samples
    train_precision = precision_score(all_labels, all_predictions, average='weighted', zero_division=0)
    train_recall = recall_score(all_labels, all_predictions, average='weighted', zero_division=0)
    train_f1 = f1_score(all_labels, all_predictions, average='weighted', zero_division=0)

    # Log metrics
    writer.add_scalar('Accuracy/train', train_accuracy, epoch)
    writer.add_scalar('Precision/train', train_precision, epoch)
    writer.add_scalar('Recall/train', train_recall, epoch)
    writer.add_scalar('F1/train', train_f1, epoch)

    print(f"Epoch {epoch+1}, Training Loss: {running_loss / len(train_loader):.4f}, "
          f"Accuracy: {train_accuracy:.2f}%, Precision: {train_precision:.4f}, "
          f"Recall: {train_recall:.4f}, F1 Score: {train_f1:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0
    val_all_labels = []
    val_all_predictions = []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            _, outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Track validation loss and accuracy
            val_loss += loss.item()
            _, predictions = torch.max(outputs, 1)
            val_correct_predictions += (predictions == labels).sum().item()
            val_total_samples += labels.size(0)

            # Store labels and predictions for metrics
            val_all_labels.extend(labels.cpu().numpy())
            val_all_predictions.extend(predictions.cpu().numpy())

    # Compute validation metrics
    val_all_labels = np.array(val_all_labels)
    val_all_predictions = np.array(val_all_predictions)
    val_accuracy = 100 * val_correct_predictions / val_total_samples
    val_precision = precision_score(val_all_labels, val_all_predictions, average='weighted', zero_division=0)
    val_recall = recall_score(val_all_labels, val_all_predictions, average='weighted', zero_division=0)
    val_f1 = f1_score(val_all_labels, val_all_predictions, average='weighted', zero_division=0)

    # AUC requires probabilities, not class predictions
    # try:
    #     val_auc = roc_auc_score(val_all_labels, outputs.softmax(dim=1).cpu().numpy(), multi_class='ovr')
    # except ValueError:
    #     val_auc = None  # Handle cases where AUC calculation is invalid

    # Log validation metrics
    writer.add_scalar('Accuracy/val', val_accuracy, epoch)
    writer.add_scalar('Precision/val', val_precision, epoch)
    writer.add_scalar('Recall/val', val_recall, epoch)
    writer.add_scalar('F1/val', val_f1, epoch)
    # if val_auc is not None:
    #     writer.add_scalar('AUC/val', val_auc, epoch)

    print(f"Epoch {epoch+1}, Validation Loss: {val_loss / len(val_loader):.4f}, "
          f"Accuracy: {val_accuracy:.2f}%, Precision: {val_precision:.4f}, "
          f"Recall: {val_recall:.4f}, F1 Score: {val_f1:.4f}")

    # Step the scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    writer.add_scalar('Learning Rate', current_lr, epoch)
    print(f"Epoch {epoch+1}, Learning Rate: {current_lr:.6f}")

Epoch 1, Training Loss: 1.7897, Accuracy: 21.41%, Precision: 0.2040, Recall: 0.2141, F1 Score: 0.2005
Epoch 1, Validation Loss: 1.7923, Accuracy: 16.90%, Precision: 0.1897, Recall: 0.1690, F1 Score: 0.1544
Epoch 1, Learning Rate: 0.000050
Epoch 2, Training Loss: 1.7852, Accuracy: 22.36%, Precision: 0.2086, Recall: 0.2236, F1 Score: 0.2117
Epoch 2, Validation Loss: 1.7815, Accuracy: 20.39%, Precision: 0.2186, Recall: 0.2039, F1 Score: 0.1989
Epoch 2, Learning Rate: 0.000050
Epoch 3, Training Loss: 1.7783, Accuracy: 24.25%, Precision: 0.2201, Recall: 0.2425, F1 Score: 0.2195
Epoch 3, Validation Loss: 1.7821, Accuracy: 25.56%, Precision: 0.2150, Recall: 0.2556, F1 Score: 0.2083
Epoch 3, Learning Rate: 0.000050
Epoch 4, Training Loss: 1.7799, Accuracy: 24.27%, Precision: 0.2160, Recall: 0.2427, F1 Score: 0.2167
Epoch 4, Validation Loss: 1.7735, Accuracy: 26.74%, Precision: 0.2623, Recall: 0.2674, F1 Score: 0.2081
Epoch 4, Learning Rate: 0.000050
Epoch 5, Training Loss: 1.7769, Accuracy: 25

In [20]:
writer.close()