In [1]:
from typing import Any, OrderedDict

import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Function

from torchvision import transforms as T
from torchvision.models.resnet import resnet50, ResNet50_Weights

import pytorch_lightning as L

from torchmetrics import Accuracy

from wilds import get_dataset
from wilds.common.data_loaders import get_eval_loader, get_train_loader
from wilds.common.grouper import CombinatorialGrouper

from utils import get_backbone_from_ckpt, DomainMapper, evaluate

<ol>
    <li>Reproduce domain shift problem and visualize</li>
    <li>Test L-DANN</li>
</ol>

To demonstrate the domain shift that is present in the chamelyon17 dataset we are evaluating the in-domain and out-of-domain accuracies:

In [2]:
BS = 256

# Load data
transform = T.Compose([
    T.ToTensor()
])

dataset = get_dataset("camelyon17", root_dir="../../data/")
train_set = dataset.get_subset("train", transform=transform, frac=1)#10_000/302436)
val_set = dataset.get_subset("id_val", transform=transform, frac=1) # in-domain
test_set = dataset.get_subset('test', transform=transform, frac=1)  # out-of-domain

val_set_s = dataset.get_subset("id_val", transform=transform, frac=1/5) # in-domain
test_set_s = dataset.get_subset('test', transform=transform, frac=1/20)  # out-of-domain

grouper = CombinatorialGrouper(dataset, ['hospital'])
dom_mapper = DomainMapper(train_set.metadata_array[:,0])

train_loader = get_train_loader("standard", train_set, grouper=grouper, uniform_over_groups=True, batch_size=BS, num_workers=8)
val_loader = get_eval_loader('standard', val_set, batch_size=BS, num_workers=8)
test_loader = get_eval_loader('standard', test_set, batch_size=BS, num_workers=8)

val_loader_s = get_eval_loader('standard', val_set_s, batch_size=BS, num_workers=8)
test_loader_s = get_eval_loader('standard', test_set_s, batch_size=BS, num_workers=8)

In [3]:
# Define model
class ReverseLayerF(Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

class LDANN(nn.Module):
    def __init__(self, weights, alpha=10.0) -> None:
        super().__init__()
        self.backbone = self._make_backbone(weights)
        # self.disc_head = nn.Linear(64+256+512+1024+2048, 3)
        self.disc_head = nn.Linear(2048, 3)

        self.crit_pred = nn.CrossEntropyLoss()
        self.crit_disc = nn.CrossEntropyLoss()

        self.alpha = alpha
    
    def forward(self, x):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        z0 = x.mean(dim=[2,3])

        x = self.backbone.layer1(x)

        z1 = x.mean(dim=[2,3])

        x = self.backbone.layer2(x)
        
        z2 = x.mean(dim=[2,3])

        x = self.backbone.layer3(x)
        
        z3 = x.mean(dim=[2,3])

        x = self.backbone.layer4(x)

        z4 = x.mean(dim=[2,3])

        x = self.backbone.avgpool(x)

        f = torch.flatten(x, 1)

        # z_ = torch.cat([z0, z1, z2, z3, z4], dim=1)
        z_ = f
        z_ = ReverseLayerF.apply(z_, self.alpha)

        y = self.backbone.fc(f)
        z = self.disc_head(z_)

        return y, z
    
    def embed(self, x: torch.Tensor):
        x = self.backbone.conv1(x)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        z0 = x.mean(dim=[2,3])

        x = self.backbone.layer1(x)
        z1 = x.mean(dim=[2,3])

        x = self.backbone.layer2(x)
        z2 = x.mean(dim=[2,3])
        
        x = self.backbone.layer3(x)
        z3 = x.mean(dim=[2,3])
        
        x = self.backbone.layer4(x)
        z4 = x.mean(dim=[2,3])

        x = self.backbone.avgpool(x).squeeze()
    
        return {
            "z0": z0,
            "z1": z1,
            "z2": z2,
            "z3": z3,
            "z4": z4,
            "x": x
        }
    
    def _make_backbone(self, weights):
        if weights == "scratch":
            backbone = resnet50(num_classes=2)
        elif weights == "ImageNet":
            backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
            backbone.fc = nn.Linear(2048, 2)
        else:
            backbone = resnet50(num_classes=2)
            print(weights)
            sd = get_backbone_from_ckpt(weights)
            missing_keys, unexpected_keys = backbone.load_state_dict(sd, strict=False)
            print("missing:", missing_keys, "unexpected:", unexpected_keys)

        return backbone

class SimpleCNN(L.LightningModule):
    def __init__(self, model, grouper, dom_mapper, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

        self.model = model

        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(params=self.model.backbone.fc.parameters(), lr=1e-4)
        self.metric = Accuracy(num_classes=2, task='multiclass')

        self.grouper: CombinatorialGrouper = grouper
        self.dom_mapper: DomainMapper = dom_mapper

        self.alpha = self.model.alpha

    def training_step(self, batch, batch_idx):
        X, t, M = batch

        y, z = self.model(X)
        
        loss_y = self.criterion(y, t)
        self.log("loss_y", loss_y)

        d = self.grouper.metadata_to_group(M.cpu())
        d = self.dom_mapper(d).cuda()

        loss_d = self.criterion(z, d)        
        self.log("loss_d", loss_d)

        if self.alpha > 0:
            return loss_y + loss_d
        else:
            return loss_y
    
    def validation_step(self, batch, batch_idx, dataloader_idx):
        X, t, _ = batch

        y, _ = self.model(X)
        
        loss = self.criterion(y, t)

        self.log('val/loss', loss)

        return loss
    
    def test_step(self, batch, batch_idx, dataloader_idx):
        X, t, _ = batch

        y, _ = self.model(X)
        y = y.argmax(dim=1)
        
        accuracy = self.metric(y, t)

        self.log('accuracy', accuracy)
    
    def configure_optimizers(self) -> Any:
        return self.optimizer

In [4]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [5]:
trainer = L.Trainer(
    accelerator="auto", 
    max_epochs=1,
    val_check_interval=len(train_loader)
)

# plain_resnet = resnet18(num_classes=2)
ldann = LDANN(weights='ImageNet', alpha=0.0)
model = SimpleCNN(
    model=ldann,
    grouper=grouper,
    dom_mapper=dom_mapper
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=[val_loader_s, test_loader_s],
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | LDANN              | 23.5 M
1 | criterion | CrossEntropyLoss   | 0     
2 | metric    | MulticlassAccuracy | 0     
-------------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.073    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [7]:
trainer.test(
    model=model,
    dataloaders={
        "val": val_loader,
        "test": test_loader,
    },
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'accuracy/dataloader_idx_0': 0.9218116998672485},
 {'accuracy/dataloader_idx_1': 0.7648199796676636}]

In [8]:
# from torch.utils.data import DataLoader
# from tqdm import tqdm

In [9]:
# train_loader.batch_size

In [10]:
# @torch.no_grad()
# def compute_embeddings(model: nn.Module, train_loader: DataLoader, test_loader: DataLoader):
#     num_train_samples = len(train_loader.dataset)
#     num_test_samples  = len(test_loader.dataset)
    
#     train = {
#         'embeddings': torch.empty(num_train_samples, 2048),
#         'labels'    : torch.empty(num_train_samples)
#     }
#     BS = train_loader.batch_size
#     for i, (X, t, _) in enumerate(tqdm(train_loader)):
#         bs = X.shape[0]
#         train['embeddings'][i*BS:i*BS+bs] = model(X.cuda()).cpu()
#         train['labels'][i*BS:i*BS+bs]     = t

#     test = {
#         'embeddings': torch.empty(num_test_samples, 2048),
#         'labels'    : torch.empty(num_test_samples)
#     }
#     BS = test_loader.batch_size
#     for i, (X, t, _) in enumerate(tqdm(test_loader)):
#         bs = X.shape[0]
#         test['embeddings'][i*BS:i*BS+bs] = model(X.cuda()).cpu()
#         test['labels'][i*BS:i*BS+bs]     = t

#     return train, test

In [11]:
# model_ = model.model.backbone
# model_.fc = nn.Identity()
# model_ = model_.cuda()
# train_embs, test_embs = compute_embeddings(model_, train_loader, test_loader)

In [12]:
# from sklearn.preprocessing import StandardScaler
# from sklearn.linear_model import LogisticRegression
# from sklearn.metrics import classification_report
# from sklearn.pipeline import make_pipeline

In [13]:
# lor_b = make_pipeline(
#         StandardScaler(), 
#         LogisticRegression(max_iter=10_000, verbose=1,
#     )
# )

In [14]:
# lor_b.fit(train_embs['embeddings'], train_embs['labels'])