In [None]:
import sys
sys.path.append("..")

In [None]:
import torch
import numpy as np
from src.datamodule import HouseDataModule
from src.model import HPClassifier
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [None]:
dm = HouseDataModule(
        img_dir="output/images_clipped_buffered/",
        label_file="data.csv",
        batch_size=16,
        num_workers=1,
    )
dm.setup(stage="fit")

In [None]:
dm.trn_ds.show(50)

In [None]:
model = HPClassifier.load_from_checkpoint("../.aim/hp-padded/46d4794f3e584704bb9c29f9/checkpoints/epoch:9-step:380-loss:2.739.ckpt")

In [None]:
val_dl = iter(dm.val_dataloader())

In [None]:
batch = next(val_dl)

In [None]:
img, complete_idx, condition_idx, material_idx, security_idx, use_idx = batch

In [None]:
complete_idx = complete_idx.detach().cpu().numpy()
condition_idx = condition_idx.detach().cpu().numpy()
material_idx = material_idx.detach().cpu().numpy()
security_idx = security_idx.detach().cpu().numpy()
use_idx = use_idx.detach().cpu().numpy()

In [None]:
complete_logits, condition_logits, material_logits, security_logits, use_logits = model(img.to(model.device))

In [None]:
complete_preds, condition_preds, material_preds, security_preds, use_preds = torch.argmax(complete_logits, dim=1), torch.argmax(condition_logits, dim=1), torch.argmax(material_logits, dim=1), torch.argmax(security_logits, dim=1), torch.argmax(use_logits, dim=1)

In [None]:
complete_preds = complete_preds.detach().cpu().numpy()
condition_preds = condition_preds.detach().cpu().numpy()
material_preds = material_preds.detach().cpu().numpy()
security_preds = security_preds.detach().cpu().numpy()
use_preds = use_preds.detach().cpu().numpy()

In [None]:
fig, axs = plt.subplots(nrows=8, ncols=2, figsize=(20,40))
for i,ax in enumerate(axs.flatten()):
    _img = img[i]
    _img = _img.permute(1,2,0).detach().cpu().numpy()
    _img = _img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    _img = np.clip(_img, 0, 1)
    actual = ",".join(["A", dm.trn_ds.rcomplete[complete_idx[i]], dm.trn_ds.rcondition[condition_idx[i]], dm.trn_ds.rmaterial[material_idx[i]], dm.trn_ds.rsecurity[security_idx[i]], dm.trn_ds.ruse[use_idx[i]]])
    prediction = ",".join(["P", dm.trn_ds.rcomplete[complete_preds[i]], dm.trn_ds.rcondition[condition_preds[i]], dm.trn_ds.rmaterial[material_preds[i]], dm.trn_ds.rsecurity[security_preds[i]], dm.trn_ds.ruse[use_preds[i]]])
    title = actual + "\n" + prediction
    ax.imshow(_img)
    ax.set_title(title)
    ax.set_axis_off()

## Look at images

In [None]:
!ls output/images_clipped_buffered/

In [None]:
from pathlib import Path
from torchvision.io import read_image

height, width = list(), list()
i = 0
for img_path in Path("output").glob("**/*.jpg"):
    img = read_image(str(img_path))
    height.append(img.shape[1])
    width.append(img.shape[2])
    i += 1
    if i > 100:
        break

In [None]:
ratio = np.array(height)/np.array(width)