## Inference Generator

This notebook will load the CXR8 data, load a model, and then perform inference on the model and output the result to a temp directory

In [1]:
from torch.utils.data import DataLoader

from load_dataset import load_dataset

INPUT_SIZE = (1, 448, 448)
infer_dataset = load_dataset(
    "cxr8",
    input_size=INPUT_SIZE,
    clahe_tile_size=8,
)

for n in range(min(10, len(infer_dataset))):
    print(f"{infer_dataset.imgs[n][0]} {infer_dataset[n][0].shape}")

D:\data\cxr8\filtered\no finding\00000002_000.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_000.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_001.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_002.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_003.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_004.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000005_005.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000006_000.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000007_000.png torch.Size([1, 448, 448])
D:\data\cxr8\filtered\no finding\00000008_001.png torch.Size([1, 448, 448])


In [2]:
import torch
from torchinfo import summary

from vae import VAE, vae_loss

KERNEL_SIZE = 11
DIRECTIONS = 7
LATENT_DIM = 32  # 64
show_summary = True

# model = VAE((1 if dataset_name == "cxr8" else 3, 224, 224), latent_dim).to(device)
model = VAE(INPUT_SIZE, init_kernel_size=KERNEL_SIZE, latent_dim=LATENT_DIM)
model.load_state_dict(torch.load("weights/20230425175434_clahe8_kernel11_latent32_orisq.zip"))
if show_summary:
    print(
        summary(
            model,
            input_size=(37, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]),
            depth=10,
            col_names=[
                "input_size",
                "kernel_size",
                "mult_adds",
                "num_params",
                "output_size",
                "trainable",
            ],
        )
    )

device = torch.device("cuda" if torch.cuda.is_available else "cpu")
print(device)

model = model.to(device)
print(set([p.device for p in model.parameters()]))


make_oriented_map: weights_real.shape = torch.Size([40, 1, 11, 11])
make_oriented_map: weights_real.shape = torch.Size([40, 1, 11, 11])
self.input_size_to_fc = [512, 7, 7]


RuntimeError: Error(s) in loading state_dict for VAE:
	Missing key(s) in state_dict: "localization.0.weight", "localization.0.bias", "localization.3.weight", "localization.3.bias", "fc_rot.0.weight", "fc_rot.0.bias", "fc_rot.2.weight", "fc_rot.2.bias", "fc_xlate.0.weight", "fc_xlate.0.bias", "fc_xlate.2.weight", "fc_xlate.2.bias". 
	Unexpected key(s) in state_dict: "localization.2.weight", "localization.2.bias", "fc_lut.0.bias". 

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

fig, ax = plt.subplots(2, 5, figsize=(20,8))

for n in range(min(5, len(infer_dataset))):

    x_input = infer_dataset[n+1000][0]
    x_input = torch.unsqueeze(x_input, 0)
    x_input = x_input.cuda()
    print(f"x_input.shape = {x_input.shape}")

    warp, lut = model.stn_grid_lut(x_input)

    lut = lut.cpu().detach()
    print(f"lut = {lut.numpy()}")

    warp = warp.cpu().detach()
    print(f"warp = {warp.numpy()}")

    x_recon, mu, log_var = model(x_input)
    print(f"x_recon.shape = {x_recon.shape}")

    x_input = x_input.cpu().detach()
    x_recon = x_recon.cpu().detach()

    ax[0][n].imshow(torch.squeeze(x_input), cmap='bone')
    ax[1][n].imshow(torch.squeeze(x_recon), cmap='bone')