In [1]:
import torch
import torch.nn as nn
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from copy import deepcopy
torch.set_float32_matmul_precision("medium")
from pathlib import Path
import numpy as np

In [2]:
import matplotlib.pyplot as plt

def plot(gt, pred, pred_refined, id):

    batch_size = gt.size(0)
    for sample in range(batch_size):
        # gt = gt.sum(dim=0)
        # pred = pred.sum(dim=0)
        # pred_refined = pred.sum(dim=0)

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Display each image with its corresponding title
        images = [gt, pred, pred_refined]
        titles = ["gt", "pred", "pred after ttt"]

        for ax, img, title in zip(axes, images, titles):
            ax.imshow(img[sample, 0].cpu().detach().sum(dim=0))
            ax.set_title(title)
            ax.axis("off")  # Hide axes for better visual appearance

        # Adjust layout and save the figure
        plt.tight_layout()
        plt.savefig(Path("./plots") / f"{id[sample]}.png", bbox_inches='tight')
        plt.close()


In [3]:
criterion = DiceLoss(sigmoid=True)
scoring_fn = DiceMetric()
cross_entropy = nn.CrossEntropyLoss()
device = torch.device("cuda:2")
device_1 = torch.device("cuda:1")

In [4]:
def ttt_one_batch(batch, model, model_ttt, optimizer, n_steps):
    model_ttt.load_state_dict(model.state_dict())
    model_ttt.train()
    for step in range(n_steps):

        optimizer.zero_grad()

        y_hat_0 = model_ttt.model(torch.Tensor(batch["image"]).to(device), y=None)
        y_hat_1 = model.model(torch.Tensor(batch["image"]).to(device), y=y_hat_0)

        # values = torch.Tensor([1])
        loss = criterion(y_hat_0, (nn.functional.sigmoid(y_hat_1) > 0.5).int())
        # loss = cross_entropy(y_hat_0, y_hat_1)
        loss.backward()
        optimizer.step()
        # print(f"Step {step}: {loss:.2f}")

    return y_hat_0, y_hat_1

In [5]:
def ttt(model, test_loader, n_steps, lr):

    # model_ttt = deepcopy(model)
    model.eval()
    model.to(device)
    # model_ttt.eval()
    # model_ttt.to(device)

    # optimizer = torch.optim.Adam(model_ttt.parameters(), lr=lr)
    test_loss_0, test_loss_1 = 0., 0.

    score_0, score_1 = [], []
    score_0_refined, score_1_refined = [], []

    for i, batch in enumerate(test_loader):
        x, y = torch.Tensor(batch["image"]), torch.Tensor(batch["label"])
        print(batch["id"])

        model_ttt = deepcopy(model)
        model_ttt.eval()
        model_ttt.to(device)
        optimizer = torch.optim.SGD(model_ttt.parameters(), lr=lr, momentum=0.9)
        # optimizer = torch.optim.Adam(model_ttt.parameters(), lr=lr)

        y_hat_0 = model_ttt.model(torch.Tensor(batch["image"]).to(device), y=None)
        y_hat_1 = model.model(torch.Tensor(batch["image"]).to(device), y=y_hat_0.to(device))
        
        score_0.append((1 - criterion(y_hat_0, y.to(device))).detach().item())
        score_1.append((1 - criterion(y_hat_1.to(device), y.to(device))).detach().item())
        # print(f"y_hat_0: {1 - criterion(y_hat_0, y.to(device)):.2f}; y_hat_1: {1 - criterion(y_hat_1, y.to(device)):.2f}")

        y_hat_0_refined, y_hat_1_refined = ttt_one_batch(batch, model, model_ttt, optimizer, n_steps)

        # # print(f"RES y_hat_0: {1 - criterion(y_hat_0, y.to(device)):.2f}; y_hat_1: {1 - criterion(y_hat_1, y.to(device)):.2f}")
        model_ttt.eval()
        score_0_refined.append((1 - criterion(y_hat_0_refined, y.to(device))).detach().item())
        score_1_refined.append((1 - criterion(y_hat_1_refined.to(device), y.to(device))).detach().item())

        plot(y, y_hat_0, y_hat_0_refined, batch["id"])

    print(f"score_0: {sum(score_0) / len(score_0)}, std: {np.array(score_0).std()}")
    print(f"score_1: {sum(score_1) / len(score_1)}, std: {np.array(score_1).std()}")
    print(f"score_0_refined: {sum(score_0_refined) / len(score_0_refined)}, std: {np.array(score_0_refined).std()}")
    print(f"score_1_refined: {sum(score_1_refined) / len(score_1_refined)}, std: {np.array(score_1_refined).std()}")

In [4]:
from datasets import build_dataset
from models import build_model
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import OmegaConf

with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

# model = build_model(cfg)
test_loader = torch.utils.data.DataLoader(build_dataset(cfg, test=True), num_workers=16, batch_size=1)


['exp_name', 'ckpt_path', 'seed', 'debug', 'devices', 'load_num_workers', 'log_every_n_steps', 'method', 'model_name', 'train_batch_size', 'eval_batch_size', 'test_batch_size', 'dataset', 'depth', 'initial_features', 'encoder_dropout', 'decoder_dropout', 'BN', 'elu', 'learning_rate', 'base_lr', 'max_lr', 'step_size', 'decay_milestones', 'decay_gamma', 'weight_decay', 'max_epochs', 'grad_clip_norm', 'train_data_root_dir', 'val_data_root_dir', 'test_data_root_dir', 'normalize_data']


In [5]:
from models.deepict_unet3d_ttt import UNet3D_Lightning_ITTT

model = UNet3D_Lightning_ITTT.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

In [8]:
ttt(model, test_loader, 10, 0.0005)

['034_subtomo_[128 128 360]']
['034_subtomo_[128 128 592]']
['034_subtomo_[128 128 824]']
['034_subtomo_[128 360 128]']
['034_subtomo_[128 360 360]']
['034_subtomo_[128 360 592]']
['034_subtomo_[128 360 824]']
['034_subtomo_[128 592 128]']
['034_subtomo_[128 592 360]']
['034_subtomo_[128 592 592]']
['034_subtomo_[128 592 824]']
['034_subtomo_[128 824 128]']
['034_subtomo_[128 824 360]']
['034_subtomo_[128 824 592]']
['034_subtomo_[128 824 824]']
['030_subtomo_[128 128 592]']
['030_subtomo_[128 128 824]']
['030_subtomo_[128 360 360]']
['030_subtomo_[128 360 592]']
['030_subtomo_[128 360 824]']
['030_subtomo_[128 592 128]']
['030_subtomo_[128 592 360]']
['030_subtomo_[128 592 592]']
['030_subtomo_[128 592 824]']
['030_subtomo_[128 824 128]']
['030_subtomo_[128 824 360]']
['030_subtomo_[128 824 592]']
['030_subtomo_[128 824 824]']
['037_subtomo_[128 128 360]']
['037_subtomo_[128 128 592]']
['037_subtomo_[128 360 128]']
['037_subtomo_[128 360 360]']
['037_subtomo_[128 360 592]']
['037_subt

In [6]:
def inference_train_mode(model, test_loader, n_steps, lr):

    model.train()
    model.to(device)

    score_0, score_1 = [], []
    score_0_refined, score_1_refined = [], []

    for i, batch in enumerate(test_loader):
        x, y = torch.Tensor(batch["image"]), torch.Tensor(batch["label"])
        print(batch["id"])


        y_hat_0 = model.model(x.to(device), y=None)
        y_hat_1 = model.model(x.to(device), y=y_hat_0)
        
        score_0.append((1 - criterion(y_hat_0, y.to(device))).detach().item())
        score_1.append((1 - criterion(y_hat_1, y.to(device))).detach().item())

    print(f"score_0: {sum(score_0) / len(score_0)}, std: {np.array(score_0).std()}")
    print(f"score_1: {sum(score_1) / len(score_1)}, std: {np.array(score_1).std()}")

In [7]:
inference_train_mode(model, test_loader, 20, 0.0005)

['034_subtomo_[128 128 360]']
['034_subtomo_[128 128 592]']
['034_subtomo_[128 128 824]']
['034_subtomo_[128 360 128]']
['034_subtomo_[128 360 360]']
['034_subtomo_[128 360 592]']
['034_subtomo_[128 360 824]']
['034_subtomo_[128 592 128]']
['034_subtomo_[128 592 360]']
['034_subtomo_[128 592 592]']
['034_subtomo_[128 592 824]']
['034_subtomo_[128 824 128]']
['034_subtomo_[128 824 360]']
['034_subtomo_[128 824 592]']
['034_subtomo_[128 824 824]']
['030_subtomo_[128 128 592]']
['030_subtomo_[128 128 824]']
['030_subtomo_[128 360 360]']
['030_subtomo_[128 360 592]']
['030_subtomo_[128 360 824]']
['030_subtomo_[128 592 128]']
['030_subtomo_[128 592 360]']
['030_subtomo_[128 592 592]']
['030_subtomo_[128 592 824]']
['030_subtomo_[128 824 128]']
['030_subtomo_[128 824 360]']
['030_subtomo_[128 824 592]']
['030_subtomo_[128 824 824]']
['037_subtomo_[128 128 360]']
['037_subtomo_[128 128 592]']
['037_subtomo_[128 360 128]']
['037_subtomo_[128 360 360]']
['037_subtomo_[128 360 592]']
['037_subt