In [18]:
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import transforms as T

from torchmetrics import Accuracy

import pytorch_lightning as L
from pytorch_lightning.utilities.types import STEP_OUTPUT

from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader, get_train_loader
from wilds.common.grouper import CombinatorialGrouper

In [29]:
class CNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.conv3 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def embed(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        
        c = x.mean(dim=[2,3])
        
        return c

In [19]:
transform = T.Compose([
    T.ToTensor()
])

In [25]:
class SimpleCNN(L.LightningModule):
    def __init__(self, in_channels=3, num_cls=2,*args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.model = CNN(in_channels=in_channels, num_classes=num_cls)

        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(params=self.parameters(), lr=1e-4)
        self.metric = Accuracy(num_classes=2, task='multiclass')

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        X, t, _ = batch

        y = self.model(X)
        loss = self.criterion(y, t)

        self.log("loss", loss)

        return loss
    
    def validation_step(self, batch, batch_idx) -> STEP_OUTPUT:
        X, t, _ = batch

        y = self.model(X).argmax(dim=1)
        
        accuracy = self.metric(y, t)

        self.log('accuracy', accuracy)

        return accuracy
    
    def configure_optimizers(self) -> Any:
        return self.optimizer

In [26]:
dataset = get_dataset("camelyon17", root_dir="../../data/")
train_set = dataset.get_subset("train", transform=transform)
val_set = dataset.get_subset("id_val", transform=transform)
test_set = dataset.get_subset('test', transform=transform)

grouper = CombinatorialGrouper(dataset, ['hospital'])

train_loader = get_train_loader("standard", train_set, grouper=grouper,uniform_over_groups=True, batch_size=128, num_workers=8)
val_loader = get_eval_loader('standard', val_set, batch_size=64, num_workers=8)
test_loader = get_eval_loader('standard', test_set, batch_size=64, num_workers=8)



In [27]:
trainer = L.Trainer(
    accelerator="auto", 
    max_epochs=5, 
    val_check_interval=int(len(train_loader)/3)
)

model = SimpleCNN(in_channels=3, num_cls=2)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)


  | Name      | Type               | Params
-------------------------------------------------
0 | model     | CNN                | 271 K 
1 | criterion | CrossEntropyLoss   | 0     
2 | metric    | MulticlassAccuracy | 0     
-------------------------------------------------
271 K     Trainable params
0         Non-trainable params
271 K     Total params
1.088     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
