In [None]:
from workshop.dataset import CUBDataset
from workshop.model import BirdNet
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip, ToTensor, Resize
import matplotlib.pyplot as plt
import torch
import numpy as np

from pathlib import Path

### View examples of training data

In [None]:
datapath = "/local_storage/datasets/CUB_20"
batch_size = 32

ds_train = CUBDataset(
        root=datapath,
        train=True,
        transforms=Compose([
            Resize(256),
            RandomCrop((224, 224), pad_if_needed=True),
            RandomHorizontalFlip(),
            ToTensor()
        ])
    )
data_loader_train = DataLoader(
    ds_train,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1)

ds_test = CUBDataset(
        root=datapath,
        train=False,
        transforms=Compose([
            Resize(256),
            RandomCrop((224, 224), pad_if_needed=True),
            RandomHorizontalFlip(),
            ToTensor()
        ])
    )
data_loader_test = DataLoader(
    ds_test,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1)

In [None]:
EXAMPLES_TO_SHOW = 5
batch = next(iter(data_loader_train))
img_batch = batch[0]
targets = batch[1]
for imgIdx, image in enumerate(img_batch):
    image = image.permute(1,2,0)
    plt.imshow(image)
    plt.show()
    print("Example of: %s" %(ds_train.label_to_class_name(targets[imgIdx].item())))
    if imgIdx+1 == EXAMPLES_TO_SHOW:
        break

### Load trained model

In [None]:
project_dir = Path("..").resolve()

#edit this to be a run you have actually conducted
model = BirdNet(20)
model_path = project_dir/"runs/bs64_lr0.001_wd1e-05_NLZ9MDSO9J/final_model.pt"
model.classifier.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

### View inference examples on test data

In [None]:
EXAMPLES_TO_SHOW=5
batch = next(iter(data_loader_test))
input_ = batch[0]
target = batch[1]
output_ = model(input_)

for datapointIdx, img in enumerate(input_): 
    img = input_[datapointIdx]
    prediction = output_[datapointIdx].argmax().item()
    pltImage = img.permute(1,2,0)
    plt.imshow(pltImage)
    plt.show()
    print("Guessed class: %s, ground truth: %s" %(ds_test.label_to_class_name(prediction), ds_test.label_to_class_name(target[datapointIdx].item())))
    if(datapointIdx+1==EXAMPLES_TO_SHOW):
        break    