In [1]:
import torch
import torch.nn as nn
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from copy import deepcopy
torch.set_float32_matmul_precision("medium")
from pathlib import Path
import numpy as np
from utils.utils import set_seed
from utils.subtomos import reassemble_subtomograms, reassemble_subtomograms_v2
from utils.connected_components import connected_components
from torchmetrics.classification import BinaryConfusionMatrix
import wandb
import matplotlib.pyplot as plt
import torchvision.transforms.functional as FT
from datasets import build_dataset
# from models import build_model
from models.denoiseg import Denoiseg
from hydra import initialize, compose
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import mrcfile
import uuid

In [5]:
with initialize(version_base=None, config_path="configs/"):
    cfg = compose(config_name='config.yaml')

OmegaConf.set_struct(cfg, False)
cfg = OmegaConf.merge(cfg, cfg.method)

set_seed(cfg.seed)

criterion = DiceLoss(sigmoid=True)
scoring_fn = DiceMetric(reduction='mean')
cross_entropy = nn.CrossEntropyLoss()
mse_loss = nn.MSELoss(reduction='none')
# mse_loss = nn.L1Loss(reduction='none')
device = torch.device("cuda:3")
device_1 = torch.device("cuda:1")
bcm = BinaryConfusionMatrix().to(device)

dataset = build_dataset(cfg, test=True)

test_dataset, test_val_dataset = torch.utils.data.random_split(dataset, [len(dataset) - 2, 2], generator=torch.Generator().manual_seed(cfg.seed))
print(len(test_dataset), len(test_val_dataset))
# test_dataset = dataset
# test_val_dataset = build_dataset(cfg, val=True)
test_loader = torch.utils.data.DataLoader(
    dataset,
    # num_workers=cfg.load_num_workers,
    num_workers=0,
    batch_size=5,
    shuffle=True,
    # persistent_workers=cfg.persistent_workers,
)            
test_val_loader = torch.utils.data.DataLoader(
    test_val_dataset,
    num_workers=cfg.load_num_workers,
    batch_size=2,
    persistent_workers=cfg.persistent_workers,
)
test_test_loader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.load_num_workers,
    batch_size=1,
    shuffle=False,
    persistent_workers=cfg.persistent_workers,
)
test_entire_gt = mrcfile.read(cfg.test_entire_gt)
test_entire_gt = torch.Tensor(test_entire_gt)[None, None, ...]
test_membrain_pred = mrcfile.read(cfg.test_membrain_pred)
test_membrain_pred = torch.Tensor(test_membrain_pred)[None, None, ...]
# print(test_entire_gt.shape)

model = Denoiseg.load_from_checkpoint(cfg.ckpt_path, map_location=device, config=cfg)

Seed set to 42


36 2


In [6]:
import os
os.environ["OMP_NUM_THREADS"] = "64"
os.environ["MKL_NUM_THREADS"] = "64"

for batch in test_loader:
    print(batch["mask"].shape)

torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([3, 1, 256, 256, 256])


In [7]:
del os.environ["OMP_NUM_THREADS"]
del os.environ["MKL_NUM_THREADS"]
for batch in test_loader:
    print(batch["mask"].shape)


torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([5, 1, 256, 256, 256])
torch.Size([3, 1, 256, 256, 256])


In [4]:
def masked_mse_loss(x_hat, x, mask):
    loss = mse_loss(x_hat, x) * mask
    return loss.sum() / mask.sum()

def dice_from_conf_matrix(conf_matrix):
    return 2 * conf_matrix[1, 1] / (2 * conf_matrix[1, 1] + conf_matrix[0, 1] + conf_matrix[1, 0])

def get_grad_norm(model):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    return total_norm

def process_id(id):
    id = id.split("_")[-1]
    id = id[1:-1].split(" ")
    id = [int(elem) for elem in id]
    return id

def normalize_image(image, max_intensity):

    return image * 255.0 / max_intensity

In [5]:
def macro_test_loop(test_test_loader, model, run, epoch):
    # Test
    conf_matrix = torch.zeros(2, 2)
    preds = []
    ids = []
    with torch.no_grad():
        model.eval()
        for i, batch in enumerate(test_test_loader):
            x_unmasked = batch["unmasked_image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            y_hat, x_hat = model.model(x_unmasked)

            dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().item()
            conf_matrix += bcm(y_hat.sigmoid(), y_out).detach().cpu()
            
            slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))

            # id = id[0].replace(" ", "_")
            preds.append(y_hat.squeeze().detach().cpu().numpy())
            ids.append(process_id(id[0]))
            run.log({
                f"train/{i}/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
                f"train/{i}/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
                f"train/{i}/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                f"train/{i}/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                f"train/{i}/dice": dice,
                # f"train/{i}/dice_loss": dice_loss,
                "epoch": epoch,
            })
        
        reassembled_pred = reassemble_subtomograms_v2(preds, ids, test_entire_gt.shape[2:], 256, 12)
        reassembled_pred = (torch.Tensor(reassembled_pred).sigmoid() > 0.5).to(torch.uint8)[None, None, ...]
        # postprocessed_pred = torch.Tensor(connected_components(reassembled_pred.squeeze().numpy(), 50)).to(torch.bool).to(torch.uint8)
        # print(test_entire_gt[0].sum(dim=1).dim())
        max_intensity = np.max(test_entire_gt[0].sum(dim=1).numpy())
        run.log({
            "train/macro_dice": dice_from_conf_matrix(conf_matrix),
            "train/rsm_macro_dice": scoring_fn(reassembled_pred, test_entire_gt).item(),
            "membrain/macro_dice": scoring_fn(test_membrain_pred[:, :, :test_entire_gt.shape[2], :test_entire_gt.shape[3], :test_entire_gt.shape[4]], test_entire_gt).item(),
            "train/seg_overall_gt": wandb.Image(FT.to_pil_image(normalize_image(test_entire_gt[0, :].sum(dim=1), max_intensity).to(torch.uint8), mode="L")),
            "train/seg_overall": wandb.Image(FT.to_pil_image(normalize_image(reassembled_pred[0].sum(dim=1), max_intensity).to(torch.uint8), mode="L")),
            "membrain/seg_overall": wandb.Image(FT.to_pil_image(normalize_image(test_membrain_pred[0, :].sum(dim=1), max_intensity).to(torch.uint8), mode="L")),
            # "train/seg_overall_postprocessed": wandb.Image(FT.to_pil_image(postprocessed_pred[None, ...].sum(dim=1).to(torch.uint8), mode="L")),
            "epoch": epoch,
        })

def ttt_one_tomo(test_loader, model, n_epochs, lr, momentum, test_val_loader=None, test_test_loader=None):
    
    random_uuid = uuid.uuid4()
    random_string = str(random_uuid).replace("-", "")[:5]

    run = wandb.init(project="cryo-ttt-ttt", name=f"{cfg.exp_name}-lr-{lr:.2E}-{random_string}")
    print(cfg)
    macro_test_loop(test_test_loader, model, run, epoch=0)
    model.train()

    optimizer = torch.optim.SGD(model.parameters(), momentum=momentum, lr=lr)    
    
    step = 0
    for epoch in range(1, n_epochs + 1):
        print(f"Epoch: {epoch}")
        for batch in test_loader:
            # -------------
            # Train step
            model.train()

            x_masked, x_unmasked = batch["image"].to(device), batch["unmasked_image"].to(device)
            y_out, id = batch["label"].to(device), batch["id"]
            mask = batch["mask"].to(device)

            optimizer.zero_grad()

            # y_hat, x_hat = model.model(x_masked)
            # ### Normalize
            # x_masked, mean, std = model.normalize(x_masked)
            # ###

            y_hat, x_hat = model.model(x_masked)

            # ### Denormalize
            # x_hat = model.denormalize(x_hat, mean, std)
            # ###


            loss = masked_mse_loss(x_hat, x_unmasked, mask)

            loss.backward()
            optimizer.step()

            train_grad_norm = get_grad_norm(model.model)
            dice_loss = criterion(y_hat, y_out).detach().item()
            dice = scoring_fn((y_hat.sigmoid() > 0.5).int(), y_out).detach().mean().item()

            val_losses = []
            for batch_val in test_val_loader:
                # --------------
                # Validation step
                model.eval()

                x_masked_val, x_unmasked_val = batch_val["image"].to(device), batch_val["unmasked_image"].to(device)
                y_out_val, id_val = batch_val["label"].to(device), batch_val["id"]
                mask_val = batch_val["mask"].to(device)

                _, x_hat_val = model.model(x_masked_val)

                val_mse_loss = masked_mse_loss(x_hat_val, x_unmasked_val, mask_val)
                val_losses.append(val_mse_loss.item())

                slice_to_log = torch.argmax(y_out[0, 0].sum(dim=(-1, -2)))
            
            # Log results after step i
            run.log({
                "train/mse_loss": loss.item(),
                "train/dice_loss": dice_loss,
                "train/dice": dice,
                "epoch": epoch,
                "step": step,
                # "train/subtomo": wandb.Image(x_unmasked[0, :, slice_to_log]),
                # "train/denoised": wandb.Image(x_hat[0, :, slice_to_log]),
                # "train/seg_gt": wandb.Image(FT.to_pil_image((y_out[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                # "train/seg": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                "train/mask": wandb.Image(FT.to_pil_image((mask[0, :, slice_to_log] * 255).to(torch.uint8), mode="L")),
                "val/mse_loss": np.array(val_losses).mean(),
                # "train/grad_norm": train_grad_norm,
                # "train/intensity_var": x_hat.var(),
                # "train/seg_overall_gt": wandb.Image(FT.to_pil_image(y_out[0, :].sum(dim=1).to(torch.uint8), mode="L")),
                # "train/seg_overall": wandb.Image(FT.to_pil_image((y_hat.sigmoid()[0, :] > 0.5).sum(dim=1).to(torch.uint8), mode="L")),
            })

            step += 1
        macro_test_loop(test_test_loader, model, run, epoch=epoch)
        
    run.finish()
    # return y_hat, val_losses, val_dices

In [5]:
ttt_one_tomo(test_loader, model, 30, 0.0005, 0.9, test_val_loader=test_val_loader, test_test_loader=test_test_loader)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mdijor0310[0m ([33mcryo-diyor[0m). Use [1m`wandb login --relogin`[0m to force relogin


{'exp_name': 'def-vpp-001-norm-shuffle', 'ckpt_path': '/workspaces/cryo/cryo-ttt/src/ttt_ckpt/denoiseg-gn-def-norm_data-shuffle-67fe4/epoch=893-val/dice_loss=0.44.ckpt', 'seed': 42, 'debug': False, 'devices': [2], 'profiler': 'simple', 'strategy': 'auto', 'shuffle': True, 'load_num_workers': 32, 'pin_memory': False, 'persistent_workers': True, 'accumulate_grad_batches': 2, 'gradient_clip_val': None, 'log_every_n_steps': 1, 'method': {'model_name': 'denoiseg', 'train_batch_size': 8, 'eval_batch_size': 1, 'test_batch_size': 1, 'dataset': 'denoiseg', 'mask_ratio': 0.1, 'window_size': 5, 'depth': 2, 'initial_features': 4, 'encoder_dropout': 0.0, 'decoder_dropout': 0.1, 'BN': 'group_norm', 'elu': False, 'learning_rate': 0.0005, 'decay_milestones': [], 'decay_gamma': 1.0, 'max_epochs': 1000, 'grad_clip_norm': None, 'train_data_root_dir': '/media/ssd3/diyor/30-01-deepict-norm-256/work/training_data/def_train', 'val_data_root_dir': '/media/ssd3/diyor/30-01-deepict-norm-256/work/training_data/d

RuntimeError: cuDNN error: CUDNN_STATUS_NOT_INITIALIZED

In [None]:
for batch in test_loader:
    x_unmasked = batch["unmasked_image"]

    print(f"mean: {x_unmasked.mean(dim=(-1, -2, -3, -4))}")
    print(f"var: {x_unmasked.var(dim=(-1, -2, -3, -4))}")


mean: tensor([-0.0078,  0.0039,  0.0066,  0.0021, -0.0166])
var: tensor([1.0411, 1.0905, 1.0543, 1.0585, 0.8818])
mean: tensor([ 0.0134, -0.0009, -0.0142,  0.0024,  0.0047])
var: tensor([1.1368, 1.0887, 1.0626, 1.1120, 1.0318])
mean: tensor([ 0.0020,  0.0087,  0.0012,  0.0035, -0.0145])
var: tensor([1.0689, 0.9181, 0.8958, 1.1233, 0.8515])
