In [1]:
!ls /kaggle/input/pyvips-python-and-deb-package-gpu
# intall the deb packages
!yes | dpkg -i --force-depends /kaggle/input/pyvips-python-and-deb-package-gpu/linux_packages/archives/*.deb
# install the python wrapper
!pip install pyvips -f /kaggle/input/pyvips-python-and-deb-package-gpu/python_packages/ --no-index
!pip list | grep pyvips

linux_packages	python_packages
Selecting previously unselected package apparmor.
(Reading database ... 122997 files and directories currently installed.)
Preparing to unpack .../apparmor_3.0.4-2ubuntu2.2_amd64.deb ...
Unpacking apparmor (3.0.4-2ubuntu2.2) ...
Selecting previously unselected package autoconf.
Preparing to unpack .../autoconf_2.71-2_all.deb ...
Unpacking autoconf (2.71-2) ...
Selecting previously unselected package automake.
Preparing to unpack .../automake_13a1.16.5-1.3_all.deb ...
Unpacking automake (1:1.16.5-1.3) ...
Selecting previously unselected package autotools-dev.
Preparing to unpack .../autotools-dev_20220109.1_all.deb ...
Unpacking autotools-dev (20220109.1) ...
Selecting previously unselected package bzip2-doc.
Preparing to unpack .../bzip2-doc_1.0.8-5build1_all.deb ...
Unpacking bzip2-doc (1.0.8-5build1) ...
Selecting previously unselected package file.
Preparing to unpack .../file_13a5.41-3ubuntu0.1_amd64.deb ...
Unpacking file (1:5.41-3

In [2]:
import pyvips

In [3]:
import albumentations as A

  data = fetch_version_info()


In [4]:
import os
import pandas as pd
import numpy as np
import random
import cv2
from PIL import Image, ImageFile
import gc 
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torch.nn.functional as F
import timm

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None
os.environ['VIPS_CONCURRENCY'] = '4'
os.environ['VIPS_DISC_THRESHOLD'] = '15gb'

# Paths and constants
data_dir = "/kaggle/input/UBC-OCEAN/"  # Update with your data directory
test_images_dir = os.path.join(data_dir, "test_images")
test_thumbnails_dir = os.path.join(data_dir, "test_thumbnails")
test_csv_path = os.path.join(data_dir, "test.csv")
submission_path = "submission.csv"
model_path = "/kaggle/input/lunit-224-400-tiles-7698/pytorch/default/1/best_model_224_400tiles.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
entropy_threshold = 1.5  # Threshold for declaring class 'Other'

# Label mapping
label_mapping = {'HGSC': 0, 'EC': 1, 'CC': 2, 'LGSC': 3, 'MC': 4}
reverse_label_mapping = {v: k for k, v in label_mapping.items()}
reverse_label_mapping[5] = 'Other'  # Add "Other" for entropy-based predictions

In [5]:
# Number of tiles and tile size
n_tiles = 400
tile_size = 224
batch_size = 1  # Adjust batch size as needed

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
######### MIL Dataset

class WSI_TMA_MILDataset(Dataset):
    def __init__(
        self, 
        image_ids, 
        images_dir, 
        thumbnails_dir, 
        n_tiles, 
        tile_size, 
        transform=None
    ):
        """
        image_ids: list of image IDs (strings or integers)
        images_dir: directory containing the full-resolution images
        thumbnails_dir: directory containing the thumbnail images
        n_tiles: number of tiles per image
        tile_size: size of each tile
        transform: torchvision transforms
        """
        self.image_ids = image_ids
        self.images_dir = images_dir
        self.thumbnails_dir = thumbnails_dir
        self.n_tiles = n_tiles
        self.tile_size = tile_size
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        """
        Returns:
            tiles_tensor: A tensor of shape (n_tiles, C, tile_size, tile_size).
            label: The numeric label (torch.long).
        """

        image_id = self.image_ids[idx]
        
        # thumb_path = os.path.join(self.thumbnails_dir, f"{image_id}_thumbnail.png")
        # if os.path.exists(thumb_path):
        #     img_path = thumb_path
        # else:
        #     # Fallback to full image
        #     img_path = os.path.join(self.images_dir, f"{image_id}.png")
        img_path = os.path.join(self.images_dir,f"{image_id}.png")
        try:
            # Attempt to open the image
            # image = Image.open(img_path).convert("RGB")
            # image = cv2.imread(img_path, cv2.IMREAD_COLOR)
            image = pyvips.Image.new_from_file(img_path, access='sequential').numpy()

        except (FileNotFoundError, IOError):
            print(f"Warning: Unable to load image {img_path}. Skipping.")
            return None, image_id  # Return a placeholder for the image with the image_id
            
        # img_pil = Image.fromarray(image)
        # cropped1 = get_cropped_image(img_pil)
        # cropped2 = crop_wsi_with_otsu(cropped1)
        # img_rgb = cv2.cvtColor(np.array(img_pil), cv2.COLOR_BGR2RGB)
        # h, w, c = img_rgb.shape

        tiles = []
        # Resize the image to 4096x4096
        # img_resized = Image.fromarray(img_rgb).resize((12288, 12288))
        # img_resized = np.array(img_resized)
        is_tma = image.shape[0] <= 5000 and image.shape[1] <= 5000

        # 1. downsample
        if is_tma:
            resize = A.Resize(image.shape[0], image.shape[1])
        else:
            resize = A.Resize(image.shape[0]//3, image.shape[1] // 3)
        img_resized = resize(image=image)['image']
        # Determine the size of each tile (assume 20x20 grid for 400 tiles)
        tile_size = img_resized.shape[0] // 20

        # Create 400 tiles of equal size
        for i in range(20):  # Loop over rows
            for j in range(20):  # Loop over columns
                # Calculate the coordinates of the current tile
                x_start = j * tile_size
                y_start = i * tile_size
                tile = img_resized[y_start:y_start + tile_size, x_start:x_start + tile_size, :]
        
                # Convert the tile to a PIL Image
                tile_img = Image.fromarray(tile)
        
                # Apply transformation if specified
                if self.transform:
                    tile_img = self.transform(tile_img)
        
                tiles.append(tile_img)

        tiles_tensor = torch.stack(tiles, dim=0)
        del image, img_resized, resize
        gc.collect()
        torch.cuda.empty_cache()

        return tiles_tensor, image_id

In [7]:
########################################
# MultiPatchViTExtractor
########################################
class MultiPatchViTExtractor:
    """
    Loads two ViT models (patch8 and patch16) from timm, extracts features from each,
    and concatenates them into a single vector.
    """
    def __init__(self, device="cpu"):
        self.device = device
        
        # Example model names
        self.model_patch8  = timm.create_model("vit_small_patch8_224",  pretrained=False)
        self.model_patch16 = timm.create_model("vit_small_patch16_224", pretrained=False)

        # load lunit weights
        self.model_patch8.load_state_dict(
            torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch8_ep200.torch", 
                       map_location="cpu"), strict=False)
        self.model_patch16.load_state_dict(
            torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch16_ep200.torch", 
                       map_location="cpu"), strict=False)

        # Remove classification heads
        self.model_patch8.head  = nn.Identity()
        self.model_patch16.head = nn.Identity()
        
        self.model_patch8.to(device).eval()
        self.model_patch16.to(device).eval()
    
    @torch.no_grad()
    def extract_features(self, image_tensor: torch.Tensor) -> torch.Tensor:
        """
        Expects image_tensor shape: (1, 3, 224, 224).
        Returns concatenated features from patch8 & patch16, e.g. shape: (768+768,).
        """
        feats8  = self.model_patch8(image_tensor).squeeze(0)   # shape (768,)
        feats16 = self.model_patch16(image_tensor).squeeze(0)  # shape (768,)
        return torch.cat([feats8, feats16], dim=0)  # shape (1536,)

In [8]:
#######################################
# MILAttentionModel (Double ViT Backbone)
########################################
class MILAttentionDoubleDINO(nn.Module):
    """
    Multi-Instance Learning model that uses the MultiPatchViTExtractor to get features 
    from each tile, applies attention, and then classifies.
    """
    def __init__(self, device="cpu", num_classes=5, embed_dim=512):
        super().__init__()
        # Instead of inception_v3, we use our double-ViT extractor
        self.extractor = MultiPatchViTExtractor(device=device)
        
        # The concatenated output of patch8 and patch16 is 1536 dims
        in_features = 768

        # Project to an embedding space if needed
        self.embed = nn.Linear(in_features, embed_dim)
        
        # Attention parameters
        self.attention_A = nn.Linear(embed_dim, 128)
        self.attention_B = nn.Linear(128, 1)
        
        # Classifier
        self.classifier = nn.Linear(embed_dim, num_classes)

        self.device = device

    def forward(self, x):
        """
        x shape: [B, N, C, H, W]
        B: batch size (# of patients/slides)
        N: # of tiles per slide
        C, H, W: channels, height, width (e.g., 3, 224, 224)
        """
        B, N, C, H, W = x.shape
        
        # We'll accumulate all tile features in a list, then stack
        all_features = []
        for b_idx in range(B):
            # Extract features for N tiles in the current batch element
            tile_features = []
            for n_idx in range(N):
                # Each tile is shape [C, H, W]
                tile = x[b_idx, n_idx, ...].unsqueeze(0).to(self.device)  # shape [1, C, H, W]
                
                with torch.no_grad():
                    feats = self.extractor.extract_features(tile)  # shape [1536,]
                tile_features.append(feats.unsqueeze(0))  # shape [1, 1536]
            
            tile_features = torch.cat(tile_features, dim=0)  # shape [N, 1536]
            all_features.append(tile_features.unsqueeze(0))  # shape [1, N, 1536]
            del tile_features, tile
            gc.collect()
            torch.cuda.empty_cache()

        
        # Concatenate along batch dimension: [B, N, 1536]
        all_features = torch.cat(all_features, dim=0).to(self.device)
        
        # Project to embedding dimension
        embeddings = self.embed(all_features)  # [B, N, embed_dim]
        
        # Attention
        A = torch.relu(self.attention_A(embeddings))  # [B, N, 128]
        A = self.attention_B(A)                       # [B, N, 1]
        A = torch.softmax(A, dim=1)                   # attention weights over tiles
        
        # Weighted sum of embeddings
        weighted_sum = torch.sum(A * embeddings, dim=1)  # [B, embed_dim]

        del all_features, embeddings
        gc.collect()
        torch.cuda.empty_cache()

        # Final classification
        logits = self.classifier(weighted_sum)  # [B, num_classes]
        return logits

In [9]:
# Define transforms
test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [10]:
# Extract all image IDs from the test_images directory
image_ids = [os.path.splitext(img)[0] for img in os.listdir(test_images_dir) if img.endswith(".png")]

# Create test dataset and loader
test_dataset = WSI_TMA_MILDataset(image_ids,
                                  n_tiles = n_tiles,
                                  tile_size = tile_size,
                                  images_dir=test_images_dir, 
                                  thumbnails_dir=test_thumbnails_dir,
                                  transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

# Load model
model = MILAttentionDoubleDINO(num_classes=5,device=device)
model.load_state_dict(torch.load(model_path, map_location=device))

  torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch8_ep200.torch",
  torch.load("/kaggle/input/lunit-dino-weights/dino_vit_small_patch16_ep200.torch",
  model.load_state_dict(torch.load(model_path, map_location=device))


<All keys matched successfully>

In [11]:
with torch.no_grad():
    model.to(device)
    model.eval()

In [12]:
# Perform inference with entropy thresholding
predictions = []
image_ids_result = []

with torch.no_grad():
    for images, ids in test_loader:
        valid_images = []
        valid_ids = []
        
        # Filter out None images
        for img, img_id in zip(images, ids):
            if img is not None and img_id is not None:
                valid_images.append(img)
                valid_ids.append(img_id)
        
        if not valid_images:  # Skip batch if no valid images
            continue
        
        valid_images = torch.stack(valid_images).to(device)
        logits = model(valid_images)
        
        # Convert logits to probabilities using softmax
        probabilities = F.softmax(logits, dim=1).cpu().numpy()
        print(probabilities)
        # Compute entropy
        entropies = -np.sum(probabilities * np.log(probabilities + 1e-9), axis=1)
        print(entropies)
        # Predict labels based on entropy
        preds = np.argmax(probabilities, axis=1)
        preds[entropies > entropy_threshold] = 5  # Assign "Other" class if entropy > threshold
        
        predictions.extend(preds)
        image_ids_result.extend(valid_ids)

        del valid_images, logits
        gc.collect()
        torch.cuda.empty_cache()

# Map predictions to labels using reverse_label_mapping
mapped_predictions = [reverse_label_mapping[pred] for pred in predictions]

# Create submission DataFrame
submission_df = pd.DataFrame({
    "image_id": image_ids_result,
    "label": mapped_predictions
})

# Save to CSV
submission_df.to_csv(submission_path, index=False)
print(f"Submission saved to {submission_path}")

[[1.4588974e-01 3.7229471e-03 8.4387887e-01 6.1678374e-03 3.4055355e-04]]
[0.47899622]
Submission saved to submission.csv
