BatchNorm for Domain Adaptation

In [1]:
import torch
import ignite.distributed as idist
from tqdm.notebook import tqdm

from uda.models import UNet
from uda.datasets import CC359
from uda.metrics import dice_score

In [5]:
dataset = CC359.from_preconfigured("../config/cc359.yaml")
dataset.setup()
dataloader = dataset.val_dataloader(batch_size=4)

y_true = dataset.val_split.tensors[1]

Loading files: 100%|██████████████████████████████████████| 60/60 [00:11<00:00,  5.44it/s]


In [2]:
model = UNet.from_pretrained("/tmp/models/teacher-model.pt").to(idist.device())

In [3]:
for m in model.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        print(m)

BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
BatchNorm2d(128, eps=1e-05,

In [4]:
@torch.no_grad()
def get_predictions(model, dataloader) -> torch.Tensor:
    return torch.cat(
        [
            model(x.to(idist.device())).sigmoid().round().cpu()
            for x, _ in tqdm(dataloader, desc="Predicting")
        ]
    )

Model is trained on domain `A` and will now be evaluated on domain `B`

Set the model to eval, as you would usually do when evaluating on a new domain

In [6]:
model.eval()
y_pred = get_predictions(model, dataloader)

dice_score(y_pred, y_true)

Predicting:   0%|          | 0/960 [00:00<?, ?it/s]

tensor(0.7680)

Now we set the model to train and run the dataset once (still no gradients, no training - just running the model)

Then we gonna get our predictions in eval mode again

In [7]:
model.train()
y_pred = get_predictions(model, dataloader)

model.eval()
y_pred = get_predictions(model, dataloader)

dice_score(y_pred, y_true)

Predicting:   0%|          | 0/960 [00:00<?, ?it/s]

Predicting:   0%|          | 0/960 [00:00<?, ?it/s]

tensor(0.7680)

The results have improved a lot, only due to adapting the running stats of our BatchNorm to the new domain