# Setup

In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from functools import reduce

USE_DRIVE = False

In [None]:
# check if Colab is used to mount the drive
USING_COLAB = False
if USE_DRIVE:
	try:
		from google.colab import drive
		# store logs on google drive to avoid colab crash
		drive.mount('/content/drive')
		USING_COLAB = True
	except:
		USING_COLAB = False
		pass

# download the images
if not (os.path.isfile("annotations_train.csv") and os.path.isdir("test") and os.path.isdir("train")):
	# !wget -c 'https://github.com/itLovaz/DL_Project/raw/main/dataset.zip'
	!gdown --id 14upp5aMYjBwDR57u_H4hMfMKYx5zw8k6
	!unzip dataset.zip > /dev/null
	!rm -r dataset.zip

# Download our pretrained classification model
if not os.path.isdir("pretrained_models"):
	!gdown --id 116ScZ5oGOQVKDFWLC6DI-4q61-FIw8l4
	!mkdir pretrained_models
	!mv resnet_18_acc_70_30_seed_26_88.63.h5 pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5

# Datasets

In [None]:
# group file images by id and create a dict
def group_files_per_id(file_list):
	file_list_per_id = {}
	for file in file_list:
		img_id = int(file.split("_")[0])
		if file_list_per_id.get(img_id) is None:
			file_list_per_id[img_id] = [file]
		else:
			file_list_per_id[img_id].append(file)
	return file_list_per_id

## Training Dataset

In [None]:
class ReidTrainDataset(Dataset):
	def __init__(self, file_list, directory, min_group_length = 12, n_select = 12) -> None:
		super().__init__()
		assert min_group_length >= n_select
		self.dir = directory
		self.n_select = n_select

		self.transform = T.Compose([
			T.ToTensor(),
			T.Normalize(mean=[0.], std=[1.])
		])

		self.augmentation = T.Compose([
			T.RandomRotation(degrees=(0,15)),
			T.RandomCrop((96, 48)),
			T.Resize((128, 64)),
			T.RandomHorizontalFlip(p=0.5),
		])

		# group images by id
		file_list_per_id = group_files_per_id(file_list)

		# add augmentation to images to reach the minimmum number of images per id
		for k, v in file_list_per_id.items():
			non_augm_list = v.copy()
			for i in range(min_group_length - len(v)):
				file_list_per_id[k].append("#" + np.random.choice(non_augm_list))

		self.file_batches = list(file_list_per_id.values())
	
	def __getitem__(self, index): # get directly a minibatch of images with the same id
		minibatch = torch.Tensor()
		files = np.array(self.file_batches[index])
		np.random.shuffle(files)
		for file in files[:self.n_select]:
			augment = file[0] == "#"
			img = Image.open(os.path.join(self.dir, file[1:] if augment else file))
			if augment:
				img = self.augmentation(img)
			img = self.transform(img)
			img = img[None, :]
			minibatch = torch.cat((minibatch, img))

		return minibatch

	def __len__(self):
		return len(self.file_batches)

## Validation Dataset

In [None]:
class ReidValidDataset(Dataset):
	def __init__(self, file_list, directory) -> None:
		super().__init__()
		self.dir = directory
		self.file_list = file_list
		self.transform = T.Compose([
			T.ToTensor(),
			T.Normalize(mean=[0.], std=[1.])
		])

		# group indexes of the images by id
		file_list_per_id = {}
		for i, file in enumerate(file_list):
			img_id = int(file.split("_")[0])
			if file_list_per_id.get(img_id) is None:
				file_list_per_id[img_id] = [i]
			else:
				file_list_per_id[img_id].append(i)

		# compute ground truth
		self.ground_truth = {v[0]:set(v[1:]) for k,v in file_list_per_id.items()}

	def __getitem__(self, index):
		img = Image.open(os.path.join(self.dir, self.file_list[index]))
		if self.transform:
			img = self.transform(img)

		return img

	def __len__(self):
		return len(self.file_list)

In [None]:
def get_data(train_batch_size, valid_batch_size, data_dir, train_samples_ratio = 0.7, min_group_length = 10, n_select = 10, seed = 26):
	# create training and validation sets
	file_list = sorted(os.listdir(data_dir))
 
	train_list = []
	valid_list = []

	# group images by id
	file_list_per_id = group_files_per_id(file_list)

	# shuffle ids
	shuffled = np.array(list(file_list_per_id.keys()))
	np.random.seed(seed)
	np.random.shuffle(shuffled)

	# divide per id and create 2 list of images
	for index, id in enumerate(shuffled):
		if index < train_samples_ratio * len(file_list_per_id):
			train_list.append(file_list_per_id[id])
		else:
			valid_list.append(file_list_per_id[id])
	
	# unsqeeze
	train_list = np.array(reduce(lambda a, b: a+b, train_list))
	valid_list = np.array(reduce(lambda a, b: a+b, valid_list))

	training_data = ReidTrainDataset(train_list, data_dir, min_group_length, n_select)
	validation_data = ReidValidDataset(valid_list, data_dir)

	train_loader = torch.utils.data.DataLoader(training_data, train_batch_size, shuffle=True)
	valid_loader = torch.utils.data.DataLoader(validation_data, valid_batch_size, shuffle=False)

	return train_loader, valid_loader

# Model

## classification model (to load)

In [None]:
class MyModel(nn.Module):
	
	labels = [
		"age", "backpack", "bag", "handbag", "down", "clothes", "up", "hair", "hat", 
		"gender", "upblack", "upwhite", "upred", "uppurple", "upyellow", "upgray", 
		"upblue", "upgreen", "downblack", "downwhite", "downpink", "downpurple", 
		"downyellow", "downgray", "downblue", "downgreen", "downbrown", "upmulti", "downmulti"]

	def __init__(self, configuration):
		super().__init__()
	
		self.configuration = configuration
		self.model = self.configuration['model'] 
		self.model.fc = nn.Identity() # bypass final layer
		
		# use an extra layer if original network last layer output size is too big
		if self.configuration['use_extra_fc_layer']:
			self.extra_layer = self.configuration['extra_layer']

		# apply last layer differently if age or other attributes
		self.relu = nn.ReLU()
		self.labelsFC = { label: None for label in self.labels }
		self.outputs = self.labelsFC.copy()
		for label in self.labelsFC:
			if label == "age":
				self.labelsFC[label] = nn.Sequential(nn.Linear(self.configuration['last_layer_output_size'], 4), nn.Softmax(dim=1))
			else:
				self.labelsFC[label] = nn.Sequential(nn.Linear(self.configuration['last_layer_output_size'], 1), nn.Sigmoid())

	# ovverride to manage last layer
	def to(self, *args, **kwargs):
		if args[0] == "cuda":
			self.model.cuda(*args)
		else:
			self.model.to(*args, **kwargs)
	 
		self.relu.to(*args, **kwargs)
	
		for label in self.labelsFC:
			self.labelsFC[label].to(*args, **kwargs)
	 
		if self.configuration['use_extra_fc_layer']:
			self.extra_layer.to(*args, **kwargs)
		return self
	
	def forward(self, x):
		x = self.model(x)
		x = self.relu(x)

		if self.configuration['use_extra_fc_layer']:
			x = self.extra_layer(x)

		for label in self.outputs:
			self.outputs[label] = self.labelsFC[label](x)
		return self.outputs, x

## Re-Identification Model

In [None]:
class ResNetIdentification(nn.Module):
	def __init__(self, class_model_loc="pretrained_models/resnet_acc_88.47.h5"):
		super().__init__()
		self.resnet18 = torch.load(class_model_loc)
		self.resnet18.relu = nn.Identity() # bypass final layer
		# resnet_out_size = 512

	def to(self, *args, **kwargs):
		if args[0] == "cuda":
			self.resnet18.cuda(*args)
		else:
			self.resnet18.to(*args, **kwargs)
		return self
	
	def forward(self, x):
		out, x = self.resnet18(x)
		return x

# Utils

In [None]:
# create folder to save logs
def exp_folder(exp_name):
	if not os.path.isdir("runs/"):
		os.mkdir("runs/")
	suffix = ""
	i = 2
	while os.path.isdir(f"runs/{exp_name+suffix}"):
		suffix = str(i)
		i+=1
	exp_name += suffix
	os.mkdir(f"runs/{exp_name}")
	return exp_name

	
def log_data(writer, train_loss, m_ap, epoch):
	writer.add_scalar("Distance Loss", train_loss, epoch)
	writer.add_scalar("map mean avg. precision", m_ap, epoch)

def print_data(train_loss, m_ap, epoch):
	print("▚"*14)
	print(f"Epoch: {epoch+1}")
	print(f"distance loss: {train_loss:^10.5f}")
	print(f"m_ap: {m_ap:^10.5f}")

import gc
# to clean the gpu memory after testing one model
def clean_gpu(model):
  torch.cuda.empty_cache()
  gc.collect()
  model = model.cpu()
  model = None
  del model
  gc.collect()
  torch.cuda.empty_cache()

def download_to_drive():
	if USING_COLAB:
	  !zip -r /content/filebestbest.zip /content/runs
	  !cp filebestbest.zip /content/drive/MyDrive/

# mAP and Re-ranking

In [None]:
from typing import Dict, Set, List
def evaluate_map(predictions: Dict[str, List], ground_truth: Dict[str, Set]):
  '''
  Computes the mAP (https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173) of the predictions with respect to the given ground truth
  In person reidentification mAP refers to the mean of the AP over all queries.
  The AP for a query is the area under the precision-recall curve obtained from the list of predictions considering the
  ground truth elements as positives and the other ones as negatives

  :param predictions: dictionary from query filename to list of test image filenames associated with the query ordered
                      from the most to the least confident prediction.
                      Represents the predictions to be evaluated.
  :param ground_truth: dictionary from query filename to set of test image filenames associated with the query
                       Represents the ground truth on which to evaluate predictions.

  :return:
  '''

  m_ap = 0.0
  for current_ground_truth_query, current_ground_truth_query_set in ground_truth.items():

    # No predictions were performed for the current query, AP = 0
    if not current_ground_truth_query in predictions:
      continue

    current_ap = 0.0  # The area under the curve for the current sample
    current_predictions_list = predictions[current_ground_truth_query]

    # Recall increments of this quantity each time a new correct prediction is encountered in the prediction list
    delta_recall = 1.0 / len(current_ground_truth_query_set)

    # Goes through the list of predictions
    encountered_positives = 0
    for idx, current_prediction in enumerate(current_predictions_list):
      # Each time a positive is encountered, compute the current precition and the area under the curve
      # since the last positive
      if current_prediction in current_ground_truth_query_set:
        encountered_positives += 1
        current_precision = encountered_positives / (idx + 1)
        current_ap += current_precision * delta_recall

    m_ap += current_ap

  # Compute mean over all queries
  m_ap /= len(ground_truth)

  return m_ap

# Code taken from the paper "Re-ranking Person Re-identification with k-reciprocal Encoding"
# https://github.com/zhunzhong07/person-re-ranking
"""
Created on Mon Jun 26 14:46:56 2017
@author: luohao
"""

"""
CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
Matlab version: https://github.com/zhunzhong07/person-re-ranking
"""

"""
API
probFea: all feature vectors of the query set, shape = (image_size, feature_dim)
galFea: all feature vectors of the gallery set, shape = (image_size, feature_dim)
k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
MemorySave: set to 'True' when using MemorySave mode
Minibatch: avaliable when 'MemorySave' is 'True'
"""

from scipy.spatial.distance import cdist
def re_ranking(probFea,galFea,k1=20,k2=6,lambda_value=0.3, MemorySave = False, Minibatch = 2000):

    query_num = probFea.shape[0]
    all_num = query_num + galFea.shape[0]    
    feat = np.append(probFea,galFea,axis = 0)
    feat = feat.astype(np.float16)
    print('computing original distance')
    if MemorySave:
        original_dist = np.zeros(shape = [all_num,all_num],dtype = np.float16)
        i = 0
        while True:
            it = i + Minibatch
            if it < np.shape(feat)[0]:
                original_dist[i:it,] = np.power(cdist(feat[i:it,],feat),2).astype(np.float16)
            else:
                original_dist[i:,:] = np.power(cdist(feat[i:,],feat),2).astype(np.float16)
                break
            i = it
    else:
        original_dist = cdist(feat,feat).astype(np.float16)
        original_dist = np.power(original_dist,2).astype(np.float16)
    del feat
    gallery_num = original_dist.shape[0]
    original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0))
    V = np.zeros_like(original_dist).astype(np.float16)
    initial_rank = np.argsort(original_dist).astype(np.int32)

    
    print('starting re_ranking')
    for i in range(all_num):
        # k-reciprocal neighbors
        forward_k_neigh_index = initial_rank[i,:k1+1]
        backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1]
        fi = np.where(backward_k_neigh_index==i)[0]
        k_reciprocal_index = forward_k_neigh_index[fi]
        k_reciprocal_expansion_index = k_reciprocal_index
        for j in range(len(k_reciprocal_index)):
            candidate = k_reciprocal_index[j]
            candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1]
            candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1]
            fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
            candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
            if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index):
                k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index)
            
        k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
        weight = np.exp(-original_dist[i,k_reciprocal_expansion_index])
        V[i,k_reciprocal_expansion_index] = weight/np.sum(weight)
    original_dist = original_dist[:query_num,]    
    if k2 != 1:
        V_qe = np.zeros_like(V,dtype=np.float16)
        for i in range(all_num):
            V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0)
        V = V_qe
        del V_qe
    del initial_rank
    invIndex = []
    for i in range(gallery_num):
        invIndex.append(np.where(V[:,i] != 0)[0])
    
    jaccard_dist = np.zeros_like(original_dist,dtype = np.float16)

    
    for i in range(query_num):
        temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16)
        indNonZero = np.where(V[i,:] != 0)[0]
        indImages = []
        indImages = [invIndex[ind] for ind in indNonZero]
        for j in range(len(indNonZero)):
            temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]])
        jaccard_dist[i] = 1-temp_min/(2-temp_min)
    
    final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value
    del original_dist
    del V
    del jaccard_dist
    final_dist = final_dist[:query_num,query_num:]
    return final_dist

# Losses

## Custom

In [None]:
class CustomLoss:
	def __init__(self) -> None:
		pass

	# input shape: (batch_size, n_sample, 512)
	def __call__(self, features):
		minibatch_mean = torch.mean(features, dim=1)
		std_of_means = torch.var(minibatch_mean, dim=0).sum() # this must be maximized
		minibatch_std = torch.var(features, dim=1)
		minibatch_std = minibatch_std.sum(dim=0).sum() # this must be minimized
		# features_std = torch.std(features, dim=2).sum() # this must be maximized
		return minibatch_std - std_of_means

## Centroid

In [None]:
class CentroidTripletLoss:
	# - epochPercNearestNegativePhase: 	0 for nearest only, 
	# 																1 for random only, 
	# 																0.2 for 2phase
	# - incrProbMode: True to use incremental probability mode (bypass first parameter)
	def __init__(self, epochPercNearestNegativePhase = 0.2, incrProbMode = True) -> None:
		self.loss = nn.TripletMarginLoss(reduction="none")
		self.epochPercNearestNegativePhase = epochPercNearestNegativePhase
		self.incrProbMode = incrProbMode
		self.epochPerc = 1

	# input shape: (64, 12, 512)
	def __call__(self, features):
		feature_per_class = features.shape[1]
		loss = torch.Tensor([0.]).to(features.device)

		for i, minibatch in enumerate(features):
			dist = torch.cdist(features, minibatch, p=1).mean(2).sum(1)
			min_args = dist.argsort()[1:].squeeze().cpu().detach().numpy()
			if self.incrProbMode:
				if np.random.uniform() < 1-self.epochPerc:
					neg_idx = np.random.choice(min_args)
				else:
					neg_idx = min_args[0]
			else:
				if self.epochPerc < self.epochPercNearestNegativePhase:
					neg_idx = np.random.choice(min_args)
				else:
					neg_idx = min_args[0]
			
			# compute centroids
			negative = features[neg_idx].mean(0).repeat(feature_per_class, 1)
			positive = minibatch.mean(0).repeat(feature_per_class, 1)
			# loss(anchor, positive, negative)
			loss += self.loss(minibatch, positive, negative).sum()
		return loss


## Triplet

In [None]:
class TripletLoss:
	# - nearestIdBased: use id-based strategy
	# - epochPercNearestNegativePhase: 	0 for nearest only, 
	# 																	1 for random only, 
	# 																	0.2 for 2phase
	# - incrProbMode: True to use incremental probability mode (bypass first parameter)
	def __init__(self, nearestIdBased=False, epochPercNearestNegativePhase = 0.2, incrProbMode = True) -> None:
		self.loss = nn.TripletMarginLoss(reduction="none")
		self.nearestIdBased = nearestIdBased
		self.epochPercNearestNegativePhase = epochPercNearestNegativePhase
		self.incrProbMode = incrProbMode
		self.epochPerc = 1

	# input shape: (64, 12, 512)
	def __call__(self, features):
		loss = torch.Tensor([0.]).to(features.device)

		for i, minibatch in enumerate(features):

			if self.nearestIdBased:
				dist = torch.cdist(features, minibatch, p=1).mean(2).sum(1)
				min_args = dist.argsort()[1:].squeeze().cpu().detach().numpy()
				if self.incrProbMode:
					if np.random.uniform() < 1-self.epochPerc:
						neg_idx = np.random.choice(min_args)
					else:
						neg_idx = min_args[0]
				else:
					if self.epochPerc < self.epochPercNearestNegativePhase:
						neg_idx = np.random.choice(min_args)
					else:
						neg_idx = min_args[0]

				negative = features[neg_idx]

			else:
				negative_features = torch.cat((features[:i], features[i+1:])).flatten(0,1)
				dist = torch.cdist(negative_features, minibatch, p=1).mean(1)
				min_dist_idxs = dist.argsort()[:minibatch.shape[0]]
				if self.incrProbMode:
					if np.random.uniform() < 1-self.epochPerc:
						negative = negative_features[np.random.randint(negative_features.shape[0], size=[minibatch.shape[0]])]
					else:
						negative = negative_features[min_dist_idxs]
				else:
					if self.epochPerc < self.epochPercNearestNegativePhase:
						negative = negative_features[np.random.randint(negative_features.shape[0], size=[minibatch.shape[0]])]
					else:
						negative = negative_features[min_dist_idxs]
			
			positive = torch.flip(minibatch, [0])
			# loss(anchor, positive, negative)
			loss += self.loss(minibatch, positive, negative).sum()
		return loss

## Multi

In [None]:
class MultiLoss:
	def __init__(self) -> None:
		self.tripletLoss = TripletLoss(False, 0, True)
		self.centroidLoss = CentroidTripletLoss(False, True)
		self.customLoss = CustomLoss()
		self.epochPerc = 1
		self.w = [10, 5, 1] # weights for each loss (triplet, centorid, custom)

	def __call__(self, features):
		self.tripletLoss.epochPerc = self.epochPerc
		self.centroidLoss.epochPerc = self.epochPerc
		loss1 = self.tripletLoss(features)
		loss2 = self.centroidLoss(features)
		loss3 = self.customLoss(features)
		return loss1 * self.w[0] + loss2 * self.w[1] + loss3 * self.w[2]

# Train and Test functions

In [None]:
def train_reid(net, data_loader, optimizer, cost_function, device="cpu"):
	tot_samples = cumulative_loss = 0.

	net.train()
	with torch.set_grad_enabled(True):
		for batch_idx, inputs in enumerate(data_loader):

			# reshape the batch from (batchsize, n_select, 3, 128, 64) to (batchsize*n_select, 3, 128 , 64)
			# in order to get the shape compatible for the model
			inputs = inputs.to(device) # (64, 12, 3, 128, 64)
			s = inputs.shape
			inputs = inputs.reshape((s[0] * s[1], s[2], s[3], s[4])) # (64*12, 3, 128 , 64)
			tot_samples += s[0] * s[1]

			outputs = net(inputs) # (64*12, 512)

			# reshape to obtain groups of features belonging to the same id: (batchsize, n_select, feature_size)
			outputs = outputs.reshape((s[0], s[1], outputs.shape[1])) # (64, 12, 512)

			loss = cost_function(outputs)

			loss.backward()
			optimizer.step()
			optimizer.zero_grad()
			
			cumulative_loss += loss.item()

	return cumulative_loss/tot_samples

def test_reid(net, data_loader, device="cpu", useReranking=False, select_size=0):
	ground_truth = data_loader.dataset.ground_truth

	# first compute all the features
	net.eval()
	with torch.no_grad():
		all_features = torch.Tensor()
		for i, batch in enumerate(data_loader):
			features = net(batch.to(device))
			all_features = torch.cat((all_features, features.to("cpu")))
	
	if useReranking:
		dist_matrix = re_ranking(all_features, all_features)

	# then compute distance and the prediction dict {id: [list ids]}
	predictions = {}
	for index in ground_truth.keys():
		
		if useReranking:
			dist = dist_matrix[index]
		else:
			dist = torch.cdist(all_features, all_features[index][None, :], p=1).squeeze()

		if select_size > 0:
			min_args = dist.argsort()[1:select_size]
		else:
			min_args = dist.argsort()[1:]

		predictions[index] = min_args.tolist()

	m_ap = evaluate_map(predictions, ground_truth)

	return m_ap * 100

In [None]:
def main(net, train_batch_size, valid_batch_size, device, epochs, optimizer, cost_function, data_dir, train_sample_ratio, min_group_length, n_select, exp_name="reid"):
	# Creates a logger for the experiment
	max_map = epochs_without_improvements = 0
	exp_name = exp_folder(exp_name)
	writer = SummaryWriter(log_dir=f"runs/{exp_name}", comment=f"experiment {exp_name}")

	train_loader, valid_loader = get_data(train_batch_size, valid_batch_size, data_dir, train_sample_ratio, min_group_length, n_select)

	for e in range(epochs):
		
		# set to the loss the epoch percentage: curr_epoch / max_epochs
		if type(cost_function).__name__ == "TripletLoss" or type(cost_function).__name__ == "CentroidTripletLoss":
			cost_function.epochPerc = e / epochs

		train_loss = train_reid(net, train_loader, optimizer, cost_function, device)
		m_ap = test_reid(net, valid_loader, device)

		log_data(writer, train_loss, m_ap, e)

		print_data(train_loss, m_ap, e)

		# save the weights when accuracy increases
		max_map = max(m_ap, max_map)
		if max_map == m_ap:
			print("Best accuracy: {0}".format(max_map))
			epochs_without_improvements = 0
			if max_map > 68:
				model_name = 'reid_{0}.h5'.format(round(max_map, 2))
				torch.save(net, model_name)
				if USING_COLAB:
					os.system("cp {0} /content/drive/MyDrive/".format(model_name))
			else:
				epochs_without_improvements += 1
				print("Epochs without improvements: {0}".format(epochs_without_improvements))

	# Compute Re-ranking at the end
	m_ap = test_reid(net, valid_loader, device, True)
	print("\nFinal test with reranking...\n")
	print_data(train_loss, m_ap, e+1)
	
	writer.close()

# Start training

In [None]:
epochs = 50
train_batch_size = 128
test_batch_size = 512
min_group_length = 12
n_select = 12 # even number <= min_group_length
train_sample_ratio = 0.8
device = "cuda"
data_dir = "train"
lr = 1.5e-4
weight_decay = 1e-3
cost_function = TripletLoss() # default incr. prob. mode

## Single Training

In [None]:
net = ResNetIdentification("pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5")
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
main(net, train_batch_size, test_batch_size, device, epochs, optimizer, cost_function, data_dir, train_sample_ratio, min_group_length, n_select, "reid")

## Multi Training for experiments

In [None]:
configurations = {
	'reid_triplet_incr_prob' : {
		'model' : "pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5",
		'loss' : TripletLoss(False, 0, True)
		},
	'reid_triplet_nearest_only' : {
		'model' : "pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5",
		'loss' : TripletLoss(False, 0, False)
		},
	'reid_centroid_incr_prob' : {
		'model' : "pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5",
		'loss' : CentroidTripletLoss(0, True)
		},
	'reid_custom_loss' : {
		'model' : "pretrained_models/resnet_18_acc_70_30_seed_26_88.63.h5",
		'loss' : CustomLoss()
		}
}

# Run multiple experiment
def multi_experiments(configurations, num_repeated_training = 1):
	for i in range(num_repeated_training):
		for exp_name, configuration in configurations.items():

			# create the network
			net = ResNetIdentification(configuration["model"])
			net = net.to(device)
			optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
			cost_function = configuration["loss"]

			print('#######################################')
			print(exp_name)
			print('#######################################')

			# train the network
			main(net, train_batch_size, test_batch_size, device, epochs, optimizer, cost_function, data_dir, train_sample_ratio, min_group_length, n_select, exp_name)
			
			download_to_drive()

			clean_gpu(net)

# multi_experiments(configurations, 3)

In [None]:
# %load_ext tensorboard
# !rmdir /s /q "%appdata%/../Local/Temp/.tensorboard-info" # Windows only
# %tensorboard --logdir runs/ --host localhost --port 8888
# if you use vscode go to http://localhost:8888/

# Test

## Generate distance matrix from query and test features

In [None]:
class ImageDataset(Dataset):
	def __init__(self, directory) -> None:
		super().__init__()
		self.dir = directory
		self.transform = T.Compose([
			T.ToTensor(),
			T.Normalize(mean=[0.], std=[1.])
		])
		self.file_list = sorted(os.listdir(directory))

	def __getitem__(self, index):
		img = Image.open(os.path.join(self.dir, self.file_list[index]))
		if self.transform:
			img = self.transform(img)
		return img

	def __len__(self):
		return len(self.file_list)

def generate_dist_matrix(net, query_dataset, test_dataset, device = "cuda", batch_size = 512):
	net.eval()
	net.to(device)
	test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)
	query_loader = torch.utils.data.DataLoader(query_dataset, batch_size, shuffle=False)

	with torch.no_grad():
		test_features = torch.Tensor()
		for i, batch in enumerate(test_loader):
			features = net(batch.to(device))
			test_features = torch.cat((test_features, features.to("cpu")))

			print(f"\rComputing test features: {i+1}/{len(test_loader)}", end="")
		print()

		query_features = torch.Tensor()
		for i, batch in enumerate(query_loader):
			features = net(batch.to(device))
			query_features = torch.cat((query_features, features.to("cpu")))

			print(f"\rComputing query features: {i+1}/{len(query_loader)}", end="")

	print("\nComputing reranking pass...")
	dist_matrix = re_ranking(query_features, test_features)
	return dist_matrix


# net = torch.load("pretrained_models/reid_map_68.63.h5")
query_dataset = ImageDataset("queries")
test_dataset = ImageDataset("test")
dist_matrix = generate_dist_matrix(net, query_dataset, test_dataset)

## Generate prediction file

In [None]:
def write_prediction_file(dist_matrix, query_dataset, test_dataset):
	if not os.path.isdir("predictions"):
		os.mkdir("predictions")
	with open("predictions/reid_test2.txt", "w+") as reid_test_file:
		select_size = 100
		test_data_file_list = np.array(test_dataset.file_list)

		for index in range(len(query_dataset)):
			dist = dist_matrix[index]
			min_args = dist.argsort()[:select_size]
			best_file_list = test_data_file_list[min_args].tolist()
			s = ', '.join(best_file_list)
			entry_line = f'{query_dataset.file_list[index]}: {s}\n'
			reid_test_file.write(entry_line)

write_prediction_file(dist_matrix, query_dataset, test_dataset)

In [None]:
def show_prediction(query_img_id, dist_matrix, query_data, test_data, n_select = 50, treshold = None):
	img = query_data[query_img_id]
	plt.imshow(img.permute(1,2,0))
	plt.show()

	dist = dist_matrix[query_img_id]
	min_args = dist.argsort().squeeze()

	treshold_enc = False
	fig, axs = plt.subplots(n_select//10, 10, figsize=(16,17))
	for i, ax_r in enumerate(axs):
		if treshold_enc: break
		for j, ax in enumerate(ax_r):
			idx = (i*10)+j
			best_img = test_data[min_args[idx]]
			ax.imshow(best_img.permute(1,2,0))
			ax.set_title(dist[min_args[idx]])
			if treshold and dist[min_args[idx]] > treshold:
				treshold_enc = True
				break

show_prediction(126, dist_matrix, query_dataset, test_dataset)