In [1]:
import torch
import torch.nn as nn
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
torch.set_float32_matmul_precision("medium")
from pathlib import Path
import numpy as np
from utils.utils import set_seed
from utils.ddw_subtomos import reassemble_subtomos
from torchmetrics.classification import BinaryConfusionMatrix, PrecisionRecallCurve
import wandb
import torchvision.transforms.functional as FT
from datasets import build_dataset
from models.denoiseg import Denoiseg
from hydra import initialize, compose
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import mrcfile
import uuid
from tqdm import tqdm
from torch_receptive_field import receptive_field, receptive_field_for_unit

In [2]:
with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

set_seed(cfg.seed)

criterion = DiceLoss(sigmoid=True)
scoring_fn = DiceMetric(reduction='mean')
cross_entropy = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss(reduction='none')
# mse_loss = nn.L1Loss(reduction='none')
device = torch.device("cuda:2")
device_1 = torch.device("cuda:1")
bcm = BinaryConfusionMatrix().to(device)
# pr_curve = PrecisionRecallCurve(task="binary")

dataset = build_dataset(cfg, test=True)

test_dataset, test_val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 2, 2], generator=torch.Generator().manual_seed(cfg.seed))
print(len(test_dataset), len(test_val_dataset))
# test_dataset = dataset
# test_val_dataset = build_dataset(cfg, val=True)
test_loader = torch.utils.data.DataLoader(
    dataset,
    # num_workers=cfg.load_num_workers,
    num_workers=2,
    batch_size=16,
    shuffle=True,
    # persistent_workers=cfg.persistent_workers,
)            
test_val_loader = torch.utils.data.DataLoader(
    test_val_dataset,
    num_workers=8,
    batch_size=2,
    # persistent_workers=cfg.persistent_workers,
)
test_test_loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=8,
    batch_size=1,
    shuffle=False,
    # persistent_workers=cfg.persistent_workers,
)
test_entire_gt = mrcfile.read(cfg.test_entire_gt)
test_entire_tomo = mrcfile.read(cfg.test_entire_tomo)

test_entire_gt = torch.Tensor(test_entire_gt)[None, None, ...]
test_entire_tomo = torch.Tensor(test_entire_tomo)

test_membrain_pred = mrcfile.read(cfg.test_membrain_pred)
test_membrain_pred = torch.Tensor(test_membrain_pred)[None, None, ...]

model = Denoiseg.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

Seed set to 42


262 2


In [3]:
a = torch.rand(1, 100) > 0.5
b = torch.rand(1, 100) > 0.5

c = bcm(a.to(device), b.to(device)).detach().cpu()
c

tensor([[25, 16],
        [24, 35]])

In [4]:
def masked_mse_loss(x_hat, x, mask):
    loss = mse_loss(x_hat, x) * mask
    return loss.sum() / mask.sum()

def dice_from_conf_matrix(conf_matrix):
    return 2 * conf_matrix[1, 1] / (2 * conf_matrix[1, 1] + conf_matrix[0, 1] + conf_matrix[1, 0])

def precision_from_conf_matrix(conf_matrix: torch.Tensor):
    return conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[0, 1])

def recall_from_conf_matrix(conf_matrix: torch.Tensor):
    return conf_matrix[1, 1] / (conf_matrix[1, 1] + conf_matrix[1, 0])

def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

def process_id(id):
    id = id.split("_")[-1]
    id = id[1:-1].split(" ")
    id = [int(elem) for elem in id]
    return id

def normalize_image(image, max_intensity):

    return image * 255.0 / max_intensity

def normalize_min_max(image):

    image = (image - image.min()) / (image.max() - image.min())
    return image * 255.0

In [5]:
def macro_test_loop(test_test_loader, model, run, epoch):
    # Test
    conf_matrix = torch.zeros(2, 2)
    preds = []
    ids = []
    subtomo_start_coords = []
    denoised_preds = []
    with torch.no_grad():
        model.eval()
        for i, batch in enumerate(tqdm(test_test_loader)):
            x_unmasked = batch["unmasked_image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            y_hat, x_hat = model.model(x_unmasked)

            # dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().item()
            conf_matrix += bcm(y_hat.sigmoid(), y_out).detach().cpu()
            
            # slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))

            # id = id[0].replace(" ", "_")
            preds.append(y_hat.squeeze().detach().cpu().sigmoid())
            denoised_preds.append(x_hat.squeeze().detach().cpu())

            # ids.append(process_id(id[0]))
            subtomo_start_coords.append(batch["start_coord"][0])
            # run.log({
            #     f"train/{i}/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
            #     f"train/{i}/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
            #     f"train/{i}/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
            #     f"train/{i}/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
            #     f"train/{i}/dice": dice,
            #     # f"train/{i}/dice_loss": dice_loss,
            #     "epoch": epoch,
            # })
        
        # reassembled_pred = reassemble_subtomos(preds, ids, test_entire_gt.shape[2:], 256, 12)
        reassembled_pred = reassemble_subtomos(
            subtomos=preds,
            subtomo_start_coords=subtomo_start_coords,
            subtomo_overlap=80,
            crop_to_size=test_entire_gt.shape[2:]
        )

        denoised_reassembled = reassemble_subtomos(
            subtomos=denoised_preds,
            subtomo_start_coords=subtomo_start_coords,
            subtomo_overlap=80,
            crop_to_size=test_entire_gt.shape[2:]
        )

        # reassembled_pred = (torch.Tensor(reassembled_pred).sigmoid() > 0.5).to(torch.uint8)[None, None, ...]
        reassembled_pred_binary = (reassembled_pred[None, None, ...] > 0.5).to(torch.uint8)
        reassembled_pred_binary_75 = (reassembled_pred[None, None, ...] > 0.75).to(torch.uint8)
        reassembled_pred_binary_90 = (reassembled_pred[None, None, ...] > 0.9).to(torch.uint8)
        reassembled_pred_binary_98 = (reassembled_pred[None, None, ...] > 0.98).to(torch.uint8)

        # postprocessed_pred = torch.Tensor(connected_components(reassembled_pred.squeeze().numpy(), 50)).to(torch.bool).to(torch.uint8)
        max_intensity = np.max(test_entire_gt[0].sum(dim=1).numpy())
        full_conf_matrix = bcm(reassembled_pred_binary.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_75 = bcm(reassembled_pred_binary_75.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_90 = bcm(reassembled_pred_binary_90.to(device), test_entire_gt.to(device)).detach().cpu()
        full_conf_matrix_98 = bcm(reassembled_pred_binary_98.to(device), test_entire_gt.to(device)).detach().cpu()

        slice_to_log = torch.argmax(test_entire_gt.squeeze().sum(dim=(-2, -1)))

        run.log({
            "train/macro_dice": dice_from_conf_matrix(conf_matrix),
            "train/rsm_macro_dice": scoring_fn(reassembled_pred_binary, test_entire_gt).item(),
            "train/rsm_macro_dice/75": scoring_fn(reassembled_pred_binary_75, test_entire_gt).item(),
            "train/rsm_macro_dice/90": scoring_fn(reassembled_pred_binary_90, test_entire_gt).item(),
            "train/rsm_macro_dice/98": scoring_fn(reassembled_pred_binary_98, test_entire_gt).item(),
            "train/rsm_precision": precision_from_conf_matrix(full_conf_matrix),
            "train/rsm_recall": recall_from_conf_matrix(full_conf_matrix),
            "train/rsm_precision/75": precision_from_conf_matrix(full_conf_matrix_75),
            "train/rsm_recall/75": recall_from_conf_matrix(full_conf_matrix_75),
            "train/rsm_precision/90": precision_from_conf_matrix(full_conf_matrix_90),
            "train/rsm_recall/90": recall_from_conf_matrix(full_conf_matrix_90),
            "train/rsm_precision/98": precision_from_conf_matrix(full_conf_matrix_98),
            "train/rsm_recall/98": recall_from_conf_matrix(full_conf_matrix_98),
            "train/denoised": wandb.Image(denoised_reassembled[slice_to_log]),
            "train/tomo": wandb.Image(test_entire_tomo[slice_to_log]),
            # "train/rsm_pr_curve": wandb.plot.pr_curve(
            #     test_entire_gt.squeeze().numpy().flatten(),
            #     np.stack((1 - reassembled_pred.numpy().flatten(), reassembled_pred.numpy().flatten()), axis=-1),
            # ),
            "membrain/macro_dice": scoring_fn(test_membrain_pred[:, :, :test_entire_gt.shape[2], :test_entire_gt.shape[3], :test_entire_gt.shape[4]], test_entire_gt).item(),
            "train/seg_overall_gt": wandb.Image(FT.to_pil_image(normalize_min_max(test_entire_gt[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/75": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_75[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/90": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_90[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "train/seg_overall/98": wandb.Image(FT.to_pil_image(normalize_min_max(reassembled_pred_binary_98[0].sum(dim=1)).to(torch.uint8), mode="L")),
            "membrain/seg_overall": wandb.Image(FT.to_pil_image(normalize_min_max(test_membrain_pred[0, :].sum(dim=1)).to(torch.uint8), mode="L")),
            # "train/seg_overall_postprocessed": wandb.Image(FT.to_pil_image(postprocessed_pred[None, ...].sum(dim=1).to(torch.uint8), mode="L")),
            "epoch": epoch,
        })

def ttt_one_tomo(test_loader, model, n_epochs, lr, momentum, test_val_loader=None, test_test_loader=None):
    
    random_uuid = uuid.uuid4()
    random_string = str(random_uuid).replace("-", "")[:5]

    run = wandb.init(project="cryo-ttt-ttt", name=f"{cfg.exp_name}-lr-{lr:.2E}-{random_string}")
    print(cfg)
    macro_test_loop(test_test_loader, model, run, epoch=0)
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)
    scheduler = torch.optim.lr_scheduler.CyclicLR(
        optimizer, 
        base_lr=0.0003,   # Minimum learning rate
        max_lr=0.003,     # Maximum learning rate
        step_size_up=5,  # Number of iterations to go from base_lr to max_lr
        mode='triangular' # Policy for cyclic variation
    )     
    
    step = 0
    for epoch in range(1, n_epochs + 1):
        print(f"Epoch: {epoch}")
        for batch in tqdm(test_loader):
            # -------------
            # Train step
            model.train()

            x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            mask = batch["mask"].to(device)

            optimizer.zero_grad()

            # y_hat, x_hat = model.model(x_masked)
            # ### Normalize
            # x_masked, mean, std = model.normalize(x_masked)
            # ###

            y_hat, x_hat = model.model(x_masked)

            # ### Denormalize
            # x_hat = model.denormalize(x_hat, mean, std)
            # ###


            loss = masked_mse_loss(x_hat, x_unmasked, mask)

            loss.backward()
            optimizer.step()
            scheduler.step()

            train_grad_norm = get_grad_norm(model.model)
            dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().mean().item()

            val_losses = []
            for batch_val in tqdm(test_val_loader):
                # --------------
                # Validation step
                model.eval()

                x_masked_val, x_unmasked_val = batch_val["image"].to(device), batch_val["unmasked_image"].to(device)
                y_out_val, id_val = batch_val["label"].to(device), batch_val["id"]
                mask_val = batch_val["mask"].to(device)

                _, x_hat_val = model.model(x_masked_val)

                val_mse_loss = masked_mse_loss(x_hat_val, x_unmasked_val, mask_val)
                val_losses.append(val_mse_loss.item())

                slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))
            
            # Log results after step i
            run.log({
                "train/mse_loss": loss.item(),
                "train/dice_loss": dice_loss,
                "train/dice": dice,
                "epoch": epoch,
                "step": step,
                # "train/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
                # "train/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
                # "train/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                # "train/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                # "train/mask": wandb.Image(FT.to_pil_image((mask[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                "val/mse_loss": np.array(val_losses).mean(),
                # "train/grad_norm": train_grad_norm,
                # "train/intensity_var": x_hat.var(),
                # "train/seg_overall_gt": wandb.Image(FT.to_pil_image(y_out[0, :].sum(dim=1).to(torch.uint8), mode="L")),
                # "train/seg_overall": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :] > 0.5).sum(dim=1).to(torch.uint8), mode="L")),
            })

            step += 1
        macro_test_loop(test_test_loader, model, run, epoch=epoch)
        
    run.finish()
    # return y_hat, val_losses, val_dices

In [6]:
ttt_one_tomo(test_loader, model, 30, 0.003, 0.9, test_val_loader=test_val_loader, test_test_loader=test_test_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdijor0310[0m ([33mcryo-diyor[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'exp_name': 'n2v-GN-vpp-def', 'wandb_project_name': 'cryo-ttt-v2', 'ckpt_path': '/workspaces/cryo/cryo-ttt/src/ttt_ckpt/denoiseg-GN-vpp-norm-p160-cb6bb/epoch=157-val/dice_loss=0.52.ckpt', 'seed': 42, 'debug': False, 'devices': [2, 3], 'profiler': None, 'strategy': 'ddp', 'shuffle': False, 'train_load_num_workers': 8, 'val_load_num_workers': 8, 'pin_memory': False, 'persistent_workers': False, 'accumulate_grad_batches': 1, 'gradient_clip_val': None, 'check_val_every_n_epochs': 1, 'log_every_n_steps': 1, 'num_sanity_val_steps': 0, 'enable_progress_bar': True, 'method': {'model_name': 'denoiseg', 'train_batch_size': 32, 'eval_batch_size': 32, 'test_batch_size': 1, 'dataset': 'denoiseg_v2', 'mask_ratio': 0.1, 'window_size': 5, 'depth': 2, 'initial_features': 4, 'encoder_dropout': 0.0, 'decoder_dropout': 0.1, 'BN': 'group_norm', 'elu': False, 'lambda_ce': 0.2, 'learning_rate': 0.003, 'gamma_decay': 0.99, 'max_epochs': 600, 'train_data_root_dir': '/media/ssd3/diyor/patch-160-overlap-80', 't

100%|██████████| 264/264 [00:17<00:00, 14.93it/s]


Epoch: 1


100%|██████████| 1/1 [00:02<00:00,  2.70s/it]
100%|██████████| 1/1 [00:02<00:00,  2.79s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.71s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.74s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.74s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.73s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.77s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.78s/it]]
100%|██████████| 1/1 [00:02<00:00,  2.77s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.79s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.83s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.81s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.76s/it]t]
100%|██████████| 1/1 [00:02<00:00,  2.73s/it]t]
100%|██████████| 17/17 [01:12<00:00,  4.27s/it]
100%|██████████| 264/264 [00:21<00:00, 12.13it/s]


Epoch: 2


100%|██████████| 1/1 [00:03<00:00,  3.75s/it]
100%|██████████| 1/1 [00:03<00:00,  3.75s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.79s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.82s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.82s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.89s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.82s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.80s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.80s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.78s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.83s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.80s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.80s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.83s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.85s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.80s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.75s/it]t]
100%|██████████| 17/17 [01:30<00:00,  5.33s/it]
100%|██████████| 264/264 [00:23<00:00, 11.34it/s]


Epoch: 3


100%|██████████| 1/1 [00:04<00:00,  4.00s/it]
100%|██████████| 1/1 [00:04<00:00,  4.03s/it]]
100%|██████████| 1/1 [00:04<00:00,  4.03s/it]]
100%|██████████| 1/1 [00:04<00:00,  4.04s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.70s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.68s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.71s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.70s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.63s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.67s/it]]
100%|██████████| 1/1 [00:01<00:00,  1.69s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.69s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.66s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.68s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.66s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.69s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.68s/it]t]
100%|██████████| 17/17 [01:04<00:00,  3.81s/it]
100%|██████████| 264/264 [00:16<00:00, 15.56it/s]


Epoch: 4


100%|██████████| 1/1 [00:03<00:00,  3.47s/it]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.53s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.56s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.55s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.53s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.62s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.53s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.53s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.54s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.54s/it]t]
100%|██████████| 17/17 [01:26<00:00,  5.06s/it]
100%|██████████| 264/264 [00:22<00:00, 11.69it/s]


Epoch: 5


100%|██████████| 1/1 [00:03<00:00,  3.48s/it]
100%|██████████| 1/1 [00:03<00:00,  3.49s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.56s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.57s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.54s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.58s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.57s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.57s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.58s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.55s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.56s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.66s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.57s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.60s/it]t]
100%|██████████| 17/17 [01:26<00:00,  5.09s/it]
100%|██████████| 264/264 [00:21<00:00, 12.14it/s]


Epoch: 6


100%|██████████| 1/1 [00:03<00:00,  3.40s/it]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.41s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.41s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.40s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.43s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.46s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.42s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.47s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.41s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.45s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.44s/it]t]
100%|██████████| 17/17 [01:24<00:00,  4.96s/it]
100%|██████████| 264/264 [00:22<00:00, 11.66it/s]


Epoch: 7


100%|██████████| 1/1 [00:03<00:00,  3.71s/it]
100%|██████████| 1/1 [00:03<00:00,  3.73s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.64s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.74s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.71s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.71s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.71s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.74s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.71s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.70s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.70s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.76s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.72s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.74s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.73s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.70s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.78s/it]t]
100%|██████████| 17/17 [01:29<00:00,  5.26s/it]
100%|██████████| 264/264 [00:22<00:00, 11.72it/s]


Epoch: 8


100%|██████████| 1/1 [00:03<00:00,  3.85s/it]
100%|██████████| 1/1 [00:03<00:00,  3.86s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.93s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.92s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.87s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.89s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.90s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.90s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.89s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.90s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.92s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.89s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.90s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.92s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.88s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.76s/it]t]
100%|██████████| 1/1 [00:01<00:00,  1.70s/it]t]
100%|██████████| 17/17 [01:28<00:00,  5.22s/it]
100%|██████████| 264/264 [00:16<00:00, 15.70it/s]


Epoch: 9


100%|██████████| 1/1 [00:03<00:00,  3.48s/it]
100%|██████████| 1/1 [00:03<00:00,  3.47s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.47s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.56s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.57s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.53s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.64s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]t]
100%|██████████| 1/1 [00:03<00:00,  3.54s/it]t]
100%|██████████| 17/17 [01:25<00:00,  5.05s/it]
100%|██████████| 264/264 [00:22<00:00, 11.78it/s]


Epoch: 10


100%|██████████| 1/1 [00:03<00:00,  3.49s/it]
100%|██████████| 1/1 [00:03<00:00,  3.48s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.49s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.50s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.52s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
100%|██████████| 1/1 [00:03<00:00,  3.51s/it]]
 47%|████▋     | 8/17 [00:43<00:48,  5.41s/it]


KeyboardInterrupt: 

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f452f52ada0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f452f52aec0, execution_count=6 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f452f52aa10, raw_cell="ttt_one_tomo(test_loader, model, 30, 0.003, 0.9, t.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://dev-container%2B7b22686f737450617468223a222f686f6d652f6469796f722f6372796f222c226c6f63616c446f636b6572223a66616c73652c22636f6e66696746696c65223a7b22246d6964223a312c2270617468223a222f686f6d652f6469796f722f6372796f2f2e646576636f6e7461696e65722f646576636f6e7461696e65722e6a736f6e222c22736368656d65223a227673636f64652d66696c65486f7374227d7d@ssh-remote%2Bzion/workspaces/cryo/cryo-ttt/src/denoiseg_ttt_v3.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe