# Settings

In [1]:
import torch
import numpy as np
from resnet_model import ResNet18
from protonet import *
from torchvision import transforms

In [2]:
from PIL import Image
import os
import json

In [3]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")

### Model weights path

In [4]:
model_path = 'model_weights/mahanalobis_final_resnet_k35.ckpt'

In [72]:
os.path.join(os.listdir('model_weights')

['mahanalobis_final_resnet_k35.ckpt']

### Model parameters

In [25]:
k_way = 5
n_shot = 5
embedding_dim = 256
distance = 'gaussian'

# Load model

In [77]:
model = ResNet18(flatten=True).to(device, dtype=torch.float)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

ResNet(
  (trunk): Sequential(
    (0): SimpleBlock(
      (C1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace)
      (relu2): ReLU(inplace)
      (shortcut): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (BNshortcut): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): SimpleBlock(
      (C1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (C2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (BN2): BatchNorm2d(64, eps=1e-05, mome

# Preparation model

### Load support set

In [7]:
support_path = 'support_set/'

In [48]:
images = []
class_mapping = {}
class_number = 0
for root, folders, files in os.walk(support_path):
    if len(files) == 0:
        continue
    
    class_name = root.split('/')[-1]
    
    if class_name != class_mapping:
        class_mapping[class_name] = class_number
        class_number += 1
    
    for f in files:
        images.append({
            'class_name': class_name,
            'filepath': os.path.join(root,f)})

Create mapping to get label back

In [49]:
reverse_class_mapping = {value: key for key, value in class_mapping.items()}

In [50]:
class_mapping

{'24fitness': 0, 'a1': 2, 'adidas-text': 4, 'adidas3': 1, 'airhawk': 3}

Map the support batch in order of numbering in reverse class mapping

In [53]:
images_path = []
for i in range(class_number):
    for j in images:
        if j['class_name'] == reverse_class_mapping[i]:
            images_path.append(j['filepath'])

In [54]:
pil_images = []
for i in images_path:
    pil_images.append(Image.open(i).convert('RGB'))

Transform support set and put in tensor

In [55]:
transform = transforms.Compose([
                transforms.Resize((60,60)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
            ])

In [56]:
support_batch = torch.stack([transform(i) for i in pil_images])

Check if order is good

In [57]:
for i in range(len(pil_images)):
    state = torch.equal(support_batch[i], transform(pil_images[i]))
    if state == False:
        print(i)

In [58]:
support_batch.shape

torch.Size([25, 3, 60, 60])

### Prepare prototypes and inverse covariance matrices

In [59]:
embeddings = model(support_batch)
support, raw_covariance_matrix = torch.split(embeddings, [embedding_dim, embedding_dim], dim=1)
inv_covariance_matrix = calculate_inverse_covariance_matrix(raw_covariance_matrix, 1.0)
S = compute_matrix(inv_covariance_matrix, k_way, n_shot, embedding_dim)
prototypes = compute_prototypes(support,k_way, n_shot)

### Store prototypes, S and reverse label mapping

In [64]:
prototypes.shape

torch.Size([5, 256])

In [65]:
S.shape

torch.Size([5, 256])

In [66]:
torch.save(prototypes, 'preparation_files/prototypes.pt')
torch.save(S, 'preparation_files/S.pt')

In [68]:
reverse_class_mapping

{0: '24fitness', 1: 'adidas3', 2: 'a1', 3: 'airhawk', 4: 'adidas-text'}

In [71]:
with open('preparation_files/reverse_class_mapping.json', 'w') as outfile:
    json.dump(reverse_class_mapping, outfile)

# Inference block

### Load query image

In [7]:
os.listdir('query_images')

['adidas_img000259_0.jpg',
 'adidas_img000508_3.jpg',
 'redbull_img000344_5.jpg',
 'Walmart_img000083_2.jpg',
 'airhawk_img000053_0.jpg']

### process query image

In [56]:
query_path = 'query_images/Walmart_img000083_2.jpg'

In [57]:
transform = transforms.Compose([
                transforms.Resize((60,60)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])
            ])

In [58]:
query = torch.unsqueeze(transform(Image.open(query_path).convert('RGB')),0)

In [59]:
query.shape

torch.Size([1, 3, 60, 60])

In [60]:
query_output = model(query)

In [61]:
query_embedding, _ = torch.split(query_output, [embedding_dim, embedding_dim], dim=1)

### Load preparation files

In [62]:
prototypes = torch.load('preparation_files/prototypes.pt')
S = torch.load('preparation_files/S.pt')

In [63]:
with open('preparation_files/reverse_class_mapping.json') as json_file:
    reverse_class_mapping = json.load(json_file)

In [64]:
reverse_class_mapping

{'0': '24fitness',
 '1': 'adidas3',
 '2': 'a1',
 '3': 'airhawk',
 '4': 'adidas-text'}

### Make inference

In [65]:
distances = pairwise_distances(query_embedding, prototypes, distance, S)

In [66]:
distances.shape

torch.Size([1, 5])

In [67]:
distances

tensor([[19.7987, 32.2502, 37.2896, 32.6666, 37.1529]], grad_fn=<SqrtBackward>)

In [68]:
y_pred = (-distances).softmax(dim=1)

In [69]:
y_pred

tensor([[9.9999e-01, 3.9122e-06, 2.5339e-08, 2.5797e-06, 2.9051e-08]],
       grad_fn=<SoftmaxBackward>)

In [70]:
print('min distance:', torch.min(distances).data.numpy())
print('max prob:', torch.max(y_pred).data.numpy())
print('label:', reverse_class_mapping[str(torch.argmax(y_pred).data.numpy())])

min distance: 19.798737
max prob: 0.99999344
label: 24fitness


In [73]:
device

device(type='cpu')