In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from torch.utils.data import DataLoader
from unet import UNet
from data_loading import XRayDataset
import lightning as L
import pathlib
import torch
import data_processing
import numpy as np
from lightning.pytorch.callbacks import TQDMProgressBar
from torchvision import transforms
import matplotlib.pyplot as plt
from lightning.pytorch.loggers import TensorBoardLogger
import lightning as L
import skimage as ski
import worker_seed
import loss

In [None]:
g = torch.Generator()
g.manual_seed(0)

In [None]:
xray_dataset_training = XRayDataset(pathlib.Path.cwd() / "./datasample/cc_no_aug",
                           data_processing.preprocess_inputs,
                           data_processing.preprocess_labels)
xray_dataset_validation = XRayDataset(pathlib.Path.cwd() / "./datasample/cc_no_aug" / "validation",
                           data_processing.preprocess_inputs,
                           data_processing.preprocess_labels)

train_dataloader = DataLoader(xray_dataset_training,
                              batch_size=1,
                              shuffle=True,
                              num_workers=7,
                              worker_init_fn=worker_seed.seed_worker,
                              generator=g,)
val_dataloader = DataLoader(xray_dataset_validation, batch_size=1, shuffle=False)

In [None]:
print(xray_dataset_training.__len__())

In [None]:
checkpoint = torch.load("./checkpoints/ckp_cc_aug_unet.ckpt", map_location=torch.device('cpu'))
loaded = UNet(loss.Loss)
loaded.load_state_dict(checkpoint["state_dict"])

In [None]:
model = loaded.eval()  

In [None]:
iter_test_loader = iter(val_dataloader)

In [None]:
sample_input, sample_label = next(iter_test_loader)
sample_input, sample_label = sample_input[0].unsqueeze(0), sample_label[0].unsqueeze(0)
model.eval()
with torch.no_grad():
    predictions = model.forward(sample_input)

In [None]:
print(sample_input.shape)
torch.max(predictions)
print(predictions.squeeze().size())
print(sample_input.size())

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(10, 3))
imgs = [data_processing.resize_to_roughy_input_size(sample_input.squeeze()).permute(1, 2, 0), 
        data_processing.postprocess(data_processing.resize_to_roughy_input_size(sample_label.squeeze())),
        data_processing.postprocess(data_processing.resize_to_roughy_input_size(predictions.squeeze()))]

for ax, img in zip(axs, imgs):
    ax.imshow(img, cmap="gray", vmin=0, vmax=255)