In [1]:
import torch
import torch.nn as nn
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
torch.set_float32_matmul_precision("medium")
from pathlib import Path
import numpy as np
from utils.utils import set_seed
from utils.ddw_subtomos import reassemble_subtomos
from torchmetrics.classification import BinaryConfusionMatrix, PrecisionRecallCurve
import wandb
import torchvision.transforms.functional as FT
from datasets import build_dataset
from models.denoiseg import Denoiseg
from hydra import initialize, compose
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import mrcfile
import uuid
from tqdm import tqdm
from torch_receptive_field import receptive_field, receptive_field_for_unit

In [2]:
with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

set_seed(cfg.seed)

criterion = DiceLoss(sigmoid=True)
scoring_fn = DiceMetric(reduction='mean')
cross_entropy = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss(reduction='none')
# mse_loss = nn.L1Loss(reduction='none')
device = torch.device("cuda:3")
device_1 = torch.device("cuda:1")
bcm = BinaryConfusionMatrix().to(device)
# pr_curve = PrecisionRecallCurve(task="binary")

dataset = build_dataset(cfg, test=True)

test_dataset, test_val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 2, 2], generator=torch.Generator().manual_seed(cfg.seed))
print(len(test_dataset), len(test_val_dataset))
# test_dataset = dataset
# test_val_dataset = build_dataset(cfg, val=True)
test_loader = torch.utils.data.DataLoader(
    dataset,
    # num_workers=cfg.load_num_workers,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    # persistent_workers=cfg.persistent_workers,
)            
test_val_loader = torch.utils.data.DataLoader(
    test_val_dataset,
    num_workers=8,
    batch_size=2,
    # persistent_workers=cfg.persistent_workers,
)
test_test_loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=8,
    batch_size=1,
    shuffle=False,
    # persistent_workers=cfg.persistent_workers,
)
test_entire_gt = mrcfile.read(cfg.test_entire_gt)
test_entire_tomo = mrcfile.read(cfg.test_entire_tomo)

test_entire_gt = torch.Tensor(test_entire_gt)[None, None, ...]
test_entire_tomo = torch.Tensor(test_entire_tomo)

test_membrain_pred = mrcfile.read(cfg.test_membrain_pred)
test_membrain_pred = torch.Tensor(test_membrain_pred)[None, None, ...]

model = Denoiseg.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

Seed set to 42


145 2


In [3]:
def masked_mse_loss(x_hat, x, mask):
    loss = mse_loss(x_hat, x) * mask
    return loss.sum() / mask.sum()

def dice_from_conf_matrix(conf_matrix):
    return 2 * conf_matrix[1, 1] / (2 * conf_matrix[1, 1] + conf_matrix[0, 1] + conf_matrix[1, 0])

def precision_from_conf_matrix(conf_matrix: torch.Tensor):
    return conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[0, 1])

def recall_from_conf_matrix(conf_matrix: torch.Tensor):
    return conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0])

def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

def process_id(id):
    id = id.split("_")[-1]
    id = id[1:-1].split(" ")
    id = [int(elem) for elem in id]
    return id

def normalize_image(image, max_intensity):

    return image * 255.0 / max_intensity

def normalize_min_max(image):

    image = (image - image.min()) / (image.max() - image.min())
    return image * 255.0

In [4]:
def macro_test_loop(test_test_loader, model, run, epoch):
    # Test
    conf_matrix = torch.zeros(2, 2)
    preds = []
    ids = []
    subtomo_start_coords = []
    denoised_preds = []
    raws = []
    with torch.no_grad():
        model.eval()
        for i, batch in enumerate(tqdm(test_test_loader)):
            x_unmasked = batch["raw_subtomo"].to(device)
            # x_unmasked = batch["image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            y_hat, x_hat = model.model(x_unmasked)

            ## TEST recurrent inference
            # y_hat, _ = model.model(x_hat)
            ##########################

            # dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().item()
            conf_matrix += bcm(y_hat.sigmoid(), y_out).detach().cpu()
            
            # slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))

            # id = id[0].replace(" ", "_")
            preds.append(y_hat.squeeze().detach().cpu().sigmoid())

            denoised_preds.append(x_hat.squeeze().detach().cpu())
            # ids.append(process_id(id[0]))
            subtomo_start_coords.append(batch["start_coord"][0])
            # run.log({
            #     f"train/{i}/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
            #     f"train/{i}/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
            #     f"train/{i}/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
            #     f"train/{i}/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
            #     f"train/{i}/dice": dice,
            #     # f"train/{i}/dice_loss": dice_loss,
            #     "epoch": epoch,
            # })
        
        # reassembled_pred = reassemble_subtomos(preds, ids, test_entire_gt.shape[2:], 256, 12)
        reassembled_pred = reassemble_subtomos(
            subtomos=preds,
            subtomo_start_coords=subtomo_start_coords,
            subtomo_overlap=80,
            crop_to_size=test_entire_gt.shape[2:]
        )

        denoised_reassembled = reassemble_subtomos(
            subtomos=denoised_preds,
            subtomo_start_coords=subtomo_start_coords,
            subtomo_overlap=80,
            crop_to_size=test_entire_gt.shape[2:]
        )

        # reassembled_pred = (torch.Tensor(reassembled_pred).sigmoid() > 0.5).to(torch.uint8)[None, None, ...]
        reassembled_pred_binary = (reassembled_pred[None, None, ...] > 0.5).to(torch.uint8)
        reassembled_pred_binary_75 = (reassembled_pred[None, None, ...] > 0.75).to(torch.uint8)
        reassembled_pred_binary_90 = (reassembled_pred[None, None, ...] > 0.9).to(torch.uint8)
        reassembled_pred_binary_98 = (reassembled_pred[None, None, ...] > 0.98).to(torch.uint8)

        # postprocessed_pred = torch.Tensor(connected_components(reassembled_pred.squeeze().numpy(), 50)).to(torch.bool).to(torch.uint8)
        max_intensity = np.max(test_entire_gt[0].sum(dim=1).numpy())
        full_conf_matrix = bcm(reassembled_pred_binary.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_75 = bcm(reassembled_pred_binary_75.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_90 = bcm(reassembled_pred_binary_90.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_98 = bcm(reassembled_pred_binary_98.to(device), test_entire_gt.to(device)).detach().cpu()

        slice_to_log = torch.argmax(test_entire_gt.squeeze().sum(dim=(-2, -1)))

        run.log({
            "train/macro_dice": dice_from_conf_matrix(conf_matrix),
            "train/rsm_macro_dice": scoring_fn(reassembled_pred_binary, test_entire_gt).item(),
            "train/rsm_macro_dice/75": scoring_fn(reassembled_pred_binary_75, test_entire_gt).item(),
            "train/rsm_macro_dice/90": scoring_fn(reassembled_pred_binary_90, test_entire_gt).item(),
            "train/rsm_macro_dice/98": scoring_fn(reassembled_pred_binary_98, test_entire_gt).item(),
            "train/rsm_precision": precision_from_conf_matrix(full_conf_matrix),
            "train/rsm_recall": recall_from_conf_matrix(full_conf_matrix),
            "train/rsm_precision/75": precision_from_conf_matrix(full_conf_matrix_75),
            "train/rsm_recall/75": recall_from_conf_matrix(full_conf_matrix_75),
            "train/rsm_precision/90": precision_from_conf_matrix(full_conf_matrix_90),
            "train/rsm_recall/90": recall_from_conf_matrix(full_conf_matrix_90),
            "train/rsm_precision/98": precision_from_conf_matrix(full_conf_matrix_98),
            "train/rsm_recall/98": recall_from_conf_matrix(full_conf_matrix_98),
            "train/denoised": wandb.Image(denoised_reassembled[slice_to_log]),
            "train/tomo": wandb.Image(test_entire_tomo[slice_to_log]),

            # "train/rsm_pr_curve": wandb.plot.pr_curve(
            #     test_entire_gt.squeeze().numpy().flatten(),
            #     np.stack((1 - reassembled_pred.numpy().flatten(), reassembled_pred.numpy().flatten()), axis=-1),
            # ),
            "membrain/macro_dice": scoring_fn(test_membrain_pred[:, :, :test_entire_gt.shape[2], :test_entire_gt.shape[3], :test_entire_gt.shape[4]], test_entire_gt).item(),
            "train/seg_overall_gt": wandb.Image(FT.to_pil_image(normalize_min_max(test_entire_gt[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/75": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_75[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/90": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_90[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/98": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_98[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/proba": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred.sum(dim=0)), mode="L")),
            "membrain/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(test_membrain_pred[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
            # "train/seg_overall_postprocessed": wandb.Image(FT.to_pil_image(postprocessed_pred[None, ...].sum(dim=1).to(torch.uint8), mode="L")),
            "epoch": epoch,
        })

def ttt_one_tomo(test_loader, model, n_epochs, lr, momentum, test_val_loader=None, test_test_loader=None):
    
    random_uuid = uuid.uuid4()
    random_string = str(random_uuid).replace("-", "")[:5]

    run = wandb.init(project="cryo-ttt-ttt", name=f"{cfg.exp_name}-lr-{lr:.2E}-{random_string}")
    print(cfg)
    macro_test_loop(test_test_loader, model, run, epoch=0)
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)
    # optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer, 
        base_lr=0.0003,   # Minimum learning rate
        max_lr=0.003,     # Maximum learning rate
        step_size_up=5,  # Number of iterations to go from base_lr to max_lr
        mode='triangular' # Policy for cyclic variation
    )     
    
    step = 0
    for epoch in range(1, n_epochs + 1):
        print(f"Epoch: {epoch}")
        for batch in tqdm(test_loader):
            # -------------
            # Train step
            model.train()

            x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            mask = batch["mask"].to(device)

            optimizer.zero_grad()

            # y_hat, x_hat = model.model(x_masked)
            # ### Normalize
            # x_masked, mean, std = model.normalize(x_masked)
            # ###

            y_hat, x_hat = model.model(x_masked)

            # ### Denormalize
            # x_hat = model.denormalize(x_hat, mean, std)
            # ###


            loss = masked_mse_loss(x_hat, x_unmasked, mask)

            loss.backward()
            optimizer.step()
            scheduler.step()

            train_grad_norm = get_grad_norm(model.model)
            dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().mean().item()

            val_losses = []
            for batch_val in tqdm(test_val_loader):
                # --------------
                # Validation step
                model.eval()

                x_masked_val, x_unmasked_val = batch_val["image"].to(device), batch_val["unmasked_image"].to(device)
                y_out_val, id_val = batch_val["label"].to(device), batch_val["id"]
                mask_val = batch_val["mask"].to(device)

                _, x_hat_val = model.model(x_masked_val)

                val_mse_loss = masked_mse_loss(x_hat_val, x_unmasked_val, mask_val)
                val_losses.append(val_mse_loss.item())

                # slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))
            
            # Log results after step i
            run.log({
                "train/mse_loss": loss.item(),
                "train/dice_loss": dice_loss,
                "train/dice": dice,
                "epoch": epoch,
                "step": step,
                # "train/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
                # "train/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
                # "train/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                # "train/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                # "train/mask": wandb.Image(FT.to_pil_image((mask[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                "val/mse_loss": np.array(val_losses).mean(),
                # "train/grad_norm": train_grad_norm,
                # "train/intensity_var": x_hat.var(),
                # "train/seg_overall_gt": wandb.Image(FT.to_pil_image(y_out[0, :].sum(dim=1).to(torch.uint8), mode="L")),
                # "train/seg_overall": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :] > 0.5).sum(dim=1).to(torch.uint8), mode="L")),
            })

            step += 1
        macro_test_loop(test_test_loader, model, run, epoch=epoch)
        
    run.finish()
    # return y_hat, val_losses, val_dices

In [5]:
ttt_one_tomo(test_loader, model,  20, 0.003, 0.9, test_val_loader=test_val_loader, test_test_loader=test_test_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdijor0310[0m ([33mcryo-diyor[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'exp_name': 'denoiseg-GN-def-norm-p160', 'wandb_project_name': 'cryo-ttt-v2', 'ckpt_path': '/workspaces/cryo/cryo-ttt/src/ttt_ckpt/denoiseg-f2fd-GN-vpp-norm-p160-m5_10-5ef5f/epoch=197-val/dice_loss=0.65.ckpt', 'seed': 42, 'debug': False, 'devices': [2, 3], 'profiler': None, 'strategy': 'ddp', 'shuffle': False, 'train_load_num_workers': 8, 'val_load_num_workers': 8, 'pin_memory': False, 'persistent_workers': False, 'accumulate_grad_batches': 1, 'gradient_clip_val': None, 'check_val_every_n_epochs': 1, 'log_every_n_steps': 1, 'num_sanity_val_steps': 0, 'enable_progress_bar': True, 'method': {'model_name': 'denoiseg', 'train_batch_size': 32, 'eval_batch_size': 32, 'test_batch_size': 1, 'dataset': 'denoiseg_f2fd', 'depth': 2, 'initial_features': 4, 'encoder_dropout': 0.0, 'decoder_dropout': 0.1, 'BN': 'group_norm', 'elu': False, 'bernoulli_mask_ratio': 0.5, 'phase_inversion_ratio': 0.1, 'min_mask_radius': 0.05, 'max_mask_radius': 0.1, 'lambda_ce': 0.2, 'learning_rate': 0.003, 'gamma_decay

100%|██████████| 147/147 [00:12<00:00, 11.66it/s]


Epoch: 1


100%|██████████| 1/1 [00:03<00:00,  3.25s/it]
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.22s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.26s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 10/10 [00:51<00:00,  5.18s/it]
100%|██████████| 147/147 [00:13<00:00, 10.55it/s]


Epoch: 2


100%|██████████| 1/1 [00:03<00:00,  3.05s/it]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.01s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.97s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.01s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.02s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.94s/it]
100%|██████████| 147/147 [00:13<00:00, 10.79it/s]


Epoch: 3


100%|██████████| 1/1 [00:02<00:00,  2.96s/it]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]]
100%|██████████| 1/1 [00:02<00:00,  3.00s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.97s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.62s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.95s/it]
100%|██████████| 147/147 [00:13<00:00, 10.68it/s]


Epoch: 4


100%|██████████| 1/1 [00:03<00:00,  3.19s/it]
100%|██████████| 1/1 [00:03<00:00,  3.18s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.14s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.21s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.16s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.19s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.16s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.18s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.22s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.18s/it]]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]
100%|██████████| 147/147 [00:13<00:00, 10.78it/s]


Epoch: 5


100%|██████████| 1/1 [00:02<00:00,  2.94s/it]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.91s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.97s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.10s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.98s/it]]
100%|██████████| 10/10 [00:48<00:00,  4.89s/it]
100%|██████████| 147/147 [00:13<00:00, 11.05it/s]


Epoch: 6


100%|██████████| 1/1 [00:03<00:00,  3.09s/it]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.07s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]]
100%|██████████| 10/10 [00:50<00:00,  5.01s/it]
100%|██████████| 147/147 [00:13<00:00, 10.75it/s]


Epoch: 7


100%|██████████| 1/1 [00:03<00:00,  3.29s/it]
100%|██████████| 1/1 [00:03<00:00,  3.32s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.38s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.30s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.30s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.31s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.27s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.26s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.25s/it]]
100%|██████████| 10/10 [00:52<00:00,  5.25s/it]
100%|██████████| 147/147 [00:13<00:00, 10.80it/s]


Epoch: 8


100%|██████████| 1/1 [00:03<00:00,  3.49s/it]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.45s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.41s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.41s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.40s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.42s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.40s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 10/10 [00:54<00:00,  5.42s/it]
100%|██████████| 147/147 [00:13<00:00, 10.76it/s]


Epoch: 9


100%|██████████| 1/1 [00:03<00:00,  3.46s/it]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.42s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.68s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.48s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.61s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.21s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.20s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.20s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.22s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.97s/it]
100%|██████████| 147/147 [00:11<00:00, 13.34it/s]


Epoch: 10


100%|██████████| 1/1 [00:03<00:00,  3.12s/it]
100%|██████████| 1/1 [00:03<00:00,  3.12s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.13s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.15s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.07s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.10s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.10s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.10s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.07s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.06s/it]]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]
100%|██████████| 147/147 [00:13<00:00, 10.87it/s]


Epoch: 11


100%|██████████| 1/1 [00:03<00:00,  3.17s/it]
100%|██████████| 1/1 [00:03<00:00,  3.22s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.17s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.18s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.14s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.19s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.20s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.19s/it]]
100%|██████████| 10/10 [00:51<00:00,  5.16s/it]
100%|██████████| 147/147 [00:13<00:00, 10.84it/s]


Epoch: 12


100%|██████████| 1/1 [00:03<00:00,  3.34s/it]
100%|██████████| 1/1 [00:03<00:00,  3.32s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.30s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.29s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.27s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.26s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.29s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.31s/it]]
100%|██████████| 10/10 [00:52<00:00,  5.27s/it]
100%|██████████| 147/147 [00:13<00:00, 10.67it/s]


Epoch: 13


100%|██████████| 1/1 [00:03<00:00,  3.32s/it]
100%|██████████| 1/1 [00:03<00:00,  3.26s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.29s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.27s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.30s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.29s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.30s/it]]
100%|██████████| 10/10 [00:52<00:00,  5.23s/it]
100%|██████████| 147/147 [00:12<00:00, 11.39it/s]


Epoch: 14


100%|██████████| 1/1 [00:02<00:00,  2.97s/it]
100%|██████████| 1/1 [00:02<00:00,  2.98s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.00s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.98s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.97s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.94s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.96s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.91s/it]
100%|██████████| 147/147 [00:13<00:00, 11.05it/s]


Epoch: 15


100%|██████████| 1/1 [00:03<00:00,  3.06s/it]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:02<00:00,  3.00s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.02s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.98s/it]
100%|██████████| 147/147 [00:13<00:00, 10.80it/s]


Epoch: 16


100%|██████████| 1/1 [00:03<00:00,  3.14s/it]
100%|██████████| 1/1 [00:03<00:00,  3.23s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.59s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.12s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.09s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.06s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.06s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.08s/it]]
100%|██████████| 10/10 [00:51<00:00,  5.12s/it]
100%|██████████| 147/147 [00:13<00:00, 10.91it/s]


Epoch: 17


100%|██████████| 1/1 [00:03<00:00,  3.18s/it]
100%|██████████| 1/1 [00:03<00:00,  3.07s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.07s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.06s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.05s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.01s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.04s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.01s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.03s/it]]
100%|██████████| 10/10 [00:49<00:00,  5.00s/it]
100%|██████████| 147/147 [00:13<00:00, 10.80it/s]


Epoch: 18


100%|██████████| 1/1 [00:03<00:00,  3.36s/it]
100%|██████████| 1/1 [00:03<00:00,  3.35s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.40s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.59s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.21s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.30s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.17s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.21s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.22s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.23s/it]]
100%|██████████| 10/10 [00:46<00:00,  4.68s/it]
100%|██████████| 147/147 [00:11<00:00, 13.24it/s]


Epoch: 19


100%|██████████| 1/1 [00:02<00:00,  2.96s/it]
100%|██████████| 1/1 [00:02<00:00,  2.99s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.12s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.00s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.93s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.95s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.94s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.14s/it]]
100%|██████████| 10/10 [00:49<00:00,  4.93s/it]
100%|██████████| 147/147 [00:13<00:00, 10.71it/s]


Epoch: 20


100%|██████████| 1/1 [00:03<00:00,  3.25s/it]
100%|██████████| 1/1 [00:03<00:00,  3.35s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.27s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.29s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.24s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.26s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.22s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.25s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.28s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.27s/it]]
100%|██████████| 10/10 [00:52<00:00,  5.23s/it]
100%|██████████| 147/147 [00:14<00:00, 10.42it/s]


0,1
epoch,▁▁▂▂▂▃▃▃▃▃▃▃▃▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇█████
membrain/macro_dice,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
train/dice,▄▃ ▂▄▄▃▅ ▄▂█ ▃ ▇ ▃ ▂ ▁▄ ▄▆ ▄▅ ▂▁
train/dice_loss,▅▅▄▇▆▃▄▃▄▆▃▅▃▃▅▅▃▆▁▄▆▃▅█▃▅▅▄▅▄▄▅▄▅▂▄▅▅▅▄
train/macro_dice,▁▃▅▆▇██▇▇▇▇▆▅▅▅▅▄▄▃▂▂
train/mse_loss,▂▃▄▅▇▇▆▅▄▄▆▆▄▆▇█▆▄▃▅▅▄▅▇▅▆▅▇▃▁▅▄▄▄▅▇▄▆▄▄
train/rsm_macro_dice,▅▅▇▇██▇▇▆▆▆▅▄▄▄▄▃▃▂▁▁
train/rsm_macro_dice/75,██████▇▇▆▆▅▅▄▄▄▃▃▂▂▁▁
train/rsm_macro_dice/90,███▇▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁▁

0,1
epoch,20.0
membrain/macro_dice,1.0
step,199.0
train/dice,0.36078
train/dice_loss,0.64314
train/macro_dice,0.54955
train/mse_loss,0.47576
train/rsm_macro_dice,0.51144
train/rsm_macro_dice/75,0.48412
train/rsm_macro_dice/90,0.43813


In [6]:
def masked_inference(test_test_loader, model, num_iter):

    random_uuid = uuid.uuid4()
    random_string = str(random_uuid).replace("-", "")[:5]

    run = wandb.init(project="cryo-ttt-ttt", name=f"{cfg.exp_name}-num-iter-{num_iter}-{random_string}")

    # Test

    # overall_preds = []
    # overall_denoised_preds = []
    current_pred = torch.zeros_like(test_entire_gt).squeeze()
    for _ in range(num_iter):
        conf_matrix = torch.zeros(2, 2)
        preds = []
        subtomo_start_coords = []
        denoised_preds = []
        with torch.no_grad():
            model.eval()
            for batch in tqdm(test_test_loader):
                x_unmasked = batch["image"].to(device)
                y_out, id = batch["label"].to(device), batch["id"]
                y_hat, x_hat = model.model(x_unmasked)

                # dice_loss = criterion(y_hat, y_out).detach().item()
                # dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().item()
                conf_matrix += bcm(y_hat.sigmoid(), y_out).detach().cpu()
                
                preds.append(y_hat.squeeze().detach().cpu().sigmoid())

                denoised_preds.append(x_hat.squeeze().detach().cpu())
                subtomo_start_coords.append(batch["start_coord"][0])
            
            reassembled_pred = reassemble_subtomos(
                subtomos=preds,
                subtomo_start_coords=subtomo_start_coords,
                subtomo_overlap=80,
                crop_to_size=test_entire_gt.shape[2:]
            )

            denoised_reassembled = reassemble_subtomos(
                subtomos=denoised_preds,
                subtomo_start_coords=subtomo_start_coords,
                subtomo_overlap=80,
                crop_to_size=test_entire_gt.shape[2:]
            )

            reassembled_pred = torch.stack((reassembled_pred, current_pred), dim=0).max(dim=0)[0]
            reassembled_pred_binary = (reassembled_pred[None, None, ...] > 0.5).to(torch.uint8)
            # reassembled_pred_binary_75 = (reassembled_pred[None, None, ...] > 0.75).to(torch.uint8)
            # reassembled_pred_binary_90 = (reassembled_pred[None, None, ...] > 0.9).to(torch.uint8)
            # reassembled_pred_binary_98 = (reassembled_pred[None, None, ...] > 0.98).to(torch.uint8)

            # postprocessed_pred = torch.Tensor(connected_components(reassembled_pred.squeeze().numpy(), 50)).to(torch.bool).to(torch.uint8)
            full_conf_matrix = bcm(reassembled_pred_binary.to(device), test_entire_gt.to(device)).detach().cpu()
            # full_conf_matrix_75 = bcm(reassembled_pred_binary_75.to(device), test_entire_gt.to(device)).detach().cpu()
            # full_conf_matrix_90 = bcm(reassembled_pred_binary_90.to(device), test_entire_gt.to(device)).detach().cpu()
            # full_conf_matrix_98 = bcm(reassembled_pred_binary_98.to(device), test_entire_gt.to(device)).detach().cpu()

            slice_to_log = torch.argmax(test_entire_gt.squeeze().sum(dim=(-2, -1)))

            run.log({
                "train/macro_dice": dice_from_conf_matrix(conf_matrix),
                "train/rsm_macro_dice": scoring_fn(reassembled_pred_binary, test_entire_gt).item(),
                # "train/rsm_macro_dice/75": scoring_fn(reassembled_pred_binary_75, test_entire_gt).item(),
                # "train/rsm_macro_dice/90": scoring_fn(reassembled_pred_binary_90, test_entire_gt).item(),
                # "train/rsm_macro_dice/98": scoring_fn(reassembled_pred_binary_98, test_entire_gt).item(),
                "train/rsm_precision": precision_from_conf_matrix(full_conf_matrix),
                "train/rsm_recall": recall_from_conf_matrix(full_conf_matrix),
                # "train/rsm_precision/75": precision_from_conf_matrix(full_conf_matrix_75),
                # "train/rsm_recall/75": recall_from_conf_matrix(full_conf_matrix_75),
                # "train/rsm_precision/90": precision_from_conf_matrix(full_conf_matrix_90),
                # "train/rsm_recall/90": recall_from_conf_matrix(full_conf_matrix_90),
                # "train/rsm_precision/98": precision_from_conf_matrix(full_conf_matrix_98),
                # "train/rsm_recall/98": recall_from_conf_matrix(full_conf_matrix_98),
                "train/denoised": wandb.Image(denoised_reassembled[slice_to_log]),
                "train/tomo": wandb.Image(test_entire_tomo[slice_to_log]),
                "membrain/macro_dice": scoring_fn(test_membrain_pred[:, :, :test_entire_gt.shape[2], :test_entire_gt.shape[3], :test_entire_gt.shape[4]], test_entire_gt).item(),
                "train/seg_overall_gt": wandb.Image(FT.to_pil_image(normalize_min_max(test_entire_gt[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
                "train/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary[0].sum(dim=1)).to(torch.uint8), mode="L")),
                # "train/seg_overall/75": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_75[0].sum(dim=1)).to(torch.uint8), mode="L")),
                # "train/seg_overall/90": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_90[0].sum(dim=1)).to(torch.uint8), mode="L")),
                # "train/seg_overall/98": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_98[0].sum(dim=1)).to(torch.uint8), mode="L")),
                "train/seg_overall/proba": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred.sum(dim=0)), mode="L")),
                "membrain/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(test_membrain_pred[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
            })

        current_pred = reassembled_pred.clone()
        # overall_preds.append(reassembled_pred)
        # overall_denoised_preds.append(denoised_reassembled)

    # overall_pred = torch.stack(overall_preds, dim=0).max(dim=0)[0]
    # overall_pred_binary = (overall_pred[None, None, ...] > 0.5).to(torch.uint8)

    # overall_denoised = torch.stack(overall_denoised_preds, dim=0).mean(dim=0)

    # run.log({
    #     "train/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(overall_pred_binary[0].sum(dim=1)).to(torch.uint8), mode="L")),
    #     "train/rsm_macro_dice": scoring_fn(overall_pred_binary, test_entire_gt).item(),
    #     "train/denoised": wandb.Image(overall_denoised[slice_to_log])
    # })
    
    run.finish()


In [7]:
masked_inference(test_test_loader, model, 50)

100%|██████████| 147/147 [00:13<00:00, 10.89it/s]
100%|██████████| 147/147 [00:13<00:00, 10.72it/s]
100%|██████████| 147/147 [00:13<00:00, 10.82it/s]
100%|██████████| 147/147 [00:13<00:00, 11.27it/s]
100%|██████████| 147/147 [00:13<00:00, 11.23it/s]
100%|██████████| 147/147 [00:12<00:00, 11.45it/s]
100%|██████████| 147/147 [00:13<00:00, 11.16it/s]
100%|██████████| 147/147 [00:13<00:00, 10.99it/s]
100%|██████████| 147/147 [00:13<00:00, 10.94it/s]
100%|██████████| 147/147 [00:13<00:00, 11.23it/s]
100%|██████████| 147/147 [00:13<00:00, 10.95it/s]
100%|██████████| 147/147 [00:13<00:00, 10.96it/s]
100%|██████████| 147/147 [00:13<00:00, 11.16it/s]
100%|██████████| 147/147 [00:13<00:00, 10.82it/s]
100%|██████████| 147/147 [00:13<00:00, 10.92it/s]
100%|██████████| 147/147 [00:14<00:00, 10.47it/s]
100%|██████████| 147/147 [00:13<00:00, 11.07it/s]
100%|██████████| 147/147 [00:13<00:00, 10.82it/s]
100%|██████████| 147/147 [00:14<00:00, 10.47it/s]
100%|██████████| 147/147 [00:13<00:00, 11.08it/s]


0,1
membrain/macro_dice,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/macro_dice,▆▃▅▂▂▄▅▃▄▅▄▃▃▅▆▆▆▄▅▂▆▃█▄▃▂▃▃▅▄▅▅▆▄▁▅▇▁▃▂
train/rsm_macro_dice,▄████▇▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁
train/rsm_precision,█▇▆▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
train/rsm_recall,▁▃▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇██████████████████

0,1
membrain/macro_dice,1.0
train/macro_dice,0.49245
train/rsm_macro_dice,0.50032
train/rsm_precision,0.38288
train/rsm_recall,0.72168
