In [1]:
import kagglehub
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [2]:
histopathologic_cancer_detection_path = kagglehub.competition_download('histopathologic-cancer-detection')

Downloading from https://www.kaggle.com/api/v1/competitions/data/download-all/histopathologic-cancer-detection...


100%|██████████| 6.31G/6.31G [02:44<00:00, 41.1MB/s]

Extracting files...





In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from pathlib import Path
import pandas as pd
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os
import time
from typing import Tuple, Optional

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
class HistopathologicDataset(Dataset):
    def __init__(self, data_dir: str, labels_df: Optional[pd.DataFrame] = None,
                 transform=None, is_test=False):
        self.data_dir = Path(data_dir)
        self.transform = transform
        self.is_test = is_test

        self.image_files = sorted([f for f in self.data_dir.iterdir() if f.suffix == '.tif'])

        if not is_test and labels_df is not None:
            self.labels_dict = labels_df.set_index('id')['label'].to_dict()
        else:
            self.labels_dict = None

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        if self.is_test:
            file_id = img_path.stem
            return file_id, image
        else:
            file_id = img_path.stem
            label = self.labels_dict[file_id]
            return image, label

In [6]:
class Squash(nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        norm = torch.norm(x, p=2, dim=-1, keepdim=True)
        norm_squared = norm ** 2
        return (norm_squared / (1 + norm_squared)) * (x / (norm + self.eps))


In [7]:
class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_channels, capsule_dim, kernel_size=9, stride=2):
        super().__init__()
        self.capsule_dim = capsule_dim
        self.out_channels = out_channels

        self.conv = nn.Conv2d(in_channels, out_channels * capsule_dim,
                             kernel_size=kernel_size, stride=stride)
        self.squash = Squash()

    def forward(self, x):
        x = self.conv(x)
        batch_size = x.size(0)
        x = x.view(batch_size, self.out_channels, self.capsule_dim, -1)
        x = x.permute(0, 1, 3, 2).contiguous()  # [batch, out_channels, H'*W', capsule_dim]
        x = x.view(batch_size, -1, self.capsule_dim)  # [batch, num_capsules, capsule_dim]

        return self.squash(x)

In [80]:
class DynamicRouting(nn.Module):
    def __init__(self, in_capsules, out_capsules, in_dim, out_dim, num_iterations=3):
        super().__init__()
        self.in_capsules = in_capsules
        self.out_capsules = out_capsules
        self.num_iterations = num_iterations

        self.W = nn.Parameter(torch.randn(1, in_capsules, out_capsules, out_dim, in_dim))
        self.squash = Squash()

    def forward(self, x):
        batch_size = x.size(0)

        x = x.unsqueeze(2).unsqueeze(-1)

        W = self.W.expand(batch_size, -1, -1, -1, -1)

        u_hat = torch.matmul(W, x).squeeze(-1)
        b = torch.zeros(batch_size, self.in_capsules, self.out_capsules, 1).to(x.device)
        for iteration in range(self.num_iterations):
            c = F.softmax(b, dim=2)
            s = (c * u_hat).sum(dim=1)
            v = self.squash(s)
            if iteration < self.num_iterations - 1:
                v_expanded = v.unsqueeze(1)
                agreement = (u_hat * v_expanded).sum(dim=-1, keepdim=True)  # [batch, in_caps, out_caps, 1]
                b = b + agreement

        return v

In [81]:
class ReconstructionNetwork(nn.Module):
    def __init__(self, capsule_dim, num_classes, img_channels=3, img_size=46):
        super().__init__()
        self.num_classes = num_classes
        self.capsule_dim = capsule_dim

        output_size = img_channels * img_size * img_size

        self.fc1 = nn.Linear(capsule_dim * num_classes, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, output_size)

        self.img_channels = img_channels
        self.img_size = img_size

    def forward(self, x, labels=None):
        batch_size = x.size(0)
        if labels is not None:
            mask = torch.zeros_like(x)
            mask[torch.arange(batch_size), labels] = 1.0
        else:
            lengths = torch.norm(x, dim=-1)
            _, max_idx = lengths.max(dim=1)
            mask = torch.zeros_like(x)
            mask[torch.arange(batch_size), max_idx] = 1.0

        masked = x * mask
        masked = masked.view(batch_size, -1)

        x = F.relu(self.fc1(masked))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))

        x = x.view(batch_size, self.img_channels, self.img_size, self.img_size)
        return x

In [82]:
class CapsNet(nn.Module):
    def __init__(self, img_channels=3, img_size=46, num_classes=2,
                 primary_dim=8, primary_capsules=32, class_dim=16,
                 routing_iterations=3):
        super().__init__()

        self.num_classes = num_classes
        self.class_dim = class_dim

        self.conv1 = nn.Conv2d(img_channels, 256, kernel_size=9, stride=1)
        self.relu = nn.ReLU(inplace=True)
        conv1_out_size = img_size - 8  # kernel_size - 1

        self.primary_capsules = PrimaryCapsules(
            in_channels=256,
            out_channels=primary_capsules,
            capsule_dim=primary_dim,
            kernel_size=9,
            stride=2
        )

        primary_out_size = (conv1_out_size - 8) // 2  # (size - kernel + 1) / stride
        num_primary_caps = primary_capsules * (primary_out_size ** 2)

        self.digit_capsules = DynamicRouting(
            in_capsules=num_primary_caps,
            out_capsules=num_classes,
            in_dim=primary_dim,
            out_dim=class_dim,
            num_iterations=routing_iterations
        )

        self.reconstruction_net = ReconstructionNetwork(
            capsule_dim=class_dim,
            num_classes=num_classes,
            img_channels=img_channels,
            img_size=img_size
        )

    def forward(self, x, labels=None):
        x = self.conv1(x)
        x = self.relu(x)

        primary_caps = self.primary_capsules(x)

        class_caps = self.digit_capsules(primary_caps)

        reconstruction = self.reconstruction_net(class_caps, labels)

        return class_caps, reconstruction

    def predict(self, x):
        class_caps, _ = self.forward(x)
        lengths = torch.norm(class_caps, dim=-1)
        predictions = lengths.argmax(dim=1)

        return predictions, lengths

In [83]:
class CapsuleLoss(nn.Module):
    def __init__(self, num_classes=2, m_plus=0.9, m_minus=0.1,
                 lambda_val=0.5, reconstruction_weight=0.0005):
        super().__init__()
        self.num_classes = num_classes
        self.m_plus = m_plus
        self.m_minus = m_minus
        self.lambda_val = lambda_val
        self.reconstruction_weight = reconstruction_weight

    def forward(self, class_capsules, labels, reconstructions, images):
        batch_size = class_capsules.size(0)

        lengths = torch.norm(class_capsules, dim=-1)  # [batch_size, num_classes]

        labels_one_hot = F.one_hot(labels, num_classes=self.num_classes).float()
        present_loss = labels_one_hot * F.relu(self.m_plus - lengths) ** 2
        absent_loss = (1 - labels_one_hot) * F.relu(lengths - self.m_minus) ** 2

        margin_loss = (present_loss + self.lambda_val * absent_loss).sum(dim=1).mean()
        reconstruction_loss = F.mse_loss(reconstructions, images, reduction='mean')

        total_loss = margin_loss + self.reconstruction_weight * reconstruction_loss

        return total_loss, margin_loss, reconstruction_loss


In [84]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()

    total_loss = 0
    total_margin_loss = 0
    total_recon_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        class_caps, reconstructions = model(images, labels)

        loss, margin_loss, recon_loss = criterion(class_caps, labels, reconstructions, images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        predictions, _ = model.predict(images)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        total_loss += loss.item()
        total_margin_loss += margin_loss.item()
        total_recon_loss += recon_loss.item()

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100 * correct / total:.2f}%'
        })

    avg_loss = total_loss / len(dataloader)
    avg_margin_loss = total_margin_loss / len(dataloader)
    avg_recon_loss = total_recon_loss / len(dataloader)
    accuracy = 100 * correct / total

    return avg_loss, avg_margin_loss, avg_recon_loss, accuracy

In [85]:
def validate_epoch(model, dataloader, criterion, device):
    model.eval()

    total_loss = 0
    total_margin_loss = 0
    total_recon_loss = 0
    correct = 0
    total = 0

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            class_caps, reconstructions = model(images, labels)

            loss, margin_loss, recon_loss = criterion(class_caps, labels, reconstructions, images)

            predictions, _ = model.predict(images)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            total_loss += loss.item()
            total_margin_loss += margin_loss.item()
            total_recon_loss += recon_loss.item()

            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%'
            })

    avg_loss = total_loss / len(dataloader)
    avg_margin_loss = total_margin_loss / len(dataloader)
    avg_recon_loss = total_recon_loss / len(dataloader)
    accuracy = 100 * correct / total

    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)

    tp = ((all_predictions == 1) & (all_labels == 1)).sum()
    tn = ((all_predictions == 0) & (all_labels == 0)).sum()
    fp = ((all_predictions == 1) & (all_labels == 0)).sum()
    fn = ((all_predictions == 0) & (all_labels == 1)).sum()

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    metrics = {
        'accuracy': accuracy,
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100,
        'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn
    }

    return avg_loss, avg_margin_loss, avg_recon_loss, metrics

In [86]:
train_transform = transforms.Compose([
    transforms.CenterCrop(46),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
])

In [87]:
val_transform = transforms.Compose([
    transforms.CenterCrop(46),
    transforms.ToTensor(),
])

In [88]:
base_path = Path(histopathologic_cancer_detection_path)
train_dir = base_path / "train"
test_dir = base_path / "test"
labels_path = base_path / "train_labels.csv"

In [89]:
labels_df = pd.read_csv(labels_path)

In [90]:
len(labels_df)

220025

In [91]:
labels_df['label'].value_counts()

Unnamed: 0_level_0,count
label,Unnamed: 1_level_1
0,130908
1,89117


In [92]:
full_dataset = HistopathologicDataset(
    data_dir=train_dir,
    labels_df=labels_df,
    transform=train_transform,
    is_test=False
)


In [93]:
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)


In [94]:
val_dataset.dataset.transform = val_transform

In [95]:
len(train_dataset)

176020

In [96]:
len(val_dataset)

44005

In [97]:
batch_size = 32
num_workers = 2

In [98]:
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True if torch.cuda.is_available() else False
)

In [99]:
model = CapsNet(
    img_channels=3,
    img_size=46,
    num_classes=2,
    primary_dim=8,
    primary_capsules=32,
    class_dim=16,
    routing_iterations=3
).to(device)

In [100]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [101]:
total_params

14263244

In [102]:
trainable_params

14263244

In [103]:
criterion = CapsuleLoss(
    num_classes=2,
    m_plus=0.9,
    m_minus=0.1,
    lambda_val=0.5,
    reconstruction_weight=0.0005
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2,
)

In [104]:
num_epochs = 10
best_val_loss = float('inf')
best_val_acc = 0

In [105]:
history = {
    'train_loss': [], 'train_margin_loss': [], 'train_recon_loss': [], 'train_acc': [],
    'val_loss': [], 'val_margin_loss': [], 'val_recon_loss': [], 'val_acc': [],
    'val_precision': [], 'val_recall': [], 'val_f1': []
}


In [None]:
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    train_loss, train_margin, train_recon, train_acc = train_epoch(
        model, train_loader, optimizer, criterion, device
    )

    val_loss, val_margin, val_recon, val_metrics = validate_epoch(
        model, val_loader, criterion, device
    )

    scheduler.step(val_loss)

    history['train_loss'].append(train_loss)
    history['train_margin_loss'].append(train_margin)
    history['train_recon_loss'].append(train_recon)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_margin_loss'].append(val_margin)
    history['val_recon_loss'].append(val_recon)
    history['val_acc'].append(val_metrics['accuracy'])
    history['val_precision'].append(val_metrics['precision'])
    history['val_recall'].append(val_metrics['recall'])
    history['val_f1'].append(val_metrics['f1'])


    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {train_loss:.4f} (Margin: {train_margin:.4f}, Recon: {train_recon:.6f})")
    print(f"Train Accuracy: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} (Margin: {val_margin:.4f}, Recon: {val_recon:.6f})")
    print(f"Val Accuracy: {val_metrics['accuracy']:.2f}%")
    print(f"Val Precision: {val_metrics['precision']:.2f}%")
    print(f"Val Recall: {val_metrics['recall']:.2f}%")
    print(f"Val F1: {val_metrics['f1']:.2f}%")
    print(f"Val Confusion: TP={val_metrics['tp']}, TN={val_metrics['tn']}, FP={val_metrics['fp']}, FN={val_metrics['fn']}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_metrics['accuracy']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_metrics['accuracy'],
        }, 'best_capsnet_model.pth')
        print(f"Saved best model (Val Loss: {val_loss:.4f}, Val Acc: {val_metrics['accuracy']:.2f}%)")


Epoch 1/10


Training:   0%|          | 0/5501 [00:00<?, ?it/s]