In [3]:
#!/usr/bin/env/torch11 python3
import sys
import os
import time
import numpy as np
import pickle
import glob
from datetime import datetime

# pytorch, torch vision
import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
#from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from dataloader import DataLoader

sys.path.append('/home/absamant/UCDR/src/')
from options.options_snmpnet_ssl import Options
from data.DomainNet import domainnet
from data.dataloaders import BaselineDataset
from models.deepall.deepall_seresnet50 import SnMpNet_SSL
from utils.logger import AverageMeter
#from trainer import evaluate

np.random.seed(0)
torch.manual_seed(0)
RG = np.random.default_rng()

In [4]:
class Params():
    root_path = '/BS/UCDR/work/datasets/'

In [5]:
seen_domain = 'sketch'
holdout_domain = 'quickdraw'
gallery_domain = 'real'
include_auxillary_domains = 1
num_workers = 1
batch_size = 64
checkpoint_path_remote = '/BS/UCDR/work/pretrained_models/'
dataset = 'DomainNet'
args = Params()

In [8]:
#def main(args):
va_classes = np.load('../../data/DomainNet/val_classes.npy').tolist()
te_classes = np.load('../../data/DomainNet/test_classes.npy').tolist()
semantic_vec = np.load('../../data/DomainNet/w2v_domainnet.npy', allow_pickle=True, encoding='latin1').item()
use_gpu = torch.cuda.is_available()

if use_gpu:
	cudnn.benchmark = True
	torch.cuda.manual_seed_all(0)

device = torch.device("cuda:0" if use_gpu else "cpu")
print('\nDevice:{}'.format(device))

root_path = ''

tr_classes = np.load('../../data/DomainNet/train_classes.npy').tolist()
	
# Imagenet standards
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]

	# Image transformations
image_transforms = {
		'train':
		transforms.Compose([
			transforms.RandomResizedCrop((224, 224), (0.8, 1.0)),
			transforms.RandomHorizontalFlip(0.5),
			transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4),
			lambda x: np.asarray(x),
			#transforms.ToTensor(),
			#transforms.Normalize(im_mean, im_std)
		]),

		'eval':
		transforms.Compose([
			transforms.Resize((224, 224)),
			lambda x: np.asarray(x),
			#transforms.ToTensor(),
			#transforms.Normalize(im_mean, im_std)
		]),
	}

	# Model
model = SnMpNet_SSL(semantic_dim=300, pretrained=None)#.cuda()

save_folder_name = 'seen-'+seen_domain+'_unseen-'+holdout_domain+'_x_'+gallery_domain
if not include_auxillary_domains:
	save_folder_name += '_noaux'

path_cp = os.path.join(checkpoint_path_remote, dataset, save_folder_name)
today = datetime.now().strftime('%B %d, %Y')
path_log = os.path.join('./results', dataset, save_folder_name)
if not os.path.isdir(path_log):
	os.makedirs(path_log)

if len(os.listdir(path_log))>0:
	models_tested = [result_file.split('/')[-1][:-len('.txt')]+'.pth' for result_file in glob.glob(os.path.join(path_log, '*/*.*'))]
else:
	models_tested = []
print('Total models tested before: ', len(models_tested))

path_log_save = os.path.join(path_log, today)
if not os.path.isdir(path_log_save):
	os.makedirs(path_log_save)

	# for best_model_name in os.listdir(path_cp):

	# 	if best_model_name not in models_tested:

best_model_name = 'val_map200-0.3794_prec200-0.3032_ep-5_mixlevel-img_wcce-1.0_wratio-1.0_wmse-1.0_clswts-1.0_e-100_es-15_opt-sgd_bs-60_lr-0.001_l2-0.0_beta-2.0_warmup-2_seed-0_tv-0.pth'
#best_model_name = 'val_map200-0.2945_prec200-0.2363_ep-19_mixlevel-img_wcce-1.0_wratio-1.0_wmse-1.0_clswts-1.0_e-100_es-15_opt-sgd_bs-60_lr-0.001_l2-0.0_beta-1_warmup-2_seed-0_tv-0.pth'
# best_model_name = 'val_map200-0.2935_prec200-0.2354_ep-7_mixlevel-img_wcce-1.0_wratio-1.0_wmse-1.0_clswts-1.0_e-100_es-15_opt-sgd_bs-60_lr-0.001_l2-0.0_beta-1_warmup-2_seed-0_tv-0.pth'
best_model_file = os.path.join(path_cp, best_model_name)

if os.path.isfile(best_model_file):
		
	print("\nLoading best model from '{}'".format(best_model_file))
	# load the best model yet
	checkpoint = torch.load(best_model_file)
	epoch = checkpoint['epoch']
	best_map = checkpoint['best_map']
	model.load_state_dict(checkpoint['model_state_dict'])
	print("Loaded best model '{0}' (epoch {1}; mAP@200 {2:.4f})\n".format(best_model_file, epoch, best_map))

	outstr = ''

	gzs = 0
	splits_gallery = domainnet.trvalte_per_domain(args, gallery_domain, gzs, tr_classes, va_classes, te_classes)
	data_te_gallery = BaselineDataset(np.array(splits_gallery['te']), transforms=image_transforms['eval'])
	# PyTorch test loader for gallery
	te_loader_gallery = DataLoader(dataset=data_te_gallery, batch_size=64*5, shuffle=False, 
									   num_workers=num_workers, pin_memory=True)

	for domain in [seen_domain, holdout_domain]:

		test_head_str = 'Query:' + domain + '; Gallery:' + gallery_domain + '; Generalized:' + str(gzs)
		print(test_head_str)
		outstr += test_head_str

		splits_query = domainnet.trvalte_per_domain(args, domain, 0, tr_classes, va_classes, te_classes)
			
			
		data_te_query = BaselineDataset(np.array(splits_query['te']), transforms=image_transforms['eval'])			
		data_te_comb = BaselineDataset(np.array(splits_query['te'] + splits_gallery['te']), transforms=image_transforms['eval'])

		# PyTorch test loader for query
		te_loader_query = DataLoader(dataset=data_te_query, batch_size=64*5, shuffle=False, 
										 num_workers=num_workers)

		te_loader_ttt = DataLoader(dataset=data_te_comb, batch_size=batch_size, shuffle=True, 
									   num_workers=num_workers)

		model_ttt = SSL_train(te_loader_ttt, model)

		print(f'#Test queries:{len(te_loader_query.dataset)}; #Test gallery samples:{len(te_loader_gallery.dataset)}.\n')
		# te_data = evaluate(te_loader_query, te_loader_gallery, model_ttt, None, epoch, args, 'Usual')
		te_data = evaluate(te_loader_query, te_loader_gallery, model_ttt, None, epoch, args, 'val')
		
			# outstr+="\n\nmAP@200 = %.4f, Prec@200 = %.4f, mAP@all = %.4f, Prec@100 = %.4f, Time = %.6f\nmAP@200 (binary) = %.4f, "\
			# 	   "Prec@200 (binary) = %.4f, mAP@all (binary) = %.4f, Prec@100 (binary) = %.4f, Time (binary) = %.6f"\
			# 	   %(np.mean(te_data['aps@200']), te_data['prec@200'], np.mean(te_data['aps@all']), te_data['prec@100'], 
			# 		te_data['time_euc'], np.mean(te_data['aps@200_bin']), te_data['prec@200_bin'], np.mean(te_data['aps@all_bin']), 
			# 		te_data['prec@100_bin'], te_data['time_bin'])

		outstr+="\n\nmAP@200 = %.4f, Prec@200 = %.4f"%(np.mean(te_data['aps@200']), te_data['prec@200'])

		outstr += '\n\n'
		
	print(outstr)
	result_file = open(os.path.join(path_log_save, best_model_name[:-len('.pth')]+'.txt'), 'w')
	result_file.write(outstr)
	result_file.close()

	print('\nTest Results saved!')

		# else:
		# 	continue


Device:cpu
Total models tested before:  0
