In [0]:
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
import torchvision.transforms as tvt

from model.net import Net
from utils.utils import load_checkpoint, Params

np.random.seed(0)

In [0]:
root = "/dbfs/sketches"
model_dir = "/experiments/base_model"

In [0]:
class SketchesDataset(Dataset):
    """Custom class for Sketches dataset"""

    def __init__(self, root: str, csv_file: str, transform: tvt = None) -> None:
        """Get the filenames and labels of images from a csv file.
        Args:
            root: Directory containing the data
            csv_file: file containing the data
            transform: Transformation to apply on images
        """
        self.root = root
        self.data = pd.read_csv(os.path.join(root, csv_file))
        self.transform = transform
        self.label_names = self.data.columns[1:].values

    def __len__(self) -> int:
        """Return the size of the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, np.ndarray]:
        """Get an item from the dataset given the index idx"""
        row = self.data.iloc[idx]

        im_name = row["Image Id"] + ".png"
        im_path = os.path.join(self.root, "images", im_name)
        img = Image.open(im_path).convert("RGB")

        labels = torch.tensor(row[1:], dtype=torch.float32)

        if self.transform is not None:
            img = self.transform(img)

        return img, labels

In [0]:
def get_transform(mode: str, params: Params) -> tvt.Compose:
    """Data augmentation
    Args:
        is_train: If the dataset is training
    Returns:
        Composition of all the data transforms
    """
    trans = [
        tvt.Resize((params.height, params.width)),
        tvt.ToTensor(),
        tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
    if mode == "train":
        trans += [
            tvt.RandomHorizontalFlip(params.flip),
            tvt.ColorJitter(
                brightness=params.brightness,
                contrast=params.contrast,
                saturation=params.saturation,
                hue=params.hue
            ),
            tvt.RandomRotation(params.degree)
        ]
    return tvt.Compose(trans)

In [0]:
param_path = os.path.join(root+model_dir, "params.yml")
params = Params(param_path)
print(params.__dict__)

In [0]:
model = Net(params)
load_checkpoint(root + model_dir + "/best.pth.tar", model)
model.eval()

In [0]:
trans = get_transform("val", params)

In [0]:
dataset = SketchesDataset(root, "val_sketches_mcml.csv", transform=trans)

In [0]:
thr = 0.5

In [0]:
idx = np.random.randint(0, len(dataset))
img, label = dataset[idx]

In [0]:
pred = model(img.unsqueeze(0))
probs = torch.sigmoid(pred.detach())
probs = (probs > thr)[0].numpy()

In [0]:
un_img = (img * torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)) + torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
plt.imshow(un_img.numpy().transpose(1, 2, 0))
plt.axis("off")

gt = ", ".join(dataset.label_names[label == 1.0])
print(f"Groundtruth labels: {gt}")
p = ", ".join(dataset.label_names[probs])
print(f"Predictions: {p}")