## Now reload and inference

In [None]:
from torch.utils.data import DataLoader

from load_dataset import load_dataset

INPUT_SIZE = (1,448,448)

infer_dataset = load_dataset(
    "cxr8",
    input_size=INPUT_SIZE,
    clahe_tile_size=8,
)

for n in range(10):
    print(f"{infer_dataset.imgs[n][0]} {infer_dataset[n][0].shape}")

In [None]:
import torch
from torchinfo import summary

from vae import VAE, vae_loss

KERNEL_SIZE = 11
DIRECTIONS = 7
LATENT_DIM = 32  # 64
show_summary = True

# model = VAE((1 if dataset_name == "cxr8" else 3, 224, 224), latent_dim).to(device)
model = VAE(INPUT_SIZE, init_kernel_size=KERNEL_SIZE, latent_dim=LATENT_DIM)
model.load_state_dict(torch.load("weights/20230425175434_clahe8_kernel11_latent32_orisq.zip"))
if show_summary:
    print(
        summary(
            model,
            input_size=(37, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]),
            depth=10,
            col_names=[
                "input_size",
                "kernel_size",
                "mult_adds",
                "num_params",
                "output_size",
                "trainable",
            ],
        )
    )

model = model.cpu()
if False:
    device = torch.device("cuda" if torch.cuda.is_available else "cpu")
    print(device)

    model = model.to(device)
    print(set([p.device for p in model.parameters()]))


In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

fig, ax = plt.subplots(2, 5, figsize=(20,8))

for n in range(5):
    print(f"{infer_dataset.imgs[n][0]} {infer_dataset[n][0].shape}")

    # batch = batch.cpu()
    latent_vec = model.encoder(torch.unsqueeze(infer_dataset[n][0], 0))
    print(latent_vec[0].detach().numpy())
    recon_batch, mu, log_var = model(infer_dataset[n][0])
    # print(recon_batch.shape)

    # print(v.shape)
    ax[0][n].imshow(torch.squeeze(infer_dataset[n][0]), cmap='bone')
    ax[1][n].imshow(torch.squeeze(recon_batch.detach()), cmap='bone')

In [None]:
from PIL import Image
from pathlib import Path
import numpy as np

reconst_base_path = Path("C:\Temp\cxr8")
print(reconst_base_path)

for n in range(1000):
    orig_path = Path(infer_dataset.imgs[n][0])
    print(f"{orig_path.stem} {infer_dataset[n][0].shape}")

    latent_vec = model.encoder(torch.unsqueeze(infer_dataset[n][0], 0))
    latent_vec = latent_vec[0][0].detach().numpy()
    with open(f"{reconst_base_path}/latent_vecs.csv", "a") as f:
        f.write(f"{orig_path.stem},")
        f.write(",".join([f"{e:.6f}" for e in latent_vec]))
        f.write("\n")

    # print(latent_vec[0].detach().numpy())

    reconst, mu, log_var = model(infer_dataset[n][0])

    reconst = torch.squeeze(reconst).detach().numpy() * 255.0
    reconst = reconst.astype(np.uint8)
    # print(reconst[0])
    im = Image.fromarray(reconst)
    # print(im)

    reconst_path = reconst_base_path / f"{orig_path.stem}-reconst.png"
    print(f"Saving to {reconst_path}")
    im.save(reconst_path)