BatchNorm for Domain Adaptation

In [1]:
import torch
import ignite.distributed as idist
from tqdm.notebook import tqdm

from uda.models import UNet, UNetConfig
from uda.datasets import CC359, CC359Config
from uda.metrics import dice_score

In [2]:
dataset_cfg = CC359Config.from_file("../config/dataset.yaml")
dataset_cfg.imsize = (192, 230, 230)
dataset_cfg.patch_size = (64, 115, 115)

dataset = CC359(dataset_cfg)
dataset.setup()
dataloader = dataset.val_dataloader(batch_size=4)

y_true = dataset.val_split.tensors[1]
y_true.shape

Loading files: 100%|████████████████████████████████████████| 9/9 [00:02<00:00,  4.33it/s]


torch.Size([2304, 1, 115, 115])

In [3]:
model_cfg = UNetConfig.from_file("../config/model.yaml")
# model = UNet.from_pretrained("/tmp/models/teacher/best_model.pt").to(idist.device())
model = UNet(model_cfg)

In [4]:
from uda.trainer import SegEvaluator
from ignite.handlers import EpochOutputStore
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from uda import pipe, sigmoid_round_output_transform, to_cpu_output_transform

In [5]:
evaluator = SegEvaluator(model)
ProgressBar(desc=f"Eval ({dataset.vendor})", persist=True).attach(evaluator)
eos = EpochOutputStore(
    output_transform=pipe(sigmoid_round_output_transform, to_cpu_output_transform)
)
eos.attach(evaluator, "output")

evaluator.run(dataset.val_dataloader(16))
preds, targets, data = [*zip(*evaluator.state.output)]

preds = torch.cat(preds).numpy()
targets = torch.cat(targets).numpy()
data = torch.cat(data).numpy()

preds.shape

2022-08-08 02:20:27,319 SegEvaluator INFO: Engine run starting with max_epochs=1.


Eval (GE_3)[1/144]   1%|           [00:00<?]

2022-08-08 02:20:28,154 SegEvaluator INFO: Epoch[1] Complete. Time taken: 00:00:01
2022-08-08 02:20:28,155 SegEvaluator INFO: Engine run complete. Time taken: 00:00:01


(2304, 1, 112, 112)

In [6]:
from uda.utils import reshape_to_volume

In [7]:
y_vol = reshape_to_volume(preds, model_cfg.dim, dataset.imsize, dataset.patch_size)
y_vol.shape

(3, 192, 224, 224)

In [9]:
y_vol = reshape_to_volume(preds.reshape(3, 1, 192, 224, 224), 3, dataset.imsize)
y_vol.shape

(3, 192, 224, 224)

In [None]:
for m in model.modules():
    if isinstance(m, torch.nn.BatchNorm2d):
        print(m)

In [6]:
@torch.no_grad()
def get_predictions(model, dataloader) -> torch.Tensor:
    model.to(idist.device())
    preds = torch.cat(
        [
            model(x.to(idist.device())).sigmoid().round().cpu()
            for x, _ in tqdm(dataloader, desc="Predicting")
        ]
    )
    model.cpu()
    return preds

Model is trained on domain `A` and will now be evaluated on domain `B`

Set the model to eval, as you would usually do when evaluating on a new domain

In [7]:
from uda.models.modules import center_crop_nd

In [8]:
model.eval()
y_pred = get_predictions(model, dataloader)

y_true = center_crop_nd(y_true, y_pred.shape[1:])
dice_score(y_pred, y_true)

Predicting:   0%|          | 0/144 [00:00<?, ?it/s]

tensor(0.0497)

Now we set the model to train and run the dataset once (still no gradients, no training - just running the model)

Then we gonna get our predictions in eval mode again

In [None]:
model.train()
y_pred = get_predictions(model, dataloader)

model.eval()
y_pred = get_predictions(model, dataloader)

dice_score(y_pred, y_true)

The results have improved a lot, only due to adapting the running stats of our BatchNorm to the new domain