# Settings

In [76]:
import torch
import numpy as np
from resnet_model import ResNet18
from torchvision import transforms

In [71]:
from PIL import Image
import os

In [3]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

### Model weights path

In [8]:
model_path = 'model_weights/mahanalobis_final_resnet_k35.ckpt'

# Load model

In [12]:
model = ResNet18(flatten=True).to(device, dtype=torch.float)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

ResNet(
  (trunk): Sequential(
    (0): SimpleBlock(
      (C1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace)
      (relu2): ReLU(inplace)
      (shortcut): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (BNshortcut): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): SimpleBlock(
      (C1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN2): BatchNorm2d(64, eps=1e-05, mome

# Preparation model

### Load support set

In [13]:
support_path = 'support_set/'

In [66]:
images = []
class_mapping = {}
class_number = 1
for root, folders, files in os.walk(support_path):
    if len(files) == 0:
        continue
    
    class_name = root.split('/')[-1]
    
    if class_name != class_mapping:
        class_mapping[class_name] = class_number
        class_number += 1
    
    for f in files:
        images.append({
            'class_name': class_name,
            'filepath': os.path.join(root,f)})

Create mapping to get label back

In [51]:
reverse_class_mapping = {value: key for key, value in class_mapping.items()}

In [108]:
class_mapping

{'24fitness': 1, 'a1': 3, 'adidas-text': 5, 'adidas3': 2, 'airhawk': 4}

Map the support batch in order of class mapping

In [85]:
images_path = []
for i in class_mapping:
    for j in images:
        if j['class_name'] == i:
            images_path.append(j['filepath'])

In [86]:
pil_images = []
for i in images_path:
    pil_images.append(Image.open(i).convert('RGB'))

Transform support set and put in tensor

In [77]:
transform = transforms.Compose([
                transforms.Resize((60,60)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
            ])

In [109]:
support_batch = torch.stack([transform(i) for i in pil_images])

Check if order is good

In [114]:
for i in range(len(pil_images)):
    state = torch.equal(new[i], transform(pil_images[i]))
    if state == False:
        print(i)

In [115]:
support_batch.shape

torch.Size([25, 3, 60, 60])

### Prepare prototypes and inverse covariance matrices

# Inference block

### Load query image