In [1]:
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import timm
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
import os
import numpy as np
import pandas as pd
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score, average_precision_score
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.pyplot as plt


In [2]:
tile_size= 224
batch_size = 3
n_tiles = 400
num_workers = 4

In [3]:
def crop_wsi_with_otsu(img_pil, output_path=None):
    """
    Load a large image (e.g., a WSI thumbnail or a TMA),
    generate an Otsu mask in grayscale, find the bounding box
    of the tissue region, and crop the entire image to that box.

    input_path: full path to the .png (or .tif, .jpg, etc.) image
    output_path: if provided, save the cropped image to this path
                 if None, just return the cropped PIL Image object
    """
    # # 1) Load image in RGB (PIL -> NumPy)
    # img_pil = Image.open(input_path).convert("RGB")
    img_rgb = np.array(img_pil)  # shape: (H, W, 3)

    # 2) Convert to grayscale for Otsu
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)

    # 3) Otsu threshold - background vs. tissue
    
    _, raw_mask = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    fraction_white = (raw_mask > 0).mean()  # fraction of pixels that are 255

    if fraction_white > 0.5:
        mask = 255 - raw_mask  # invert
    else:
        mask = raw_mask
    
    # Convert to binary {0,1}
    mask_binary = (mask > 0).astype(np.uint8)
    # 4) Find bounding box
    #    We want the min/max row/col where mask is non-zero
    coords = np.where(mask > 0)  # returns (row_indices, col_indices)
    if len(coords[0]) == 0:
        # No tissue found, fallback or skip
        print("Warning: Otsu found no tissue; returning original image")
        cropped_img = img_pil
    else:
        y_min, y_max = coords[0].min(), coords[0].max()
        x_min, x_max = coords[1].min(), coords[1].max()

        # 5) Crop the original RGB
        cropped_img = img_pil.crop((x_min, y_min, x_max+1, y_max+1))
        # Note: +1 to include that pixel.

    # 6) Save or return
    # if output_path:
    #     cropped_img.save(output_path)
    #     print(f"Cropped image saved to: {output_path}")
    return cropped_img

# Example usage:
# cropped_image = crop_wsi_with_otsu("slide1.png", "slide1_cropped.png")

In [4]:
def get_cropped_image(image, th_area=1000):
    # Calculate the aspect ratio
    as_ratio = image.size[0] / image.size[1]
    
    if as_ratio >= 1.5:
        # Create a mask using maximum value condition
        mask = np.max(np.array(image) > 0, axis=-1).astype(np.uint8)
        
        # Find connected components in the mask
        retval, labels = cv2.connectedComponents(mask)
        
        for label in range(1, retval):
            # Skip small components
            area = np.sum(labels == label)
            if area < th_area:
                continue
            
            # Get coordinates of the first valid connected component
            x, y = np.meshgrid(np.arange(image.size[0]), np.arange(image.size[1]))
            xs, ys = x[labels == label], y[labels == label]
            
            # Calculate cropping boundaries
            sx, ex = np.min(xs), np.max(xs)
            cx = (sx + ex) // 2
            crop_size = image.size[1]
            sx = max(0, cx - crop_size // 2)
            ex = min(sx + crop_size - 1, image.size[0] - 1)
            sx = ex - crop_size + 1
            sy, ey = 0, image.size[1] - 1
            
            # Crop the image and return
            cropped_image = image.crop((sx, sy, ex + 1, ey + 1))
            return cropped_image
    else:
        # If aspect ratio is less than 1.5, use the entire image
        return image

In [5]:

########################################
# Dataset
########################################
class WSI_TMA_MILDataset(Dataset):
    def __init__(self, df, data_dir, n_tiles, tile_size, transform=None):
        """
        df: DataFrame with columns ['image_id', 'label', 'is_tma']
        data_dir: directory containing 'train_images' and 'train_thumbnails'
        n_tiles: number of tiles per image
        tile_size: size of each tile
        transform: torchvision transforms
        """
        self.df = df.reset_index(drop=True)
        self.data_dir = data_dir
        self.n_tiles = n_tiles
        self.tile_size = tile_size
        self.transform = transform
        # Adjust mapping as per your classes
        self.label_mapping = { 'HGSC':0, 'EC':1, 'CC':2, 'LGSC':3, 'MC':4 }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_id = row['image_id']
        label_str = row['label']
        label = self.label_mapping[label_str]
        is_tma = row['is_tma']

        if is_tma:
            # TMA image
            img_path = os.path.join(self.data_dir, 'train_images', f"{image_id}.png")
        else:
            # WSI thumbnail
            img_path = os.path.join(self.data_dir, 'train_thumbnails', f"{image_id}_thumbnail.png")

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        if img is None:
            raise FileNotFoundError(f"Image not found: {img_path}")
        img_pil = Image.fromarray(img)
        cropped1 = get_cropped_image(img_pil)
        cropped2 = crop_wsi_with_otsu(cropped1)
        img_rgb = cv2.cvtColor(np.array(cropped2), cv2.COLOR_BGR2RGB)
        h, w, c = img_rgb.shape

        tiles = []
        # Resize the image to 4096x4096
        img_resized = Image.fromarray(img_rgb).resize((8192, 8192))
        img_resized = np.array(img_resized)

        # Determine the size of each tile (assume 20x20 grid for 400 tiles)
        tile_size = 8192 // 20

        # Create 400 tiles of equal size
        for i in range(20):  # Loop over rows
            for j in range(20):  # Loop over columns
                # Calculate the coordinates of the current tile
                x_start = j * tile_size
                y_start = i * tile_size
                tile = img_resized[y_start:y_start + tile_size, x_start:x_start + tile_size, :]
        
                # Convert the tile to a PIL Image
                tile_img = Image.fromarray(tile)
        
                # Apply transformation if specified
                if self.transform:
                    tile_img = self.transform(tile_img)
        
                tiles.append(tile_img)

        tiles_tensor = torch.stack(tiles, dim=0)
        return tiles_tensor, torch.tensor(label, dtype=torch.long)

In [6]:
########################################
# MultiPatchViTExtractor
########################################
class MultiPatchViTExtractor:
    """
    Loads two ViT models (patch8 and patch16) from timm, extracts features from each,
    and concatenates them into a single vector.
    """
    def __init__(self, device="cpu"):
        self.device = device
        
        # Example model names
        self.model_patch8  = timm.create_model("vit_small_patch8_224",  pretrained=False)
        self.model_patch16 = timm.create_model("vit_small_patch16_224", pretrained=False)

        # load lunit weights
        self.model_patch8.load_state_dict(
            torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch8_ep200.torch", 
                       map_location="cpu"), strict=False)
        self.model_patch16.load_state_dict(
            torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch16_ep200.torch", 
                       map_location="cpu"), strict=False)

        # Remove classification heads
        self.model_patch8.head  = nn.Identity()
        self.model_patch16.head = nn.Identity()
        
        self.model_patch8.to(device).eval()
        self.model_patch16.to(device).eval()
    
    @torch.no_grad()
    def extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        Expects image_tensor shape: (1, 3, 224, 224).
        Returns concatenated features from patch8 & patch16, e.g. shape: (768+768,).
        """
        feats8  = self.model_patch8(image_tensor).squeeze(0)   # shape (768,)
        feats16 = self.model_patch16(image_tensor).squeeze(0)  # shape (768,)
        return torch.cat([feats8, feats16], dim=0)  # shape (1536,)




In [7]:
########################################
# MILAttentionModel (Double ViT Backbone)
########################################
class MILAttentionDoubleDINO(nn.Module):
    """
    Multi-Instance Learning model that uses the MultiPatchViTExtractor to get features 
    from each tile, applies attention, and then classifies.
    """
    def __init__(self, device="cpu", num_classes=5, embed_dim=512):
        super().__init__()
        # Instead of inception_v3, we use our double-ViT extractor
        self.extractor = MultiPatchViTExtractor(device=device)
        
        # The concatenated output of patch8 and patch16 is 1536 dims
        in_features = 768

        # Project to an embedding space if needed
        self.embed = nn.Linear(in_features, embed_dim)
        
        # Attention parameters
        self.attention_A = nn.Linear(embed_dim, 160)
        self.attention_B = nn.Linear(160, 1)
        
        # Classifier
        self.classifier = nn.Linear(embed_dim, num_classes)

        self.device = device

    def forward(self, x):
        """
        x shape: [B, N, C, H, W]
        B: batch size (# of patients/slides)
        N: # of tiles per slide
        C, H, W: channels, height, width (e.g., 3, 224, 224)
        """
        B, N, C, H, W = x.shape
        
        # We'll accumulate all tile features in a list, then stack
        all_features = []
        for b_idx in range(B):
            # Extract features for N tiles in the current batch element
            tile_features = []
            for n_idx in range(N):
                # Each tile is shape [C, H, W]
                tile = x[b_idx, n_idx, ...].unsqueeze(0).to(self.device)  # shape [1, C, H, W]
                
                with torch.no_grad():
                    feats = self.extractor.extract_features(tile)  # shape [1536,]
                tile_features.append(feats.unsqueeze(0))  # shape [1, 1536]
            
            tile_features = torch.cat(tile_features, dim=0)  # shape [N, 1536]
            all_features.append(tile_features.unsqueeze(0))  # shape [1, N, 1536]
        
        # Concatenate along batch dimension: [B, N, 1536]
        all_features = torch.cat(all_features, dim=0).to(self.device)
        
        # Project to embedding dimension
        embeddings = self.embed(all_features)  # [B, N, embed_dim]
        
        # Attention
        A = torch.relu(self.attention_A(embeddings))  # [B, N, 128]
        A = self.attention_B(A)                       # [B, N, 1]
        A = torch.softmax(A, dim=1)                   # attention weights over tiles
        
        # Weighted sum of embeddings
        weighted_sum = torch.sum(A * embeddings, dim=1)  # [B, embed_dim]
        
        # Final classification
        logits = self.classifier(weighted_sum)  # [B, num_classes]
        return logits

In [8]:
########################################
# Setup
########################################
data_dir = "/kaggle/input/UBC-OCEAN/"
df = pd.read_csv(os.path.join(data_dir, "train.csv"))

# Example label mapping already defined in dataset class
y = df['label']

train_transform = transforms.Compose([
    transforms.ColorJitter(brightness=.2,contrast=.2,saturation=.2,hue=.2),
    transforms.Resize((224,224)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [9]:
# Split WSI data
wsi_df = df[df['is_tma'] == 0]
train_wsi, val_wsi = train_test_split(wsi_df, test_size=0.3, stratify=wsi_df['label'], random_state=42)

# Split TMA data
tma_df = df[df['is_tma'] == 1]
train_tma, val_tma = train_test_split(tma_df, test_size=0.6, stratify=tma_df['label'], random_state=42)

# Combine splits
train_df = pd.concat([train_wsi, train_tma]).reset_index(drop=True)
val_df = pd.concat([val_wsi, val_tma]).reset_index(drop=True)

# Load data
train_dataset = WSI_TMA_MILDataset(train_df, data_dir=data_dir, n_tiles=n_tiles, tile_size=tile_size, transform=train_transform)
val_dataset = WSI_TMA_MILDataset(val_df, data_dir=data_dir, n_tiles=n_tiles, tile_size=tile_size, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [10]:
###################################
# Train and validate
###################################
model = MILAttentionDoubleDINO(num_classes=5, device=device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)#, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=5e-5)
resume_best_model_path = "/kaggle/input/resume-after-epoch-5/pytorch/default/1/best_model_224_400tiles_tuned (2).pth"

model.load_state_dict(torch.load(resume_best_model_path))

best_balanced_acc = 0.0
best_auprc = 0.0
patience = 4
epochs_no_improve = 0
best_model_path = "best_model_224_400tiles_tuned.pth"

for epoch in range(6):  # Up to 20 epochs, adjust as needed
    print(f"\nEpoch {epoch+1}/{6}")
    # print(f"Current Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
    
    # Train
    model.train()
    running_loss = 0.0
    for tiles, labels in train_loader:
        tiles, labels = tiles.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(tiles)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * tiles.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    running_loss_val = 0.0
    all_preds = []
    all_labels = []
    all_probs = []  # To store probabilities for AUPRC calculation
    with torch.no_grad():
        for tiles, labels in val_loader:
            tiles, labels = tiles.to(device), labels.to(device)
            logits = model(tiles)
            loss = criterion(logits, labels)
            running_loss_val += loss.item() * tiles.size(0)
            
            probs = torch.softmax(logits, dim=1).cpu().numpy()  # Get probabilities
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            
            all_probs.append(probs)
            all_preds.append(preds)
            all_labels.append(labels.cpu().numpy())
            

    val_loss = running_loss_val / len(val_loader.dataset)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)  # Shape: (num_samples, num_classes)

    # Calculate metrics
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)
    # Calculate AUPRC (macro-average across all classes)
    auprc = average_precision_score(
        np.eye(len(all_probs[0]))[all_labels],  # One-hot encode labels
        all_probs,
        average="macro"
    )

    print(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
          f"Val Balanced Acc: {balanced_acc:.4f}, Val AUPRC: {auprc:.4f}")

    # Check for improvement
    if balanced_acc > best_balanced_acc or auprc > best_auprc:
        if balanced_acc > best_balanced_acc:
            best_balanced_acc = balanced_acc
        if auprc > best_auprc:
            best_auprc = auprc
        epochs_no_improve = 0
        torch.save(model.state_dict(), best_model_path)
        print("  * Best model saved.")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print("Early stopping.")
            break

    # Step the scheduler
    # scheduler.step()



  torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch8_ep200.torch",
  torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch16_ep200.torch",



Epoch 1/6


  model.load_state_dict(torch.load(resume_best_model_path))


Epoch 1 - Train Loss: 0.4766, Val Loss: 0.6272, Val Balanced Acc: 0.7289, Val AUPRC: 0.8453
  * Best model saved.

Epoch 2/6
Epoch 2 - Train Loss: 0.4129, Val Loss: 0.7092, Val Balanced Acc: 0.6905, Val AUPRC: 0.8176

Epoch 3/6
Epoch 3 - Train Loss: 0.3936, Val Loss: 0.6624, Val Balanced Acc: 0.7361, Val AUPRC: 0.8374
  * Best model saved.

Epoch 4/6
Epoch 4 - Train Loss: 0.3560, Val Loss: 0.8804, Val Balanced Acc: 0.6607, Val AUPRC: 0.8498
  * Best model saved.

Epoch 5/6
Epoch 5 - Train Loss: 0.3220, Val Loss: 0.7728, Val Balanced Acc: 0.7417, Val AUPRC: 0.8135
  * Best model saved.

Epoch 6/6
Epoch 6 - Train Loss: 0.2861, Val Loss: 0.7747, Val Balanced Acc: 0.7673, Val AUPRC: 0.8500
  * Best model saved.


In [11]:
# Load best model and evaluate
model.load_state_dict(torch.load(best_model_path))

# Evaluate final performance on validation set
model.eval()
all_preds = []
all_labels = []
all_probs = []
with torch.no_grad():
    for tiles, labels in val_loader:
        tiles, labels = tiles.to(device), labels.to(device)
        logits = model(tiles)
        probs = torch.softmax(logits, dim=1).cpu().numpy()
        preds = torch.argmax(logits, dim=1).cpu().numpy()
        all_probs.append(probs)
        all_preds.append(preds)
        all_labels.append(labels.cpu().numpy())
all_preds = np.concatenate(all_preds)
all_labels = np.concatenate(all_labels)
all_probs = np.concatenate(all_probs)



  model.load_state_dict(torch.load(best_model_path))


In [12]:
# Final metrics
final_balanced_acc = balanced_accuracy_score(all_labels, all_preds)
final_auprc = average_precision_score(
    np.eye(len(all_probs[0]))[all_labels],
    all_probs,
    average="macro"
)
print(f"Final Balanced Accuracy: {final_balanced_acc:.4f}")
print(f"Final AUPRC: {final_auprc:.4f}")

Final Balanced Accuracy: 0.7673
Final AUPRC: 0.8500


In [13]:
np.save('/kaggle/working/all_labels.npy', all_labels)
np.save('/kaggle/working/all_probs.npy', all_probs)

print("Files saved successfully: all_labels.npy, all_probs.npy")

Files saved successfully: all_labels.npy, all_probs.npy
