In [1]:
import torch
import torch.nn as nn
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from copy import deepcopy
torch.set_float32_matmul_precision("medium")
from pathlib import Path
import numpy as np
from utils.utils import set_seed
from utils.subtomos import reassemble_subtomograms_v2
from torchmetrics.classification import BinaryConfusionMatrix
import wandb
import matplotlib.pyplot as plt
import torchvision.transforms.functional as FT
import mrcfile

In [2]:
from datasets import build_dataset
# from models import build_model
from models.denoiseg import Denoiseg
from hydra import initialize, compose
from omegaconf import OmegaConf


with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

set_seed(cfg.seed)

criterion = DiceLoss(sigmoid=True)
scoring_fn = DiceMetric()
cross_entropy = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss(reduction='none')
# mse_loss = nn.L1Loss(reduction='none')
device = torch.device("cuda:3")
device_1 = torch.device("cuda:1")
bcm = BinaryConfusionMatrix().to(device)

dataset = build_dataset(cfg, test=True)

test_dataset, test_val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 2, 2])
print(len(test_dataset), len(test_val_dataset))
# test_dataset = dataset
# test_val_dataset = build_dataset(cfg, val=True)
test_loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.load_num_workers,
    batch_size=1,
    shuffle=True,
    persistent_workers=cfg.persistent_workers,
)            
test_val_loader = torch.utils.data.DataLoader(
    test_val_dataset,
    num_workers=cfg.load_num_workers,
    batch_size=2,
    persistent_workers=cfg.persistent_workers,
)            

test_entire_gt = mrcfile.read(cfg.test_entire_gt)
test_entire_gt = torch.Tensor(test_entire_gt)[None, None, ...]

model = Denoiseg.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

Seed set to 42


13 2


In [3]:
import matplotlib.pyplot as plt

def plot(gt, pred, pred_refined, id):

    batch_size = gt.size(0)
    for sample in range(batch_size):
        # gt = gt.sum(dim=0)
        # pred = pred.sum(dim=0)
        # pred_refined = pred.sum(dim=0)
        diff = torch.abs(pred_refined - pred)

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))

        # Display each image with its corresponding title
        images = [gt, pred, pred_refined, diff]
        titles = ["gt", "pred", "pred after ttt", "diff"]

        for ax, img, title in zip(axes, images, titles):
            ax.imshow(img[sample, 0].cpu().detach().sum(dim=0))
            ax.set_title(title)
            ax.axis("off")  # Hide axes for better visual appearance

        # Adjust layout and save the figure
        plt.tight_layout()
        plt.savefig(Path("./plots") / f"{id[sample]}.png", bbox_inches='tight')
        plt.close()

In [4]:
import torch
import random
import torch.nn.functional as F

# Step 1: Define the transformation and its inverse
def random_transform_with_inverse(tensor):
    transformation = random.choice(['rotate', 'mirror'])
    inverse_function = None

    if transformation == 'rotate':
        k = random.choice([1, 2, 3])
        transformed_tensor = torch.rot90(tensor, k, [3, 4])
        inverse_function = lambda x: torch.rot90(x, -k, [3, 4])

    elif transformation == 'mirror':
        axis = random.choice([3, 4])
        transformed_tensor = torch.flip(tensor, [axis])
        inverse_function = lambda x: torch.flip(x, [axis])

    return transformed_tensor, inverse_function

In [5]:
# def logging(run, dice_loss, dice_coef, )

In [6]:
def masked_mse_loss(x_hat, x, mask):
    loss = mse_loss(x_hat, x) * mask
    return loss.sum() / mask.sum()

def dice_from_conf_matrix(conf_matrix):
    return 2 * conf_matrix[1, 1] / (2 * conf_matrix[1, 1] + conf_matrix[0, 1] + conf_matrix[1, 0])

def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm


In [7]:
def ttt_one_batch(x_masked, x_unmasked, mask, model, n_steps, lr, momentum, validation=False, y=None, id=None, test_val_loader=None):
    
    # run = wandb.init(project="cryo-ttt-ttt", name=id[0])
    val_losses = []
    val_dices = []
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)
    # run.log(
    #     {
    #         # "train/subtomo": wandb.Image(x_unmasked[0, :, 128]),
    #         "train/seg_overall_gt": wandb.Image(y[0, :].sum(dim=0).float()),
    #     }
    # )
    

    for step in range(n_steps):

        optimizer.zero_grad()

        y_hat, x_hat = model.model(x_masked)

        loss = masked_mse_loss(x_hat, x_unmasked, mask)

        loss.backward()
        optimizer.step()

        train_grad_norm = get_grad_norm(model.model)
        if validation:
            model.eval()
            y_hat, _ = model.model(x_unmasked)
            x_unmasked_aug, inverse_fn = random_transform_with_inverse(x_unmasked)
            y_hat_aug, _ = model.model(x_unmasked_aug)
            y_hat_aug = inverse_fn(y_hat_aug)

            consistency = criterion((y_hat.sigmoid() > 0.5).int(), (y_hat_aug.sigmoid() > 0.5).int()).detach().item()
            dice_loss = criterion(y_hat, y).detach().item()
            dice_coef = scoring_fn((y_hat.sigmoid() > 0.5).int(), y).detach().item()
            print(f"loss: {dice_loss:.3f} | coef: {dice_coef:.3f} | mse_loss: {loss.item():.3f}")
            val_losses.append(dice_loss)
            val_dices.append(dice_coef)

            for batch in test_val_loader:
                x_masked_val, x_unmasked_val = batch["image"].to(device), batch["unmasked_image"].to(device)
                y_out, id_val = batch["label"].to(device), batch["id"]
                mask_val = batch["mask"].to(device)

                _, x_hat_val = model.model(x_masked_val)

                val_mse_loss = masked_mse_loss(x_hat_val, x_unmasked_val, mask_val)
                # val_dice_loss = criterion(y_hat_val, y_out).detach().item()
                # run.log({
                #     "val/mse_loss": val_mse_loss,
                #     "val/dice_loss": val_dice_loss,
                # })

            model.train()
        
        slice_to_log = torch.argmax(y[0, 0].sum(dim=(-1, -2)))
        # run.log({
        #     "train/mse_loss": loss.item(),
        #     "train/dice_loss": dice_loss,
        #     "train/dice": dice_coef,
        #     "step": step,
        #     # "train/seg": wandb.Image(torch.cat((y[0, :, 128].float(), y_hat.sigmoid()[0, :, 128]), dim=1)),
        #     # "train/denoised": wandb.Image(torch.cat((x_unmasked[0, :, 128], x_hat[0, :, 128]), dim=1)),
        #     "train/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
        #     "train/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
        #     "train/seg_gt": wandb.Image(FT.to_pil_image((y[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
        #     "train/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
        #     "train/mask": wandb.Image(FT.to_pil_image((mask[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
        #     "val/mse_loss": val_mse_loss.item(),
        #     "train/grad_norm": train_grad_norm,
        #     "train/intensity_var": x_hat.var(),
        #     "train/seg_overall_gt": wandb.Image(FT.to_pil_image(y[0, :].sum(dim=1).to(torch.uint8), mode="L")),
        #     "train/seg_overall": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :] > 0.5).sum(dim=1).to(torch.uint8), mode="L")),
        #     "train/consistency": consistency,
        # })
        x_masked, mask = dataset.generate_mask(deepcopy(x_unmasked[0, 0].cpu()))
        x_masked, mask = torch.Tensor(x_masked)[None, None, ...].to(device), torch.Tensor(mask)[None, None, ...].to(device)

    model.eval()
    y_hat, _ = model.model(x_unmasked)
    
    # run.finish()
    return y_hat, val_losses, val_dices

In [8]:
def ttt(model, test_loader, n_steps, lr, momentum, validation=False, test_val_loader=None):
    
    dices = []
    dices_raw = []
    conf_matrix = torch.zeros(2, 2)
    conf_matrix_raw = torch.zeros(2, 2)

    val_losses = []
    val_dices = []
    # run = wandb.init(project="cryo-ttt-ttt", name="first")
    # run.finish()

    y_hats = []
    y_hats_raw = []
    centers = []
    
    for batch in test_loader:

        # model.load_state_dict(torch.load(cfg.ckpt_path))

        x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
        y_out, id = batch["label"].to(device), batch["id"]
        mask = batch["mask"].to(device)

        model.eval()
        y_hat_raw, _ = model.model(x_unmasked)
        # optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
        y_hat, val_step_losses, val_step_dices = ttt_one_batch(
            x_masked,
            x_unmasked,
            mask, deepcopy(model), n_steps, lr, momentum, validation=validation, y=y_out, id=id, test_val_loader=test_val_loader)

        y_hats.append(y_hat.squeeze().detach().cpu().numpy())
        y_hats_raw.append(y_hat_raw.squeeze().detach().cpu().numpy())
        center = id[0].split("_")[-1][1:-1]
        center = center.split(" ")
        center = [int(elem) for elem in center]
        print(f"Center: {center}")
        centers.append(center)

        val_losses.append(val_step_losses)
        val_dices.append(val_step_dices)

        dice_score = 1 - criterion(y_hat, y_out).detach().item()
        dice_score_raw = 1 - criterion(y_hat_raw, y_out).detach().item()
        conf_matrix += bcm(y_hat.sigmoid()[:, :, 12:-12, 12:-12, 12:-12], y_out[:, :, 12:-12, 12:-12, 12:-12]).detach().cpu()
        conf_matrix_raw += bcm(y_hat_raw.sigmoid()[:, :, 12:-12, 12:-12, 12:-12], y_out[:, :, 12:-12, 12:-12, 12:-12]).detach().cpu()

        print(f"bilo: {dice_score_raw:.3f}, stalo: {dice_score:.3f}")
        dices.append(dice_score)
        dices_raw.append(dice_score_raw)
    
    print(f"Mean Dice after ttt: {sum(dices) / len(dices)}")
    print(f"Macro Dice after ttt: {dice_from_conf_matrix(conf_matrix)}")
    print(f"Mean Dice raw: {sum(dices_raw) / len(dices_raw)}")
    print(f"Macro Dice raw: {dice_from_conf_matrix(conf_matrix_raw)}")

    reassembled_pred = reassemble_subtomograms_v2(y_hats, centers, test_entire_gt.shape[2:], 256, 12)
    reassembled_pred = (torch.Tensor(reassembled_pred).sigmoid() > 0.5).to(torch.uint8)[None, None, ...]

    reassembled_pred_raw = reassemble_subtomograms_v2(y_hats_raw, centers, test_entire_gt.shape[2:], 256, 12)
    reassembled_pred_raw = (torch.Tensor(reassembled_pred_raw).sigmoid() > 0.5).to(torch.uint8)[None, None, ...]


    very_macro_dice = scoring_fn(reassembled_pred_raw, test_entire_gt).item()
    very_macro_dice_after = scoring_fn(reassembled_pred, test_entire_gt).item()

    run = wandb.init(project="cryo-ttt-ttt", name="last")
    print(cfg.test_data_root_dir)
    run.log({
        "macro_dice": dice_from_conf_matrix(conf_matrix_raw),
        "macro_dice_after": dice_from_conf_matrix(conf_matrix),
        "mean_dice": sum(dices_raw) / len(dices_raw),
        "mean_dice_after": sum(dices) / len(dices),
        "very_macro_dice": very_macro_dice,
        "very_macro_dice_after": very_macro_dice_after,
    })
    run.finish()

    return dices, dices_raw, conf_matrix, conf_matrix_raw, val_losses, val_dices

In [9]:
dices = ttt(model, test_loader, 10, 0.001, 0.9, validation=True, test_val_loader=test_val_loader)

loss: 0.694 | coef: 0.375 | mse_loss: 0.697
loss: 0.683 | coef: 0.379 | mse_loss: 0.656
loss: 0.684 | coef: 0.370 | mse_loss: 0.620
loss: 0.688 | coef: 0.360 | mse_loss: 0.582
loss: 0.689 | coef: 0.354 | mse_loss: 0.556
loss: 0.693 | coef: 0.346 | mse_loss: 0.530
loss: 0.699 | coef: 0.339 | mse_loss: 0.504
loss: 0.707 | coef: 0.331 | mse_loss: 0.472
loss: 0.718 | coef: 0.325 | mse_loss: 0.446
loss: 0.724 | coef: 0.318 | mse_loss: 0.430
Center: [128, 824, 360]
bilo: 0.289, stalo: 0.276
loss: 0.540 | coef: 0.476 | mse_loss: 0.697
loss: 0.525 | coef: 0.490 | mse_loss: 0.663
loss: 0.512 | coef: 0.502 | mse_loss: 0.632
loss: 0.503 | coef: 0.510 | mse_loss: 0.607
loss: 0.496 | coef: 0.517 | mse_loss: 0.587
loss: 0.491 | coef: 0.521 | mse_loss: 0.570
loss: 0.488 | coef: 0.523 | mse_loss: 0.550
loss: 0.486 | coef: 0.524 | mse_loss: 0.530
loss: 0.486 | coef: 0.523 | mse_loss: 0.506
loss: 0.488 | coef: 0.522 | mse_loss: 0.483
Center: [128, 128, 128]
bilo: 0.449, stalo: 0.512
loss: 0.430 | coef: 

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdijor0310[0m ([33mcryo-diyor[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
macro_dice,▁
macro_dice_after,▁
mean_dice,▁
mean_dice_after,▁
very_macro_dice,▁
very_macro_dice_after,▁

0,1
macro_dice,0.34655
macro_dice_after,0.68218
mean_dice,0.36662
mean_dice_after,0.56817
very_macro_dice,0.3464
very_macro_dice_after,0.68178


In [10]:
y_hats[0].shape

NameError: name 'y_hats' is not defined

In [9]:
import torch
import random
import torch.nn.functional as F

# Step 1: Define the transformation and its inverse
def random_transform_with_inverse(tensor):
    transformation = random.choice(['rotate', 'mirror'])
    inverse_function = None

    if transformation == 'rotate':
        k = random.choice([1, 2, 3])
        transformed_tensor = torch.rot90(tensor, k, [3, 4])
        inverse_function = lambda x: torch.rot90(x, -k, [3, 4])

    elif transformation == 'mirror':
        axis = random.choice([3, 4])
        transformed_tensor = torch.flip(tensor, [axis])
        inverse_function = lambda x: torch.flip(x, [axis])

    return transformed_tensor, inverse_function

# # Step 2: Example usage in the flow
# B, D, H, W = 2, 10, 64, 64
# original_tensor = torch.rand((B, 1, D, H, W))

# # Apply random transformation and get the inverse function
# augmented_tensor, inverse_function = random_transform_with_inverse(original_tensor)

# # Mock neural network model
# class MockModel(torch.nn.Module):
#     def forward(self, x):
#         return x  # Mock output, replace with actual model inference

# model = MockModel()

# # Step 3: Pass through the neural network
# original_segmentation = model(original_tensor)
# augmented_segmentation = model(augmented_tensor)

# # Step 4: Apply inverse transformation to the augmented segmentation
# transformed_back_segmentation = inverse_function(augmented_segmentation)

# # Step 5: Compare original and transformed-back segmentation
# consistency_loss = F.mse_loss(original_segmentation, transformed_back_segmentation)

# print("Consistency Loss:", consistency_loss.item())

In [8]:
def ttt_consistency_one_batch(x, model, n_steps, lr, momentum):
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)
    for step in range(n_steps):

        x1, inverse_fn_1 = random_transform_with_inverse(x)
        x2, inverse_fn_2 = random_transform_with_inverse(x)

        optimizer.zero_grad()

        y1, _ = model.model(x1)
        y2, _ = model.model(x2)
        y1 = inverse_fn_1(y1)
        y2 = inverse_fn_2(y2)

        # loss = F.binary_cross_entropy_with_logits(y1, y2)
        loss = criterion(y1, (y2.sigmoid() > 0.5).int())
        loss.backward()
        optimizer.step()
        # print(f"Step {step}: {loss:.2f}")

    model.eval()
    y_hat, _ = model.model(x)
    
    return y_hat

In [9]:
def ttt_consistency(model, test_loader, n_steps, lr, momentum):
    
    dices = []
    dices_raw = []
    conf_matrix = torch.zeros(2, 2)
    conf_matrix_raw = torch.zeros(2, 2)
    for batch in test_loader:

        # model.load_state_dict(torch.load(cfg.ckpt_path))

        x_unmasked = batch["unmasked_image"].to(device)
        y_out, id = batch["label"].to(device), batch["id"]

        model.eval()
        y_hat_raw, _ = model.model(x_unmasked)

        y_hat = ttt_consistency_one_batch(x_unmasked, deepcopy(model), n_steps, lr, momentum)

        dice_score = 1 - criterion(y_hat, y_out).detach().item()
        dice_score_raw = 1 - criterion(y_hat_raw, y_out).detach().item()
        conf_matrix += bcm(y_hat.sigmoid(), y_out).detach().cpu()
        conf_matrix_raw += bcm(y_hat_raw.sigmoid(), y_out).detach().cpu()

        print(f"bilo: {dice_score_raw:.3f}, stalo: {dice_score:.3f}")

        dices.append(dice_score)
        dices_raw.append(dice_score_raw)
    
    print(f"Mean Dice after ttt: {sum(dices) / len(dices)}")
    print(f"Macro Dice after ttt: {dice_from_conf_matrix(conf_matrix)}")
    print(f"Mean Dice raw: {sum(dices_raw) / len(dices_raw)}")
    print(f"Macro Dice raw: {dice_from_conf_matrix(conf_matrix_raw)}")

    return dices, dices_raw, conf_matrix, conf_matrix_raw

In [None]:
dices, dices_raw = ttt_consistency(model, test_loader, 10, 0.001, 0.9)

In [13]:
def ttt_combined(model, test_loader, n_steps_1, n_steps_2, lr, momentum):

    dices_raw = []
    dices_1 = []
    dices_2 = []
    conf_matrix_raw = torch.zeros(2, 2)
    conf_matrix_1 = torch.zeros(2, 2)
    conf_matrix_2 = torch.zeros(2, 2)
    for batch in test_loader:

        # model.load_state_dict(torch.load(cfg.ckpt_path))

        x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
        y_out, id = batch["label"].to(device), batch["id"]
        mask = batch["mask"].to(device)

        model.eval()
        y_hat_raw, _ = model.model(x_unmasked)
        dice_score_raw = 1 - criterion(y_hat_raw, y_out).detach().item()
        conf_matrix_raw += bcm(y_hat_raw.sigmoid(), y_out).detach().cpu()

        # optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
        model_temp = deepcopy(model)

        y_hat = ttt_one_batch(x_masked, x_unmasked, mask, model_temp, n_steps_1, lr, momentum)
        dice_score_1 = 1 - criterion(y_hat, y_out).detach().item()
        conf_matrix_1 += bcm(y_hat.sigmoid(), y_out).detach().cpu()
        plot(y_out, y_hat_raw, y_hat, id)
        print(id)

        y_hat = ttt_consistency_one_batch(x_unmasked, model_temp, n_steps_2, lr, momentum)
        dice_score_2 = 1 - criterion(y_hat, y_out).detach().item()
        conf_matrix_2 += bcm(y_hat.sigmoid(), y_out).detach().cpu()

        print(f"bilo: {dice_score_raw:.3f}, stalo: {dice_score_1:.3f} stalo: {dice_score_2:.3f}")
        # print(dice_from_conf_matrix(conf_matrix_1))
        # print(scoring_fn(y_hat.sigmoid(), y_out).detach().item())
        # print(dice_from_conf_matrix(conf_matrix_2))
        dices_raw.append(dice_score_raw)
        dices_1.append(dice_score_1)
        dices_2.append(dice_score_2)
        # plot(y_out, y_hat_raw, y_hat, id)
    
    print(sum(dices_2) / len(dices_2))
    print(sum(dices_1) / len(dices_1))
    print(sum(dices_raw) / len(dices_raw))
    print()
    print(f"Macro Dice 2: {dice_from_conf_matrix(conf_matrix_2)}")
    print(f"Macro Dice 1: {dice_from_conf_matrix(conf_matrix_1)}")
    print(f"Macro Dice: {dice_from_conf_matrix(conf_matrix_raw)}")
    return dices_1, dices_2, dices_raw, conf_matrix_1, conf_matrix_2, conf_matrix_raw

In [None]:
out = ttt_combined(model, test_loader, 10, 10, 0.001, 0.9)

In [7]:
def ttt_summed_one_batch(x_masked, x_unmasked, mask, model, n_steps, lr, momentum):
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)
    for step in range(n_steps):

        x1, inverse_fn_1 = random_transform_with_inverse(x_unmasked)
        x2, inverse_fn_2 = random_transform_with_inverse(x_unmasked)

        optimizer.zero_grad()

        y1, _ = model.model(x1)
        y2, _ = model.model(x2)
        y1 = inverse_fn_1(y1)
        y2 = inverse_fn_2(y2)

        _, x_denoised = model.model(x_masked)

        loss = criterion(y1, (y2.sigmoid() > 0.5).int()) + masked_mse_loss(x_denoised, x_unmasked, mask)
        loss.backward()
        optimizer.step()

    model.eval()
    y_hat, _ = model.model(x_unmasked)
    
    return y_hat

In [8]:
def ttt_summed(model, test_loader, n_steps, lr, momentum):

    dices_raw = []
    dices_1 = []
    conf_matrix_raw = torch.zeros(2, 2)
    conf_matrix_1 = torch.zeros(2, 2)

    for batch in test_loader:

        x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
        y_out, id = batch["label"].to(device), batch["id"]
        mask = batch["mask"].to(device)

        model.eval()
        y_hat_raw, _ = model.model(x_unmasked)
        dice_score_raw = 1 - criterion(y_hat_raw, y_out).detach().item()
        conf_matrix_raw += bcm(y_hat_raw.sigmoid(), y_out).detach().cpu()
        
        y_hat = ttt_summed_one_batch(
            x_masked,
            x_unmasked,
            mask,
            deepcopy(model),
            n_steps,
            lr,
            momentum
        )
        dice_score_1 = 1 - criterion(y_hat, y_out).detach().item()
        conf_matrix_1 += bcm(y_hat.sigmoid(), y_out).detach().cpu()

        print(f"bilo: {dice_score_raw:.3f}, stalo: {dice_score_1:.3f}")

        dices_raw.append(dice_score_raw)
        dices_1.append(dice_score_1)
    
    print(sum(dices_1) / len(dices_1))
    print(sum(dices_raw) / len(dices_raw))
    print()
    print(f"Macro Dice 1: {dice_from_conf_matrix(conf_matrix_1)}")
    print(f"Macro Dice: {dice_from_conf_matrix(conf_matrix_raw)}")
    return dices_1, dices_raw, conf_matrix_1, conf_matrix_raw

In [None]:
out = ttt_summed(model, test_loader, 20, 0.0005, 0.9)

## Ablation

In [None]:
out = ttt(model, test_loader, 100, 0.001, 0.9, validation=True)

In [39]:
import matplotlib.pyplot as plt
import numpy as np

def plot_two_line_plots(lines_plot1, lines_plot2, save_path):
    """
    Plots two subplots, each containing multiple lines. Each line is provided
    as a list of y-values, and x-values are assumed to be integer steps (0, 1, 2, ...).
    
    Parameters:
    - lines_plot1: List of lists, where each list contains the y-values for a line in plot 1.
    - lines_plot2: List of lists, where each list contains the y-values for a line in plot 2.
    """
    # Generate x-coordinates (just integers, same length as y)
    x = list(range(len(lines_plot1[0])))  # Assuming all lines have the same length
    
    # Create subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.set_ylim(0, 1)
    ax2.set_ylim(0, 1)
    
    # Plot lines for the first plot
    for y in lines_plot1:
        ax1.plot(x, y)
    ax1.plot(x, np.array(lines_plot1).mean(axis=0), linestyle="--", linewidth=2)
    ax1.set_title('Losses')

    # Plot lines for the second plot
    for y in lines_plot2:
        ax2.plot(x, y)
    ax2.plot(x, np.array(lines_plot2).mean(axis=0), linestyle="--", linewidth=2)
    ax2.set_title('Scores')

    fig.savefig(save_path)
    # Show the plots
    plt.tight_layout()
    plt.show()

In [None]:
import scienceplots
plt.style.use('science')

plot_two_line_plots(out[-2], out[-1], "./plots/graphs/ttt_steps.png")

In [None]:
macro_dices = []
outs = []
prefix = Path("/mnt/hdd_pool_zion/userdata/diyor/data/deepict/training/12-12-trimmed-256-b20/work/testing_data")
tomo_paths = [
    # "vpp_train/001",
    "vpp_train/002",
    # "vpp_train/003",
    "vpp_train/004",
    # "vpp_train/005",
    "vpp_train/006",
    "vpp_test/007",
    "vpp_val/008",
    # "vpp_val/009",
    "vpp_val/010",
    # "def_train/026",
    # "def_train/030",
    # "def_train/034",
    # "def_val/037",
    # "def_test/041",
]
for path in tomo_paths:

    macro_dices_tomo = []
    cfg.test_data_root_dir = prefix / path
    print(f"Tomo {path}")
    test_loader = torch.utils.data.DataLoader(
        build_dataset(cfg, test=True),
        num_workers=cfg.load_num_workers,
        batch_size=1,
        persistent_workers=cfg.persistent_workers,
    )

    for steps in range(0, 40, 10):
        out = ttt(model, test_loader, steps, 0.001, 0.9, validation=False)
        macro_dices_tomo.append(dice_from_conf_matrix(out[2]))
        # outs.append(out)
    
    macro_dices.append(macro_dices_tomo)
    # losses.append(1 - out[1])

In [None]:
import scienceplots
plt.style.use(['science', 'nature'])
for elem in macro_dices:
    plt.plot(list(range(0, 40, 10)), elem)
plt.ylim(0, 1)
plt.xlim(0, 30)
plt.grid()
plt.show()

## Plotting 

In [1]:
import matplotlib.pyplot as plt
import pickle

In [1]:
import numpy as np
from pathlib import Path

prefix = Path("/mnt/hdd_pool_zion/userdata/diyor/data/deepict/training/12-12-trimmed-256-b20/work/testing_data")
tomo_paths = [
    # "vpp_train/001",
    # "vpp_train/002",
    # "vpp_train/003",
    # "vpp_train/004",
    # "vpp_train/005",
    # "vpp_train/006",
    # "vpp_test/007",
    # "vpp_val/008",
    # "vpp_val/009",
    # "vpp_val/010",
    "def_train/026",
    "def_train/030",
    "def_train/034",
    "def_val/037",
    "def_test/041",
]
x = np.arange(len(tomo_paths))

In [None]:
cached_table = {
    # "baseline": [],
    # "denoising_ttt": [],
    # "rotation_ttt": [],
    # "combined_ttt": [],
    "summed_ttt": [],
}

for path in tomo_paths:

    cfg.test_data_root_dir = prefix / path
    print(f"Tomo {path}")
    test_loader = torch.utils.data.DataLoader(
        build_dataset(cfg, test=True),
        num_workers=cfg.load_num_workers,
        batch_size=1,
        persistent_workers=cfg.persistent_workers,
    )

    dices_denoising, dices_raw, conf_m_denoising, conf_m_raw = ttt(model, test_loader, 20, 0.001, 0.9)
    # dices_rotation, _, conf_m_rotation, _ = ttt_consistency(model, test_loader, 10, 0.001, 0.9)
    # _, dices_combined, _, _, conf_m_combined, _ = ttt_combined(model, test_loader, 10, 10, 0.001, 0.9)


    # cached_table["baseline"].append(dice_from_conf_matrix(conf_m_raw))
    # cached_table["denoising_ttt"].append(dice_from_conf_matrix(conf_m_denoising))
    # cached_table["rotation_ttt"].append(dice_from_conf_matrix(conf_m_rotation))
    # cached_table["combined_ttt"].append(dice_from_conf_matrix(conf_m_combined))

# Save table to alter the plotting in the future
# with open("cache/ttt_results_vpp_vpp_macro.pkl", "wb") as handle:
    # pickle.dump(cached_table, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import pickle

names = {
    "baseline": "Baseline (Segmentation + Noise2Void)",
    "denoising_ttt": "Noise2Void TTT [10 steps]",
    "rotation_ttt": "Rotation/Flip Consistency TTT [10 steps]",
    "combined_ttt": "Noise2Void + Rotation/Flip Consistency TTT [10 + 10 steps]"
}
plt.style.use(['science', 'nature'])
plt.rcParams.update({'font.size': 11})
with open("cache/ttt_results_vpp_def_macro.pkl", "rb") as handle:
    cached_table = pickle.load(handle)
table = {}
for k, v in cached_table.items():
    table[names[k]] = cached_table[k]

width = 0.2  # the width of the bars
multiplier = -0.5

fig, ax = plt.subplots(layout='constrained')
# fig, ax = plt.subplots()
fig.set_size_inches(6.5, 2.786)

for attribute, measurement in table.items():
    offset = width * multiplier
    rects = ax.bar(x + offset, measurement, width, label=attribute)
    ax.axhline(np.array(measurement).mean(), color=rects.patches[0].get_facecolor(), linestyle="--", zorder=0)
    # ax.bar_label(rects, padding=3)
    multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
x_labels = ["def/" + tomo.split("/")[-1] for tomo in tomo_paths]

ax.set_ylabel('Dice coefficient')
# ax.set_title('Trained on DEF, TTT on VPP')
ax.set_xticks(x + width, x_labels)
ax.legend(loc='upper left', ncols=2)
ax.set_ylim(0, 1)
# ax.yaxis.grid(True)

plt.show()
fig.savefig("plots/graphs/vpp_def_macro_65.png", dpi=600)

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import pickle

x = [0]
names = {
    "baseline": "Baseline (Segmentation + Noise2Void)",
    "denoising_ttt": "Noise2Void TTT [10 steps]",
    "rotation_ttt": "Rotation/Flip Consistency TTT [10 steps]",
    "combined_ttt": "Noise2Void + Rotation/Flip Consistency TTT [10 + 10 steps]"
}
plt.style.use(['science', 'nature'])
plt.rcParams.update({'font.size': 11})

with open("cache/ttt_results_def_def_macro.pkl", "rb") as handle:
    cached_table = pickle.load(handle)
table = {}
for k, v in cached_table.items():
    table[names[k]] = cached_table[k]

with open("cache/ttt_results_vpp_vpp_macro.pkl", "rb") as handle:
    cached_table = pickle.load(handle)
table1 = {}
for k, v in cached_table.items():
    table1[names[k]] = cached_table[k]

width = 0.2  # the width of the bars
multiplier = -0.5

fig, ax = plt.subplots(1, 2, layout='constrained')
# fig, ax = plt.subplots()
fig.set_size_inches(6.5, 2.786)

for attribute, measurement in table.items():
    offset = width * multiplier
    rects = ax[0][0].bar(x + offset, measurement, width, label=attribute)
    ax[0][0].axhline(np.array(measurement).mean(), color=rects.patches[0].get_facecolor(), linestyle="--", zorder=0)
    # ax.bar_label(rects, padding=3)
    multiplier += 1

for attribute, measurement in table.items():
    offset = width * multiplier
    rects = ax[0][1].bar(x + offset, measurement, width, label=attribute)
    ax[0][1].axhline(np.array(measurement).mean(), color=rects.patches[0].get_facecolor(), linestyle="--", zorder=0)
    # ax.bar_label(rects, padding=3)
    multiplier += 1

# Add some text for labels, title and custom x-axis tick labels, etc.
# x_labels = ["def/" + tomo.split("/")[-1] for tomo in tomo_paths]
x_labels = ["def/041"]
ax[0][0].set_ylabel('Dice coefficient')
# ax.set_title('Trained on DEF, TTT on VPP')
ax[0][0].set_xticks(x + width, x_labels)
# ax[0][0].legend(loc='upper left', ncols=2)
ax[0][0].set_ylim(0, 1)

# x_labels = ["vpp/" + tomo.split("/")[-1] for tomo in tomo_paths]
# ax[0][0].set_ylabel('Dice coefficient')
# ax.set_title('Trained on DEF, TTT on VPP')
x_labels = ["vpp/007"]
ax[0][0].set_xticks(x + width, x_labels)
# ax[0][0].legend(loc='upper left', ncols=2)
ax[0][0].set_ylim(0, 1)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center')
# ax.yaxis.grid(True)

plt.show()
fig.savefig("plots/graphs/vpp_def_macro_65.png", dpi=600)