In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
%pip install tensorboard torchinfo torch-summary matplotlib scipy opencv-python SimpleITK pandas

In [None]:
from pathlib import Path
import os

data_temp_path = os.environ["DATA_TEMP"]
root_path = Path(data_temp_path) / "cxr8"

from cxr8_dataset import Cxr8Dataset, get_clahe_transforms

transforms = get_clahe_transforms(clahe_tile_size=8, input_size=448)
train_dataset = Cxr8Dataset(root_path, transform=transforms)
input_size = train_dataset[0]["image"].shape

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
print(f"train_dataset length = {len(train_dataset)}; input_size = {input_size}")

In [None]:
import torch
from torchinfo import summary

from vae import VAE, vae_loss

KERNEL_SIZE = 13
DIRECTIONS = 7
LATENT_DIM = 128  # 64

# model = VAE((1 if dataset_name == "cxr8" else 3, 224, 224), latent_dim).to(device)
model = VAE(input_size, init_kernel_size=KERNEL_SIZE, latent_dim=LATENT_DIM)
print(
    summary(
        model,
        input_size=(37, input_size[0], input_size[1], input_size[2]),
        depth=10,
        col_names=[
            "input_size",
            "kernel_size",
            "mult_adds",
            "num_params",
            "output_size",
            "trainable",
        ],
    )
)

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)

model = model.to(device)
print(set([p.device for p in model.parameters()]))

In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output, display
import torch.nn.functional as F

num_epochs = 3

import torch.optim as optim
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# TODO: get this from the VAE construction
v1_weight = torch.tensor(  
    [1.0**4] * 8
    + [0.9**4] * 8
    + [0.8**4] * 8
    + [0.7**4] * 8
    + [0.6**4] * 8
)
print(f"v1_weight = {v1_weight}")

# torch.autograd.set_detect_anomaly(True)

fig, ax = plt.subplots(3, 5, figsize=(20,12))

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_count = 0
    for batch_idx, batch in enumerate(train_loader):
        x = batch["image"]

        x = x.to(device)
        optimizer.zero_grad()
        
        result_dict = model.forward_dict(x)
        x_recon = result_dict["x_recon"]
        # TODO: can these be passed as **params?
        recon_loss, kldiv_loss, loss = vae_loss(
            x_recon, 
            x, 
            result_dict["mu"], 
            result_dict["log_var"],
            v1_weight, 
            result_dict["x_after_v1"], 
            result_dict["x_before_v1"],
            recon_loss_metric="l1_loss",  
            beta=0.1
        )

        if train_count % 10 == 9:
            # TODO: move this to output to tensorboard
            x = x[0:5].clone()

            x_xformed = model.stn(x)
            x_xformed = x_xformed.cpu().detach().numpy()

            x = x.cpu().detach().numpy()
            
            x_recon = x_recon[0:5].clone()
            x_recon = x_recon.cpu().detach().numpy()

            # blend_data = 0.5 * orig_data + 0.5 * recon_data

            blend_data = np.stack([x_recon, x_xformed, x_recon], axis=-1)

            # print(v.shape)
            for n in range(5):
                ax[0][n].imshow(np.squeeze(x[n]), cmap='bone')
                ax[1][n].imshow(np.squeeze(blend_data[n])) # cmap='bone')
                ax[2][n].imshow(np.squeeze(x_recon[n]), cmap='bone')

            clear_output(wait=True)

            display(plt.gcf())

        if loss.isnan():
            print("loss.isnan()")
                        
        loss.backward()
        # print(f"loss = {loss}")
        
        train_loss += loss.item()
        train_count += 1.0

        # print(list(model.parameters()))
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        print(
            f"Epoch [{epoch+1}/{num_epochs}], Batch: {batch_idx}, Loss: {train_loss / train_count:.6f} ({recon_loss:.6f}/{kldiv_loss:.6f})"
        )


In [None]:
from datetime import datetime

when = datetime.now().strftime("%Y%m%d%H%M%S")
weight_path = (
    Path("weights") / f"{when}_clahe8_kernel{KERNEL_SIZE}_latent{LATENT_DIM}_orisq.zip"
)

torch.save(model.state_dict(), weight_path)