In [1]:
import io
import torch
import torch.nn as nn
import PIL.Image
import einops
import matplotlib.pyplot as plt
import numpy as np
import datasets
import math
from IPython.display import HTML
from types import SimpleNamespace
from timm.optim import Mars
from fastprogress import progress_bar, master_bar
from torchvision.transforms.v2 import ToPILImage, PILToTensor, CenterCrop, RandomCrop
from codec import AutoEncoderND
from tft.wavelet import WPT3D, IWPT3D
from tft.utils import compand, decompand
from tft.transforms import RandomCrop3D

In [None]:
medmnist_types = ['organ', 'adrenal', 'fracture', 'nodule', 'synapse', 'vessel']
dataset_train = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='train') for type in medmnist_types])
dataset_valid = datasets.concatenate_datasets([datasets.load_dataset(f"danjacobellis/{type}mnist3d_64", split='validation') for type in medmnist_types])

In [None]:
device = "cuda"
config = SimpleNamespace()
config.F = 8
config.latent_dim = 32
config.input_channels = 1
config.encoder_depth = 6
config.λ = 1e-1
config.max_lr = 1e-3
config.min_lr = config.max_lr / 1e3
config.lr_pow = 6
config.vol_size = 56
config.batch_size = 64
config.num_workers = 32
config.epochs = 96
config.total_steps = config.epochs * (dataset_train.num_rows // config.batch_size)
config.checkpoint = None

In [None]:
model = AutoEncoderND(
    dim=3,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    num_res_blocks=config.encoder_depth
).to(device)

if config.checkpoint:
    checkpoint = torch.load(config.checkpoint,weights_only=False)
    model.load_state_dict(checkpoint['state_dict'])

print(f"{sum(p.numel() for p in model.parameters())/1e6} M parameters")

optimizer = Mars(model.parameters(), lr=1.0, caution=True)

def rc_sched(i_step, config):
    t = i_step / config.total_steps
    return (config.max_lr - config.min_lr) * (1 - ((np.cos(np.pi*t))**(2*config.lr_pow))) + config.min_lr

schedule = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda i_step: rc_sched(i_step, config)
)

In [None]:
rand_crop = RandomCrop3D(config.vol_size)
def pil_to_grid3d(img):
    x = torch.tensor(np.array(img))
    x = einops.rearrange(x, '(a y) (b z) c -> (a b c) y z', a=4, b=4, c=4)
    return x
def collate_fn(batch):
    return torch.cat([
        rand_crop(pil_to_grid3d(sample['image']).unsqueeze(0).unsqueeze(0)) for sample in batch
    ]).to(torch.float)/127.5 - 1.0

In [None]:
learning_rates = [optimizer.param_groups[0]['lr']]
mb = master_bar(range(config.epochs))
losses = []
rate_losses = []
model.train()
for i_epoch in mb:
    model.train()
    dataloader_train = torch.utils.data.DataLoader(
            dataset_train,
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            drop_last=True,
            shuffle=True,
            collate_fn=collate_fn
        )
    pb = progress_bar(dataloader_train, parent=mb)
    for i_batch, x in enumerate(pb):
        x = x.to(device)

        x_hat, rate = model(x)
        rate_losses.append(rate.item())
        loss = torch.nn.functional.mse_loss(x, x_hat).log10()
        losses.append(loss.item())
        loss+= config.λ * rate
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        schedule.step()
        learning_rates.append(optimizer.param_groups[0]['lr'])
    
        pb.comment = f"PSNR: {-10*losses[-1]+6.02:.3g}, R: {rate_losses[-1]:.2g}, LR: {learning_rates[-1]:.2g}"

In [1]:
# batch = dataset_train.select(range(8))
# x = collate_fn(batch)
# from IPython.display import display, clear_output
# from ipywidgets import interact
# volume = x[3,0].numpy()
# def show_slice(index):
#     clear_output(wait=True)  # Clear previous output
#     slice_data = volume[index, :, :]
#     img = ToPILImage()(slice_data/2+0.5).resize((256,256),resample=PIL.Image.Resampling.LANCZOS)
#     display(img)
# interact(show_slice, index=(0, 47, 1));