In [None]:
%cd /content
!gdown https://drive.google.com/uc?id=1AWOgXpD7KZEOSXudgER9kSrH0E9Juv-G
!unzip -q LATransformer.zip

## Import Libraries

In [None]:
from __future__ import print_function

!pip install timm

import os
import time
import random
import zipfile
from itertools import chain

import timm
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from LATransformer.model import ClassBlock, LATransformer
from LATransformer.utils import save_network, update_summary

os.environ['CUDA_VISIBLE_DEVICES']='0'
device = "cuda"

Collecting timm
  Downloading timm-0.4.12-py3-none-any.whl (376 kB)
[K     |████████████████████████████████| 376 kB 4.3 MB/s 
Installing collected packages: timm
Successfully installed timm-0.4.12


### Set Config Parameters

In [None]:
batch_size = 32
num_epochs = 1000
lr = 3e-3
gamma = 0.7
unfreeze_after=2
lr_decay=.8
lmbd = 8

## Load Data

In [None]:
transform_train_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_val_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
'train': transforms.Compose( transform_train_list ),
'val': transforms.Compose(transform_val_list),
}

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
image_datasets = {}
data_dir = "/content/LATransformer/data/"

image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                          data_transforms['train'])
train_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True )
valid_loader = DataLoader(dataset = image_datasets['train'], batch_size=batch_size, shuffle=True)
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
#                                              shuffle=True, num_workers=8, pin_memory=True) # 8 workers may work faster
#               for x in ['train', 'val']}
# dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(len(class_names))

62


## Load Model

In [None]:
# Load pre-trained ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base= vit_base.to(device)
vit_base.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn



###  Train

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def validate(model, loader, loss_fn):
    batch_time_m = AverageMeter()
    losses_m = AverageMeter()
    top1_m = AverageMeter()
    top5_m = AverageMeter()

    model.eval()
    epoch_accuracy = 0
    epoch_loss = 0
    end = time.time()
    last_idx = len(loader) - 1
    
    running_loss = 0.0
    running_corrects = 0.0

    with torch.no_grad():
        for input, target in tqdm(loader):

            input, target = input.to(device), target.to(device)
            
            output = model(input)
            
            score = 0.0
            sm = nn.Softmax(dim=1)
            for k, v in output.items():
                score += sm(output[k])
            _, preds = torch.max(score.data, 1)

            loss = 0.0
            for k,v in output.items():
                loss += loss_fn(output[k], target)


            batch_time_m.update(time.time() - end)
            acc = (preds == target.data).float().mean()
            epoch_loss += loss/len(loader)
            epoch_accuracy += acc / len(loader)
            
            print(f"Epoch : {epoch+1} - val_loss : {epoch_loss:.4f} - val_acc: {epoch_accuracy:.4f}", end="\r")
    print()    
    metrics = OrderedDict([('val_loss', epoch_loss.data.item()), ("val_accuracy", epoch_accuracy.data.item())])


    return metrics

In [None]:
def train_one_epoch(
        epoch, model, loader, optimizer, loss_fn,
        lr_scheduler=None, saver=None, output_dir='', 
        loss_scaler=None, model_ema=None, mixup_fn=None):

 

    
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    losses_m = AverageMeter()

    model.train()
    epoch_accuracy = 0
    epoch_loss = 0
    end = time.time()
    last_idx = len(loader) - 1
    num_updates = epoch * len(loader)
    running_loss = 0.0
    running_corrects = 0.0

    for data, target in tqdm(loader):
        data, target = data.to(device), target.to(device)

            
        data_time_m.update(time.time() - end)

        optimizer.zero_grad()
        output = model(data)
        score = 0.0
        sm = nn.Softmax(dim=1)
        for k, v in output.items():
            score += sm(output[k])
        _, preds = torch.max(score.data, 1)
        
        loss = 0.0
        for k,v in output.items():
            loss += loss_fn(output[k], target)
        loss.backward()

        optimizer.step()

        batch_time_m.update(time.time() - end)
        
#         print(preds, target.data)
        acc = (preds == target.data).float().mean()
        
#         print(acc)
        epoch_loss += loss/len(loader)
        epoch_accuracy += acc / len(loader)
#         if acc:
#             print(acc, epreds, target.data)
        print(
    f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}"
, end="\r")

    print()

    return OrderedDict([('train_loss', epoch_loss.data.item()), ("train_accuracy", epoch_accuracy.data.item())])


In [None]:
def freeze_all_blocks(model):
    frozen_blocks = 12
    for block in model.model.blocks[:frozen_blocks]:
        for param in block.parameters():
            param.requires_grad=False
    

In [None]:
def unfreeze_blocks(model, amount= 1):
    
    for block in model.model.blocks[11-amount:]:
        for param in block.parameters():
            param.requires_grad=True
    return model

## Training Loop

In [None]:
# Create LA Transformer
model = LATransformer(vit_base, lmbd).to(device)
print(model.eval())

# loss function
criterion = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.Adam(model.parameters(),weight_decay=5e-4, lr=lr)

# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
freeze_all_blocks(model)

LATransformer(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
      

In [None]:
best_acc = 0.0
y_loss = {} # loss history
y_loss['train'] = []
y_loss['val'] = []
y_err = {}
y_err['train'] = []
y_err['val'] = []
print("training...")
output_dir = ""
best_acc = 0
name = "la_with_lmbd_{}".format(lmbd)

try:
    os.mkdir("model/" + name)

except:
    pass
output_dir = "model/" + name
unfrozen_blocks = 0
os.makedirs(output_dir, exist_ok=True)

for epoch in range(num_epochs):

    if epoch%unfreeze_after==0:
        unfrozen_blocks += 1
        model = unfreeze_blocks(model, unfrozen_blocks)
        optimizer.param_groups[0]['lr'] *= lr_decay 
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print("Unfrozen Blocks: {}, Current lr: {}, Trainable Params: {}".format(unfrozen_blocks, 
                                                                             optimizer.param_groups[0]['lr'], 
                                                                             trainable_params))

    train_metrics = train_one_epoch(
        epoch, model, train_loader, optimizer, criterion,
        lr_scheduler=None, saver=None)

    eval_metrics = validate(model, valid_loader, criterion)


    # update summary
    update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
                   write_header=True)

    # deep copy the model
    last_model_wts = model.state_dict()
    if eval_metrics['val_accuracy'] > best_acc:
        best_acc = eval_metrics['val_accuracy']
        save_network(model, epoch,name)
        print("SAVED!")

training...
Unfrozen Blocks: 1, Current lr: 6.7553994410557506e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 1 - loss : 20.6106 - acc: 0.9879


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 1 - val_loss : 19.0567 - val_acc: 0.9839
SAVED!


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 2 - loss : 20.6025 - acc: 0.9859


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 2 - val_loss : 18.9239 - val_acc: 0.9849
SAVED!
Unfrozen Blocks: 2, Current lr: 5.404319552844601e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 3 - loss : 20.2954 - acc: 0.9869


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 3 - val_loss : 18.8069 - val_acc: 0.9859
SAVED!


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 4 - loss : 20.1325 - acc: 0.9849


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 4 - val_loss : 18.7170 - val_acc: 0.9839
Unfrozen Blocks: 3, Current lr: 4.323455642275681e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 5 - loss : 19.9176 - acc: 0.9879


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 5 - val_loss : 18.7147 - val_acc: 0.9839


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 6 - loss : 19.9928 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 6 - val_loss : 18.3632 - val_acc: 0.9839
Unfrozen Blocks: 4, Current lr: 3.458764513820545e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 7 - loss : 19.9575 - acc: 0.9869


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 7 - val_loss : 18.2254 - val_acc: 0.9839


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 8 - loss : 19.8022 - acc: 0.9859


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 8 - val_loss : 18.1219 - val_acc: 0.9839
Unfrozen Blocks: 5, Current lr: 2.7670116110564363e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 9 - loss : 19.5664 - acc: 0.9869


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 9 - val_loss : 18.2833 - val_acc: 0.9839


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 10 - loss : 19.7103 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 10 - val_loss : 18.0887 - val_acc: 0.9839
Unfrozen Blocks: 6, Current lr: 2.2136092888451492e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 11 - loss : 19.6042 - acc: 0.9909


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 11 - val_loss : 18.1363 - val_acc: 0.9839


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 12 - loss : 19.6022 - acc: 0.9879


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 12 - val_loss : 17.9558 - val_acc: 0.9849
Unfrozen Blocks: 7, Current lr: 1.7708874310761196e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 13 - loss : 19.5457 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 13 - val_loss : 18.0512 - val_acc: 0.9869
SAVED!


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 14 - loss : 19.6156 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 14 - val_loss : 17.9720 - val_acc: 0.9869
Unfrozen Blocks: 8, Current lr: 1.4167099448608957e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 15 - loss : 19.3437 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 15 - val_loss : 18.1298 - val_acc: 0.9859


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 16 - loss : 19.5186 - acc: 0.9889


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 16 - val_loss : 17.8525 - val_acc: 0.9839
Unfrozen Blocks: 9, Current lr: 1.1333679558887166e-06, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 17 - loss : 19.1921 - acc: 0.9869


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 17 - val_loss : 18.0220 - val_acc: 0.9869


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 18 - loss : 19.3099 - acc: 0.9919


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 18 - val_loss : 17.8191 - val_acc: 0.9859
Unfrozen Blocks: 10, Current lr: 9.066943647109733e-07, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 19 - loss : 19.1080 - acc: 0.9879


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 19 - val_loss : 17.9404 - val_acc: 0.9849


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 20 - loss : 19.3079 - acc: 0.9929


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 20 - val_loss : 17.8860 - val_acc: 0.9839
Unfrozen Blocks: 11, Current lr: 7.253554917687787e-07, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 21 - loss : 19.2115 - acc: 0.9889


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 21 - val_loss : 17.8046 - val_acc: 0.9849


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 22 - loss : 19.0883 - acc: 0.9919


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 22 - val_loss : 17.9578 - val_acc: 0.9859
Unfrozen Blocks: 12, Current lr: 5.802843934150231e-07, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 23 - loss : 19.3052 - acc: 0.9889


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 23 - val_loss : 17.6163 - val_acc: 0.9899
SAVED!


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 24 - loss : 19.1742 - acc: 0.9909


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 24 - val_loss : 17.7209 - val_acc: 0.9869
Unfrozen Blocks: 13, Current lr: 4.642275147320185e-07, Trainable Params: 91841537


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 25 - loss : 19.0299 - acc: 0.9899


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 25 - val_loss : 17.8628 - val_acc: 0.9879


  0%|          | 0/31 [00:00<?, ?it/s]

Epoch : 26 - loss : 0.6711 - acc: 0.0302

KeyboardInterrupt: ignored

# Testing

In [None]:
from __future__ import print_function

import os
import time
import glob
import random
import zipfile
from itertools import chain

import timm
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
from torchvision import models
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

from LATransformer.model import ClassBlock, LATransformer, LATransformerTest
from LATransformer.utils import save_network, update_summary, get_id
from LATransformer.metrics import rank1, rank5, rank10, calc_map
device = "cpu"

In [None]:
# Load ViT
vit_base = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=751)
vit_base= vit_base.to(device)

# Create La-Transformer
model = LATransformerTest(vit_base, lmbd=8).to(device)

# Load LA-Transformer
name = "la_with_lmbd_8"
save_path = os.path.join('./model',name,'net_best.pth')
model.load_state_dict(torch.load(save_path), strict=False)
model.eval()

LATransformerTest(
  (model): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=768, out_features=768, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (act): GELU()
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (1): Block(
  

In [None]:
transform_query_list = [
    transforms.Resize((224,224), interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
transform_gallery_list = [
    transforms.Resize(size=(224,224),interpolation=3), #Image.BICUBIC
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]
data_transforms = {
'query': transforms.Compose( transform_query_list ),
'gallery': transforms.Compose(transform_gallery_list),
}

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
batch_size_test = 8
image_datasets = {}
data_dir = "/content/LATransformer/data/val"

image_datasets['query'] = datasets.ImageFolder(os.path.join(data_dir, 'query'),
                                          data_transforms['query'])
image_datasets['gallery'] = datasets.ImageFolder(os.path.join(data_dir, 'gallery'),
                                          data_transforms['gallery'])
query_loader = DataLoader(dataset = image_datasets['query'], batch_size=batch_size_test, shuffle=False )
gallery_loader = DataLoader(dataset = image_datasets['gallery'], batch_size=batch_size_test, shuffle=False)

class_names = image_datasets['query'].classes
print(len(class_names))

12


In [None]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook


def extract_feature(model,dataloaders):
    
    features =  torch.FloatTensor()
    count = 0
    idx = 0
    for data in tqdm(dataloaders):
        img, label = data
        img, label = img.to(device), label.to(device)

        output = model(img)

        n, c, h, w = img.size()
        
        count += n
        features = torch.cat((features, output.detach().cpu()), 0)
        idx += 1
    return features

In [None]:
# Extract Query Features
query_feature= extract_feature(model, query_loader)

# Extract Gallery Features
gallery_feature = extract_feature(model, gallery_loader)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/23 [00:00<?, ?it/s]

In [None]:
import os 
def get_id(img_path):
    camera_id = []
    labels = []
    for path, v in img_path:
        #filename = path.split('/')[-1]
        filename = os.path.basename(path)
        label = os.path.dirname(path).split(os.path.sep)[-1]
        camera = filename[:-4].split('_')[1]
        if label[0:2]=='-1':
            labels.append(-1)
        else:
            labels.append(int(label))
        camera_id.append(int(camera))
    return camera_id, labels

In [None]:
# Retrieve labels
gallery_path = image_datasets['gallery'].imgs
query_path = image_datasets['query'].imgs

gallery_cam,gallery_label = get_id(gallery_path)
query_cam,query_label = get_id(query_path)

In [None]:
concatenated_query_vectors = []
for query in tqdm(query_feature):
   
    fnorm = torch.norm(query, p=2, dim=1, keepdim=True)*np.sqrt(14)
   
    query_norm = query.div(fnorm.expand_as(query))
    
    concatenated_query_vectors.append(query_norm.view((-1))) # 14*768 -> 10752

concatenated_gallery_vectors = []
for gallery in tqdm(gallery_feature):
   
    fnorm = torch.norm(gallery, p=2, dim=1, keepdim=True) *np.sqrt(14)
   
    gallery_norm = gallery.div(fnorm.expand_as(gallery))
    
    concatenated_gallery_vectors.append(gallery_norm.view((-1))) # 14*768 -> 10752
  

  0%|          | 0/28 [00:00<?, ?it/s]

  0%|          | 0/181 [00:00<?, ?it/s]

In [None]:
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.1.post2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (89.7 MB)
[K     |████████████████████████████████| 89.7 MB 5.9 kB/s 
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.1.post2


In [None]:
import faiss
import numpy as np


index = faiss.IndexIDMap(faiss.IndexFlatIP(10752))

index.add_with_ids(np.array([t.numpy() for t in concatenated_gallery_vectors]),np.array(gallery_label))

# xb = np.array([t.numpy() for t in concatenated_gallery_vectors]).astype(dtype=np.float32)
# index = faiss.IndexFlatL2(10752) 
# ids = np.array(gallery_label, dtype=np.float32)
# index2 = faiss.IndexIDMap(index)
# index2.add_with_ids(xb, ids)


def search(query: str, k=1):
    encoded_query = query.unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    return top_k

In [None]:
rank1_score = 0
rank5_score = 0
rank10_score = 0
ap = 0
count = 0
for query, label in zip(concatenated_query_vectors, query_label):
    count += 1
    label = label
    output = search(query, k=10)
#     print(output)
    rank1_score += rank1(label, output) 
    rank5_score += rank5(label, output) 
    rank10_score += rank10(label, output) 
    print("Correct: {}, Total: {}, Incorrect: {}".format(rank1_score, count, count-rank1_score), end="\r")
    ap += calc_map(label, output)

print("Rank1: {}, Rank5: {}, Rank10: {}, mAP: {}".format(rank1_score/len(query_feature), 
                                                         rank5_score/len(query_feature), 
                                                         rank10_score/len(query_feature), ap/len(query_feature)))    

Correct: 1, Total: 1, Incorrect: 0Correct: 2, Total: 2, Incorrect: 0Correct: 3, Total: 3, Incorrect: 0Correct: 4, Total: 4, Incorrect: 0Correct: 5, Total: 5, Incorrect: 0Correct: 6, Total: 6, Incorrect: 0Correct: 7, Total: 7, Incorrect: 0Correct: 8, Total: 8, Incorrect: 0Correct: 9, Total: 9, Incorrect: 0Correct: 10, Total: 10, Incorrect: 0Correct: 11, Total: 11, Incorrect: 0Correct: 12, Total: 12, Incorrect: 0Correct: 13, Total: 13, Incorrect: 0Correct: 14, Total: 14, Incorrect: 0Correct: 15, Total: 15, Incorrect: 0Correct: 16, Total: 16, Incorrect: 0Correct: 17, Total: 17, Incorrect: 0Correct: 18, Total: 18, Incorrect: 0Correct: 19, Total: 19, Incorrect: 0Correct: 20, Total: 20, Incorrect: 0Correct: 21, Total: 21, Incorrect: 0Correct: 22, Total: 22, Incorrect: 0Correct: 23, Total: 23, Incorrect: 0Correct: 24, Total: 24, Incorrect: 0Correct: 25, Total: 25, Incorrect: 0Correct: 26, Total: 26, Incorrect: 0Correct: 27, Total: 27, Incorrect: 0Correct: 28, Total: