In [1]:
from typing import Any, Dict, List

import torch
from torch import nn, optim
from torchvision.models.resnet import resnet18
from torchvision import transforms as T
import pytorch_lightning as pl
from pytorch_lightning.utilities.types import STEP_OUTPUT

from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader

  warn(


In [2]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [3]:
dataset = get_dataset(
    "camelyon17",
    root_dir="../data",
)

In [4]:
train_set = dataset.get_subset(
    split="train",
    transform=transform
)
val_set = dataset.get_subset(
    split="id_val",
    transform=transform
)

In [5]:
train_loader = get_train_loader("standard", train_set, batch_size=64)
val_loader = get_eval_loader("standard", val_set, batch_size=64)

In [6]:
class R50(pl.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.model = resnet18(num_classes=2)
        self.criterion = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        X, t, _ = batch
        
        y = self.model(X)
        loss = self.criterion(y, t)

        return loss
    
    def configure_optimizers(self) -> Any:
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)

        return {
            "optimizer": optimizer,
        }

In [7]:
trainer = pl.Trainer(
    max_steps=1,
    accelerator="auto"
)
trainer.fit(
    R50(),
    train_loader,
    val_loader
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | ResNet           | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.710    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [8]:
class DomainMapper():
    def __init__(self, domains: List) -> None:
        self.unique_domains = domains.unique()
        self.map_dict: Dict = {self.unique_domains[i].item():i for i in range(len(self.unique_domains))}
        self.unmap_dict: Dict = dict((v, k) for k, v in self.map_dict.items())

    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        return self.map(x)

    def map(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tensor([self.map_dict[v.item()] for v in x])
    
    def unmap(self, x: torch.Tensor) -> torch.Tensor:
        return torch.tensor([self.unmap_dict[v.item()] for v in x])

In [9]:
dom_mapper = DomainMapper(train_set.metadata_array[:,0])
dom_mapper.map_dict, dom_mapper.unmap_dict

({0: 0, 3: 1, 4: 2}, {0: 0, 1: 3, 2: 4})

In [10]:
torch.all(dom_mapper.unmap(dom_mapper(train_set.metadata_array[:,0])) == train_set.metadata_array[:,0])

tensor(True)

In [11]:
class LinearDomainClf(pl.LightningModule):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.linear_head = nn.Linear(in_features=2048, out_features=3) 
        self.criterion = nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        X, t, _ = batch

        y = self.linear_head(X)
        loss = self.criterion(y, t)

        return loss
    
    def configure_optimizers(self) -> Any:
        optimizer = optim.AdamW(self.parameters(), lr=1e-3)

        return {
            "optimizer": optimizer,
        }