In [1]:
import os
import torch
from options import Option
from data_utils.dataset import load_data_test
from model.model import Model
from utils.util import setup_seed, load_checkpoint
import torchvision
import einops




In [2]:
args = Option().parse()
args.load = "./checkpoint/sketchy_ext/best_checkpoint.pth"
args.batch = 2

print("test args:", str(args))

test args: Namespace(data_path='./datasets', dataset='sketchy_extend', test_class='test_class_sketchy25', cls_number=100, d_model=768, d_ff=1024, head=8, number=1, pretrained=True, anchor_number=49, save='./checkpoints/sketchy_ext', batch=2, epoch=30, datasetLen=10000, learning_rate=1e-05, weight_decay=0.01, load='./checkpoint/sketchy_ext/best_checkpoint.pth', retrieval='rn', testall=False, test_sk=20, test_im=20, num_workers=4, choose_cuda='0', seed=2021)


In [3]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.ap import calculate
from tqdm import tqdm
import time

In [4]:
os.environ["CUDA_VISIBLE_DEVICES"] = args.choose_cuda
print("current cuda: " + args.choose_cuda)
setup_seed(args.seed)

current cuda: 0


In [5]:
# prepare data
sk_valid_data, im_valid_data = load_data_test(args)

# prepare model
model = Model(args)
model = model.half()

if args.load is not None:
    checkpoint = load_checkpoint(args.load)

cur = model.state_dict()
new = {k: v for k, v in checkpoint['model'].items() if k in cur.keys()}
cur.update(new)
model.load_state_dict(cur)

if len(args.choose_cuda) > 1:
    model = torch.nn.parallel.DataParallel(model.to('cuda'))
model = model.cuda()

used for valid or test sketch / image:
(77,) (1711,)
used for train sketch / image:
(55252,) (68401,)
=> loading model './checkpoint/sketchy_ext/best_checkpoint.pth'


In [6]:
model.eval()
torch.set_grad_enabled(False)

print('loading image data')
sk_dataload = DataLoader(sk_valid_data, batch_size=args.test_sk, num_workers=args.num_workers, drop_last=False)
print('loading sketch data')
im_dataload = DataLoader(im_valid_data, batch_size=args.test_im, num_workers=args.num_workers, drop_last=False)

dist_im = None
all_dist = None

loading image data
loading sketch data


In [None]:
for i, (sk, sk_label) in enumerate(tqdm(sk_dataload)):
        #sk.shape=(20,3,224,224)
        print(i)
        if i == 0:
            all_sk_label = sk_label.numpy()
        else:
            all_sk_label = np.concatenate((all_sk_label, sk_label.numpy()), axis=0)

        sk_len = sk.size(0)
        sk = sk.cuda()
        #debug
        print(sk[0].shape)
        # cv2.imwrite(f"./output/sk-{i}",sk[0].cpu().numpy())
        if i==0:
            grid_sk = torchvision.utils.make_grid(sk)
            torchvision.utils.save_image(grid_sk,f"./output/sk.jpg")
        
        sk, sk_idxs = model(sk, None, 'test', only_sa=True)#sk.shape=(20,192,768)
        for j, (im, im_label) in enumerate(tqdm(im_dataload)):
            if i == 0 and j == 0:
                all_im_label = im_label.numpy()
            elif i == 0 and j > 0:
                all_im_label = np.concatenate((all_im_label, im_label.numpy()), axis=0)

            im_len = im.size(0)
            im = im.cuda()
            im, im_idxs = model(im, None, 'test', only_sa=True)

            sk_temp = sk.unsqueeze(1).repeat(1, im_len, 1, 1).flatten(0, 1).cuda() #(400,197,768) #?difference
            im_temp = im.unsqueeze(0).repeat(sk_len, 1, 1, 1).flatten(0, 1).cuda() #(400,197,768)
            
            if args.retrieval == 'rn':
                feature_1, feature_2 = model(sk_temp, im_temp, 'test')
            #? when retrieval == 'sa'
            if args.retrieval == 'sa':
                feature_1, feature_2 = torch.cat((sk_temp[:, 0], im_temp[:, 0]), dim=0), None

            # print(feature_1.size())    # [2*sk*im, 768] #2 means sk and im cls
            # print(feature_2.size())    # [sk*im, 1]

            if args.retrieval == 'rn':
                if j == 0:
                    dist_im = - feature_2.view(sk_len, im_len).cpu().data.numpy()  # 1*args.batch
                else:
                    dist_im = np.concatenate((dist_im, - feature_2.view(sk_len, im_len).cpu().data.numpy()), axis=1)
            if args.retrieval == 'sa':
                dist_temp = F.pairwise_distance(F.normalize(feature_1[:sk_len * im_len]),
                                                F.normalize(feature_1[sk_len * im_len:]), 2)
                if j == 0:
                    dist_im = dist_temp.view(sk_len, im_len).cpu().data.numpy()
                else:
                    dist_im = np.concatenate((dist_im, dist_temp.view(sk_len, im_len).cpu().data.numpy()), axis=1)

        if i == 0:
            all_dist = dist_im
        else:
            all_dist = np.concatenate((all_dist, dist_im), axis=0)
        print(all_dist.shape)
        #all_dist.shape=(all_sk_label.size, all_im_label.size)
    # print(all_sk_label.size, all_im_label.size)     # [762 x 1711] / 2
class_same = (np.expand_dims(all_sk_label, axis=1) == np.expand_dims(all_im_label, axis=0)) * 1
# print(all_dist.size, class_same.size)     # [762 x 1711] / 2


In [8]:
print(class_same.shape)
print(class_same)
np.savetxt("./output/all_dist",all_dist)
np.savetxt("./output/class_same",class_same)


(77, 1711)
[[1 1 1 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 1 1]]


In [9]:
map_all, map_200, precision100, precision200 = calculate(all_dist, class_same, test=True)
print(map_all,map_200,precision100,precision200)

eval time: 0.5573358535766602
0.680195628754997 0.6593741691830883 0.6408593850189362 0.6408593850189362


In [10]:
arg_sort_sim = all_dist.argsort()   # 得到从小到大索引值
print(arg_sort_sim.shape)
print(arg_sort_sim)

(77, 1711)
[[ 286  254  302 ...  173 1272  179]
 [ 382  302  271 ...  342  349  347]
 [ 253  302  255 ...  349  228  179]
 ...
 [ 160 1244 1213 ...  582 1366  604]
 [1411 1455 1421 ...  337 1213  567]
 [ 266  971  247 ... 1304  928  819]]


In [16]:
def patch_replace_data(sk_index, im_index):
    '''
    sk_index: 0
    im_index: list
    '''
    (sk,_) = sk_valid_data[sk_index]
    sk = torch.unsqueeze(sk,0)
    
    for i,v in enumerate(im_index):
        if i == 0:    
            (im,_) = im_valid_data[im_index[0]]
            im = torch.unsqueeze(im,0)
        else:    
            im = torch.cat((im,im_valid_data[v][0].unsqueeze(0)))
    
    return sk,im
    

In [17]:
sk_index=0
im_index = [286,254,302]
(sk_tmp, im_tmp) = patch_replace_data(sk_index,im_index)
print(sk_tmp.shape, im_tmp.shape)

torchvision.utils.save_image(sk_tmp.cuda(),f"./output/sk-{sk_index}.jpg")

im_tmp = torchvision.utils.make_grid(im_tmp)
torchvision.utils.save_image(im_tmp.cuda(),f"./output/im_top_{len(im_index)}.jpg")
print(sk_tmp.shape, im_tmp.shape)

torch.Size([1, 3, 224, 224]) torch.Size([3, 3, 224, 224])
torch.Size([1, 3, 224, 224]) torch.Size([3, 228, 680])


In [19]:
from model import rn

sk, im = patch_replace_data(sk_index, im_index)
print(sk.shape, im.shape)


sk_sa, sk_idxs = model(sk.cuda(), None, 'test', only_sa=True)#sk_sa.shape=(20,192,768)
im_sa, im_idxs = model(im.cuda(), None, 'test', only_sa=True)#im_sa.shape=(20,192,768)


sk_im_sa = torch.cat((sk_sa, im_sa), dim=0)
ca_fea = model.ca(sk_im_sa)  # [2b, 197, 768]
cls_fea = ca_fea[:, 0]  # [2b, 1, 768]
token_fea = ca_fea[:, 1:]  # [2b, 196, 768]
batch = token_fea.size(0)

print(token_fea.shape)


torch.Size([1, 3, 224, 224]) torch.Size([3, 3, 224, 224])
torch.Size([4, 196, 768])


In [21]:
# token_fea = einops.rearrange(token_fea,"b d h w -> b d (h w)") #token_fea = token_fea.view(batch, 768, 14, 14)
sk_fea = token_fea[0]
im_fea = token_fea[sk.size(0):]
# np.savetxt("./output/sk_fea", sk_fea.cpu())
# np.savetxt("./output/im_fea", im_fea.cpu())
print(sk_fea.shape, im_fea.shape)
cos_scores = rn.cos_similar(sk_fea, im_fea)
print(cos_scores.shape)
np.savetxt("./output/cos_scores",cos_scores.cpu()[0])

torch.Size([196, 768]) torch.Size([2, 196, 768])
torch.Size([2, 196, 196])


In [22]:
print(sk.size(0))

1


In [None]:
# print(cos_scores.argsort(0).shape,cos_scores.argsort(0))
# print(torch.argmax(einops.rearrange(cos_scores,"a b c -> b (a c)")))
b = einops.rearrange(cos_scores,"a b c -> b (a c)")
# print(cos_scores.shape,cos_scores)

max_indices = torch.empty((0,2), dtype=int)
print(b)
print(max_indices)

for i in b:
    max_indices_item = torch.argmax(i)
    print(i.shape)
    new = np.unravel_index(max_indices_item.cpu(),(cos_scores.shape[0],cos_scores.shape[2]))
    # print(torch.Tensor(new))
    max_indices = torch.cat((max_indices, torch.tensor(new, dtype=torch.int).unsqueeze(0)), 0)
    print(max_indices)
    
# print(np.unravel_index(b.values, (3, 196)))
np.savetxt("./output/max_indices",max_indices)


In [None]:
indices = max_indices

print(im.shape)
# x = torch.zeros((0,)+tuple(im.shape[1:]))
x = torch.zeros((0, 3, 14,14))
print(x)
for i in indices:
    patch_index = np.unravel_index(i[1],(16,16))
    item = patch2im(torch.tensor(patch_index,dtype=int), im[i[0]], int(im.shape[-1]/16))
    print(item.shape)
    x= torch.cat([x, item.unsqueeze(0)])

In [None]:
x = torchvision.utils.make_grid(x,nrow=16)
torchvision.utils.save_image(x,"./output/patch_replace.jpg")

In [None]:
def patch_match(im, indices):
    '''
        im: (b,c,w,h)
        indices: (m,im.shape.len)
    '''
    print(im.shape)
    x = torch.zeros((0,)+tuple(im.shape[1:]))
    print(x)
    for i in indices:
        patch_index = np.unravel_index(i[1],(16,16))
        item = patch2im(torch.tensor(patch_index,dtype=int), im[i[0]], im.shape[-1]/16)
        x= torch.cat([x, item])
    return x 


In [None]:
def patch2im(patch_index,im, patch_size):
    '''
    im: (c, w, h)
    patch_index: (2)
    return: (c, patch_size, patch_size)
    '''
    print(patch_index.shape, im.shape, patch_size)
    # width_range = torch.tensor([patch_index[0]*patch_size,(patch_index[0]+1)*patch_size],dtype=int)
    # height_range = torch.tensor([patch_index[1]*patch_size,(patch_index[1]+1)*patch_size],dtype=int)
    print(patch_index)
    print(patch_index[0].item()*patch_size)
        
    return im[:, \
        patch_index[0]*patch_size:(patch_index[0].item()+1)*patch_size, \
        patch_index[1].item()*patch_size:(patch_index[1].item()+1)*patch_size]


In [None]:
print(x.shape,x)

In [None]:
  # [2b, n, n] n=patches**2
cos_scores = cos_scores.view(batch // 2, -1)

rn.cos_similar(im)

In [None]:
print(sk_dataload[30])

In [None]:
# valid
# map_all, map_200, precision_100, precision_200 = valid_cls(args, model, sk_valid_data, im_valid_data)
print(f'map_all:{map_all:.4f} map_200:{map_200:.4f} precision_100:{precision100:.4f} precision_200:{precision200:.4f}')