In [1]:
import torch
from models import tent
import torch.nn as nn
import torch.nn.functional as F
from monai.losses import DiceLoss
import numpy as np

## Init

In [2]:
criterion = DiceLoss(sigmoid=True)
cross_entropy = nn.CrossEntropyLoss()
device = torch.device("cuda:2")

In [27]:
from datasets import build_dataset
from models import build_model
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf
from models.deepict_unet3d import UNet3D_Lightning


with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

test_loader = torch.utils.data.DataLoader(build_dataset(cfg, test=True), num_workers=16, batch_size=5)

model = UNet3D_Lightning.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

['exp_name', 'ckpt_path', 'seed', 'debug', 'devices', 'load_num_workers', 'log_every_n_steps', 'method', 'model_name', 'train_batch_size', 'eval_batch_size', 'test_batch_size', 'dataset', 'depth', 'initial_features', 'encoder_dropout', 'decoder_dropout', 'BN', 'elu', 'learning_rate', 'decay_milestones', 'decay_gamma', 'weight_decay', 'max_epochs', 'grad_clip_norm', 'train_data_root_dir', 'val_data_root_dir', 'test_data_root_dir', 'normalize_data']


## Inference with BN in train-mode

In [4]:
def configure_model_BN(model, eps, momentum, reset_stats, no_stats):
    """Configure model for adaptation by test-time normalization."""
    for m in model.modules():
        if isinstance(m, nn.BatchNorm3d):
            # use batch-wise statistics in forward
            m.train()
            # configure epsilon for stability, and momentum for updates
            m.eps = eps
            m.momentum = momentum
            if reset_stats:
                # reset state to estimate test stats without train stats
                m.reset_running_stats()
            if no_stats:
                # disable state entirely and use only batch stats
                m.track_running_stats = False
                m.running_mean = None
                m.running_var = None
    return model

def inference_train_mode(model, test_loader):

    model.eval()
    model = configure_model_BN(model, eps=1e-5, momentum=0.1, reset_stats=True, no_stats=True)
    
    score_0 = []

    for i, batch in enumerate(test_loader):
        x, y = torch.Tensor(batch["image"]), torch.Tensor(batch["label"])

        y_hat_0, _ = model.model(x.to(device))
        dice_score = 1 - criterion(y_hat_0, y.to(device)).detach().item()

        score_0.append(dice_score)

    print(f"score_0: {sum(score_0) / len(score_0):.2f}, std: {np.array(score_0).std():.2f}")


In [5]:
inference_train_mode(model, test_loader)

score_0: 0.36, std: 0.22


## Tent inference (and training)

In [28]:
model.model.eval()
tent_model = tent.configure_model(model.model)
params, param_names = tent.collect_params(tent_model)
# optimizer = torch.optim.Adam(tent_model.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(tent_model.parameters(), lr=2e-4, momentum=0.9)
tent_model = tent.Tent(tent_model, optimizer,
                        steps=5,
                        episodic=True)

In [29]:
def inference_tent(model, test_loader):
    
    score_0 = []

    for i, batch in enumerate(test_loader):
        x, y = torch.Tensor(batch["image"]), torch.Tensor(batch["label"])

        y_hat_0 = model(x.to(device))
        dice_score = 1 - criterion(y_hat_0, y.to(device)).detach().item()

        score_0.append(dice_score)
    print(f"score_0: {sum(score_0) / len(score_0):.2f}, std: {np.array(score_0).std():.2f}")


In [30]:
torch.autograd.set_detect_anomaly(True)
inference_tent(tent_model, test_loader)

score_0: 0.36, std: 0.15
