BatchNorm for Domain Adaptation

In [1]:
import torch
import ignite.distributed as idist
from tqdm.notebook import tqdm

from uda.models import UNet, UNetConfig
from uda.datasets import CC359, CC359Config
from uda.metrics import dice_score

In [30]:
dataset_cfg = CC359Config.from_file("../config/dataset.yaml")
dataset_cfg.patch_size = (64, 115, 115)

dataset = CC359(dataset_cfg)
dataset.setup()
dataloader = dataset.val_dataloader(batch_size=4)

y_true = dataset.val_split.tensors[1]
y_true.shape

Loading files: 100%|████████████████████████████████████████| 9/9 [00:01<00:00,  4.95it/s]


torch.Size([2304, 1, 115, 115])

In [3]:
model_cfg = UNetConfig.from_file("../config/model.yaml")
# model = UNet.from_pretrained("/tmp/models/teacher/best_model.pt").to(idist.device())
model = UNet(model_cfg)

In [4]:
from uda.trainer import SegEvaluator
from ignite.handlers import EpochOutputStore
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from uda import pipe, sigmoid_round_output_transform, to_cpu_output_transform

In [8]:
evaluator = SegEvaluator(model)
ProgressBar(desc=f"Eval ({dataset.config.vendor})", persist=True).attach(evaluator)
eos = EpochOutputStore(
    output_transform=pipe(sigmoid_round_output_transform, to_cpu_output_transform)
)
eos.attach(evaluator, "output")

evaluator.run(dataset.val_dataloader(16))
preds, targets, data = [*zip(*evaluator.state.output)]

preds = torch.cat(preds).numpy()
targets = torch.cat(targets).numpy()
data = torch.cat(data).numpy()

preds.shape

2022-08-07 13:08:40,084 SegEvaluator INFO: Engine run starting with max_epochs=1.


Eval (GE_3)[1/144]   1%|           [00:00<?]

2022-08-07 13:08:40,829 SegEvaluator INFO: Epoch[1] Complete. Time taken: 00:00:01
2022-08-07 13:08:40,829 SegEvaluator INFO: Engine run complete. Time taken: 00:00:01


(2304, 1, 112, 112)

In [27]:
from collections.abc import Callable
from typing import Any, Optional, Union

import numpy as np
import torch
from ignite.utils import to_onehot
from patchify import unpatchify


def reshape_to_volume(
    data: Union[np.ndarray, torch.Tensor], dim: int, imsize: tuple[int, int, int], patch_size: Optional[tuple[int, int, int]]
) -> Union[np.ndarray, torch.Tensor]:
    # check if torch.Tensor (patchify uses numpy backend)
    if isinstance(data, torch.Tensor):
        data = data.numpy()
        output_type_tensor = True
    else:
        output_type_tensor = False

    # unpatchify if data is patchified
    if patch_size is not None:
        # compute number of patches for each axis
        n_patches = [axis_size // patch_size for axis_size, patch_size in zip(imsize, patch_size)]
        print(n_patches)
        cropped_patch_size = patch_size[:-dim] + data.shape[-dim:]  # if data was cropped due to down/upsampling inaccuracies
        print(cropped_patch_size)
        cropped_imsize = [ps * np for ps, np in zip(cropped_patch_size, n_patches)]
        print(cropped_imsize)
        batch_size = int(data.shape[0] // (np.prod(n_patches) * np.prod(imsize[:-dim])))
        print("bs", batch_size)
        print(data.shape)
        # subsume batch_size in first patch axis (z-axis)
        data = data.reshape(batch_size * n_patches[0], *n_patches[1:], *cropped_patch_size)
        print(data.shape)
        # unpatchify (subsume batch_size in first image axis)
        data = unpatchify(data, imsize=(batch_size * cropped_imsize[0], *cropped_imsize[1:]))
    else:
        cropped_imsize = imsize[:-dim] + data.shape[-dim:]

    # extract batch_size in first axis
    data = data.reshape(-1, *cropped_imsize)

    return torch.from_numpy(data) if output_type_tensor else data

In [28]:
y_vol = reshape_to_volume(preds.reshape(12, 1, 192, 112, 112), 3, dataset.imsize, dataset.patch_size)
y_vol.shape

[1, 2, 2]
(192, 112, 112)
[192, 224, 224]
bs 3
(12, 1, 192, 112, 112)
(3, 2, 2, 192, 112, 112)


(3, 192, 224, 224)

In [29]:
y_vol = reshape_to_volume(preds, model_cfg.dim, dataset.imsize, dataset.patch_size)
y_vol.shape

[1, 2, 2]
(192, 112, 112)
[192, 224, 224]
bs 3
(2304, 1, 112, 112)
(3, 2, 2, 192, 112, 112)


(3, 192, 224, 224)

In [10]:
y_pred.numpy().shape

(16, 1, 112, 112)

In [None]:
for m in model.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        print(m)

In [6]:
@torch.no_grad()
def get_predictions(model, dataloader) -> torch.Tensor:
    model.to(idist.device())
    preds = torch.cat(
        [
            model(x.to(idist.device())).sigmoid().round().cpu()
            for x, _ in tqdm(dataloader, desc="Predicting")
        ]
    )
    model.cpu()
    return preds

Model is trained on domain `A` and will now be evaluated on domain `B`

Set the model to eval, as you would usually do when evaluating on a new domain

In [7]:
from uda.models.modules import center_crop_nd

In [8]:
model.eval()
y_pred = get_predictions(model, dataloader)

y_true = center_crop_nd(y_true, y_pred.shape[1:])
dice_score(y_pred, y_true)

Predicting:   0%|          | 0/144 [00:00<?, ?it/s]

tensor(0.0497)

Now we set the model to train and run the dataset once (still no gradients, no training - just running the model)

Then we gonna get our predictions in eval mode again

In [None]:
model.train()
y_pred = get_predictions(model, dataloader)

model.eval()
y_pred = get_predictions(model, dataloader)

dice_score(y_pred, y_true)

The results have improved a lot, only due to adapting the running stats of our BatchNorm to the new domain