In [None]:
import os

import torch

from options import Option
from data_utils.dataset import load_data_test
from model.model import Model
from utils.util import setup_seed, load_checkpoint
import torchvision
import einops

In [None]:
args = Option().parse()
args.load = "./checkpoint/sketchy_ext/best_checkpoint.pth"
args.batch = 2


print("test args:", str(args))


In [None]:
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from utils.ap import calculate
from tqdm import tqdm

import time

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = args.choose_cuda
print("current cuda: " + args.choose_cuda)
setup_seed(args.seed)


In [None]:
# prepare data
sk_valid_data, im_valid_data = load_data_test(args)

# prepare model
model = Model(args)
model = model.half()

if args.load is not None:
    checkpoint = load_checkpoint(args.load)

cur = model.state_dict()
new = {k: v for k, v in checkpoint['model'].items() if k in cur.keys()}
cur.update(new)
model.load_state_dict(cur)

if len(args.choose_cuda) > 1:
    model = torch.nn.parallel.DataParallel(model.to('cuda'))
model = model.cuda()


In [None]:
model.eval()
torch.set_grad_enabled(False)

print('loading image data')
sk_dataload = DataLoader(sk_valid_data, batch_size=args.test_sk, num_workers=args.num_workers, drop_last=False)
print('loading sketch data')
im_dataload = DataLoader(im_valid_data, batch_size=args.test_im, num_workers=args.num_workers, drop_last=False)

dist_im = None
all_dist = None


In [None]:
print(sk_valid_data[30][0])
torchvision.utils.save_image(sk_valid_data[30][0].cuda(),"./output/sk-30.jpg")

In [None]:
for i, (sk, sk_label) in enumerate(tqdm(sk_dataload)):
        #sk.shape=(20,3,224,224)
        print(i)
        if i == 0:
            all_sk_label = sk_label.numpy()
        else:
            all_sk_label = np.concatenate((all_sk_label, sk_label.numpy()), axis=0)

        sk_len = sk.size(0)
        sk = sk.cuda()
        #debug
        print(sk.shape, sk[0].shape)
        # cv2.imwrite(f"./output/sk-{i}",sk[0].cpu().numpy())
        if i==0:
            grid_sk = torchvision.utils.make_grid(sk)
            print(sk)
            torchvision.utils.save_image(grid_sk,f"./output/sk.jpg")
        
        sk, sk_idxs = model(sk, None, 'test', only_sa=True)#sk.shape=(20,192,768)
        for j, (im, im_label) in enumerate(tqdm(im_dataload)):
            if i == 0 and j == 0:
                all_im_label = im_label.numpy()
            elif i == 0 and j > 0:
                all_im_label = np.concatenate((all_im_label, im_label.numpy()), axis=0)

            im_len = im.size(0)
            im = im.cuda()
            im, im_idxs = model(im, None, 'test', only_sa=True)

            sk_temp = sk.unsqueeze(1).repeat(1, im_len, 1, 1).flatten(0, 1).cuda() #(400,197,768) #?difference
            im_temp = im.unsqueeze(0).repeat(sk_len, 1, 1, 1).flatten(0, 1).cuda() #(400,197,768)
            
            if args.retrieval == 'rn':
                feature_1, feature_2 = model(sk_temp, im_temp, 'test')
            #? when retrieval == 'sa'
            if args.retrieval == 'sa':
                feature_1, feature_2 = torch.cat((sk_temp[:, 0], im_temp[:, 0]), dim=0), None

            # print(feature_1.size())    # [2*sk*im, 768] #2 means sk and im cls
            # print(feature_2.size())    # [sk*im, 1]

            if args.retrieval == 'rn':
                if j == 0:
                    dist_im = - feature_2.view(sk_len, im_len).cpu().data.numpy()  # 1*args.batch
                else:
                    dist_im = np.concatenate((dist_im, - feature_2.view(sk_len, im_len).cpu().data.numpy()), axis=1)
            if args.retrieval == 'sa':
                dist_temp = F.pairwise_distance(F.normalize(feature_1[:sk_len * im_len]),
                                                F.normalize(feature_1[sk_len * im_len:]), 2)
                if j == 0:
                    dist_im = dist_temp.view(sk_len, im_len).cpu().data.numpy()
                else:
                    dist_im = np.concatenate((dist_im, dist_temp.view(sk_len, im_len).cpu().data.numpy()), axis=1)

        if i == 0:
            all_dist = dist_im
        else:
            all_dist = np.concatenate((all_dist, dist_im), axis=0)
        print(all_dist.shape)
        #all_dist.shape=(all_sk_label.size, all_im_label.size)
    # print(all_sk_label.size, all_im_label.size)     # [762 x 1711] / 2
class_same = (np.expand_dims(all_sk_label, axis=1) == np.expand_dims(all_im_label, axis=0)) * 1
# print(all_dist.size, class_same.size)     # [762 x 1711] / 2


In [None]:
print(all_dist.shape, class_same.shape)
print(all_dist, class_same)
np.savetxt("./output/all_dist",all_dist)
np.savetxt("./output/class_same",class_same)


In [None]:
map_all, map_200, precision100, precision200 = calculate(all_dist, class_same, test=True)

In [None]:
arg_sort_sim = all_dist.argsort()   # 得到从小到大索引值
print(arg_sort_sim.shape)
print(arg_sort_sim)

In [None]:
(sk,_) = sk_valid_data[0]
sk = torch.unsqueeze(sk,0)

(im,_) = im_valid_data[30]
im = torch.unsqueeze(im,0)

im = torch.cat((im,im_valid_data[28][0].unsqueeze(0),im_valid_data[27][0].unsqueeze(0)))

print(sk.shape, im.shape)

In [None]:
from model import rn

sk_sa, sk_idxs = model(sk.cuda(), None, 'test', only_sa=True)#sk.shape=(20,192,768)
im_sa, im_idxs = model(im.cuda(), None, 'test', only_sa=True)#sk.shape=(20,192,768)


sk_im_sa = torch.cat((sk_sa, im_sa), dim=0)
ca_fea = model.ca(sk_im_sa)  # [2b, 197, 768]
cls_fea = ca_fea[:, 0]  # [2b, 1, 768]
token_fea = ca_fea[:, 1:]  # [2b, 196, 768]
batch = token_fea.size(0)

print(token_fea.shape)


In [None]:
# token_fea = einops.rearrange(token_fea,"b d h w -> b d (h w)") #token_fea = token_fea.view(batch, 768, 14, 14)

sk_fea = token_fea[sk.size(0)]
im_fea = token_fea[sk.size(0)+1:]
# np.savetxt("./output/sk_fea", sk_fea.cpu())
# np.savetxt("./output/im_fea", im_fea.cpu())
print(sk_fea, im_fea)
cos_scores = rn.cos_similar(sk_fea, im_fea)
print(cos_scores.shape, cos_scores)
np.savetxt("./output/cos_scores",cos_scores.cpu()[0])

In [None]:
# print(cos_scores.argsort(0).shape,cos_scores.argsort(0))
# print(torch.argmax(einops.rearrange(cos_scores,"a b c -> b (a c)")))
b = einops.rearrange(cos_scores,"a b c -> b (a c)")
# print(cos_scores.shape,cos_scores)

max_indices = torch.empty((0,2), dtype=int)
print(b)
print(max_indices)

for i in b:
    max_indices_item = torch.argmax(i)
    print(i.shape)
    new = np.unravel_index(max_indices_item.cpu(),(cos_scores.shape[0],cos_scores.shape[2]))
    # print(torch.Tensor(new))
    max_indices = torch.cat((max_indices, torch.tensor(new, dtype=torch.int).unsqueeze(0)), 0)
    print(max_indices)
    
# print(np.unravel_index(b.values, (3, 196)))
np.savetxt("./output/max_indices",max_indices)


In [83]:
patch_match(im, max_indices)

indices = max_indices
print(im.shape)
x = torch.zeros((0,)+tuple(im.shape[1:]))
print(x)
for i in indices:
    selected_im = i[0]
    patch_index = np.unravel_index(i[1],(16,16))
    item = patch2im(patch_index, im, im.shape[-1]/16)
    x= torch.cat(x, item)

torch.Size([3, 3, 224, 224])
tensor([], size=(0, 3, 224, 224))


TypeError: cat() received an invalid combination of arguments - got (Tensor, Tensor), but expected one of:
 * (tuple of Tensors tensors, int dim, *, Tensor out)
 * (tuple of Tensors tensors, name dim, *, Tensor out)


In [None]:
def patch_match(im, indices):
    '''
        im: (b,c,w,h)
        indices: (m,im.shape.len)
    '''
    print(im.shape)
    x = torch.zeros((0,)+tuple(im.shape[1:]))
    print(x)
    for i in indices:
        selected_im = i[0]
        patch_index = np.unravel_index(i[1],(16,16))
        item = patch2im(patch_index, im, im.shape[-1]/16)
        x= torch.cat(x, item)
    return x 
def patch2im(patch_index,im, patch_size):
    '''
    im: (c, w, h)
    patch_index: (2)
    '''
    width_range = (patch_index[0]*patch_size,(patch_index[0]+1)*patch_size)
    height_range = (patch_index[1]*patch_size,(patch_index[1]+1)*patch_size)
    return im[:,width_range,height_range]

In [None]:
  # [2b, n, n] n=patches**2
cos_scores = cos_scores.view(batch // 2, -1)

rn.cos_similar(im)

In [None]:
print(sk_dataload[30])

In [None]:
# valid
# map_all, map_200, precision_100, precision_200 = valid_cls(args, model, sk_valid_data, im_valid_data)
print(f'map_all:{map_all:.4f} map_200:{map_200:.4f} precision_100:{precision100:.4f} precision_200:{precision200:.4f}')