# Import libraries

In [None]:
import io
import os
import sys
from typing import Callable, Optional, Tuple
sys.path.insert(0, "../src")

import gcsfs
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import default_rng
import pandas as pd
from PIL import Image
from sklearn.metrics import ConfusionMatrixDisplay, PrecisionRecallDisplay, precision_recall_curve
import torch
from torch.utils.data import Dataset
import torchvision.transforms as tvt
from tqdm.notebook import trange

from model.net import avg_acc_gpu, avg_f1_score_gpu, confusion_matrix, Net
from utils.utils import load_checkpoint, Params

rng = default_rng()

In [None]:
gfs = gcsfs.GCSFileSystem(project="airesearch-1409")

# Define variables

In [None]:
root = "gs://hm_images/"
model_path = "../experiments/base_model"
img_path = "images"
annotation_path = "annotations"

thr = 0.5

In [None]:
params = Params(
    {
        "num_classes":72,
        "dropout": 0.5,
        "height": 256,
        "width": 256,
        "crop": 224,
        "data_dir": root,
        "batch_size": 128,
        "cuda": torch.cuda.is_available(),
        "device": "cuda:0",
    }
)

# Load dataset

In [None]:
class TestDataset(Dataset):
    """Custom class for Attribute prediction dataset
    Args:
        root: Directory containing the dataset
        file_path: Path of the train/val/test file relative to the root
        transforms: Data augmentation to be done
    """

    def __init__(
        self,
        root: str,
        file_path: str,
        gfs: gcsfs.core.GCSFileSystem,
        transforms: Optional[Callable] = None,
    ) -> None:
        self.root = root
        self.data = pd.read_csv(os.path.join(root, file_path))
        self.transforms = transforms
        self.gfs = gfs

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get an item from the dataset given the index idx"""
        row = self.data.iloc[idx]

        im_name = row["path"]
        im_path = os.path.join(self.root, "images", im_name)
        img = Image.open(io.BytesIO(self.gfs.open(im_path).read())).convert("RGB")

        labels = torch.as_tensor(row[1:], dtype=torch.float32)

        if self.transforms is not None:
            img = self.transforms(img)

        return img, labels

    def __len__(self) -> int:
        """Length of the dataset"""
        return len(self.data)

transform = tvt.Compose(
    [
        tvt.Resize((params.height, params.width)),
        tvt.CenterCrop(params.crop),
        tvt.ToTensor(),
        tvt.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
test_ds = TestDataset(params.data_dir, "annotations/test.csv", gfs, transform)

In [None]:
start = rng.integers(len(test_ds) - params.batch_size - 1)
inp_data, labels = [], []
for i in trange(start, start + params.batch_size):
    img, label = test_ds[i]
    inp_data.append(img)
    labels.append(label)
inp_data = torch.stack(inp_data, 0)
labels = torch.stack(labels, 0)

print(inp_data.shape)
print(labels.shape)

In [None]:
data = pd.read_csv(os.path.join(root, "annotations/test.csv"))
cols = data.columns.tolist()[1:]
cols

# Load model

In [None]:
model = Net(params)
load_checkpoint(os.path.join(model_path, "best.pth.tar"), model);

# Prediction

In [None]:
model.eval()
if params.cuda:
    model.to(params.device)

with torch.no_grad():
    if params.cuda:
        inp_data = inp_data.to(params.device)
        labels = labels.to(params.device)
    output = model(inp_data)

In [None]:
labels_cpu = labels.cpu().numpy()
preds = torch.sigmoid(output).cpu().numpy()

In [None]:
mat = confusion_matrix(output, labels, thr).numpy()

print(f"Avg. Accuracy: {avg_acc_gpu(output, labels, thr):.3f} @ {thr}")
print(f"Avg. F1 score: {avg_f1_score_gpu(output, labels, thr):.3f} @ {thr}")

In [None]:
idx = 0

fig, ax = plt.subplots(1, 2, figsize=(15, 6))
ConfusionMatrixDisplay(mat[idx]).plot(ax=ax[0], cmap="Blues");

prec, recall, _ = precision_recall_curve(
    labels_cpu[:, idx], preds[:, idx]
)
PrecisionRecallDisplay(prec, recall).plot(ax=ax[1])
fig.suptitle(f"{cols[idx]}", fontsize=16)
fig.tight_layout()

# Visualize

In [None]:
test_imgs = inp_data.cpu().numpy() * np.asarray([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
test_imgs += np.asarray([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
test_imgs = test_imgs.clip(0.0, 1.0)

col_names = np.asarray(cols)

In [None]:
i = rng.integers(params.batch_size)

plt.imshow(test_imgs[i, ...].transpose(1, 2, 0))
print(f"Labels: {col_names[labels_cpu[i].astype(bool)]}")
print(f"Predictions: {col_names[preds[i] > thr]}")