In [None]:
import os
import torch
from options import Option
from data_utils.dataset import load_data_test
from model.model import Model
from utils.util import setup_seed, load_checkpoint
import torchvision
import einops

In [None]:
args = Option().parse()
args.load = "./checkpoints/sketchy_ext/best_checkpoint.pth"
args.batch = 2
args.valid_shrink_sk=200
args.valid_shrink_im=100


print("test args:", str(args))

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from utils.ap import calculate
from tqdm import tqdm
import time

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = args.choose_cuda
print("current cuda: " + args.choose_cuda)
setup_seed(args.seed)

In [None]:
# prepare data
sk_valid_data, im_valid_data = load_data_test(args)

# prepare model
model = Model(args)
model = model.half()

if args.load is not None:
    checkpoint = load_checkpoint(args.load)

cur = model.state_dict()
new = {k: v for k, v in checkpoint['model'].items() if k in cur.keys()}
cur.update(new)
model.load_state_dict(cur)

if len(args.choose_cuda) > 1:
    model = torch.nn.parallel.DataParallel(model.to('cuda'))
model = model.cuda()

In [None]:
model.eval()
torch.set_grad_enabled(False)

print('loading image data')
sk_dataload = DataLoader(sk_valid_data, batch_size=args.test_sk, num_workers=args.num_workers, drop_last=False)
print('loading sketch data')
im_dataload = DataLoader(im_valid_data, batch_size=args.test_im, num_workers=args.num_workers, drop_last=False)

dist_im = None
all_dist = None

In [None]:
for i, (sk, sk_label) in enumerate(tqdm(sk_dataload)):
        #sk.shape=(20,3,224,224)
        print(i)
        if i == 0:
            all_sk_label = sk_label.numpy()
        else:
            all_sk_label = np.concatenate((all_sk_label, sk_label.numpy()), axis=0)

        sk_len = sk.size(0)
        sk = sk.cuda()
        #debug
        print(sk[0].shape)
        # cv2.imwrite(f"./logs/sk-{i}",sk[0].cpu().numpy())
        if i==0:
            grid_sk = torchvision.utils.make_grid(sk)
            torchvision.utils.save_image(grid_sk,f"./logs/sk.jpg")
        
        sk, sk_idxs = model(sk, None, 'test', only_sa=True)#sk.shape=(20,192,768)
        for j, (im, im_label) in enumerate(tqdm(im_dataload)):
            if i == 0 and j == 0:
                all_im_label = im_label.numpy()
            elif i == 0 and j > 0:
                all_im_label = np.concatenate((all_im_label, im_label.numpy()), axis=0)

            im_len = im.size(0)
            im = im.cuda()
            im, im_idxs = model(im, None, 'test', only_sa=True)

            sk_temp = sk.unsqueeze(1).repeat(1, im_len, 1, 1).flatten(0, 1).cuda() #(400,197,768) #?difference
            im_temp = im.unsqueeze(0).repeat(sk_len, 1, 1, 1).flatten(0, 1).cuda() #(400,197,768)
            
            if args.retrieval == 'rn':
                feature_1, feature_2 = model(sk_temp, im_temp, 'test')
            #? when retrieval == 'sa'
            if args.retrieval == 'sa':
                feature_1, feature_2 = torch.cat((sk_temp[:, 0], im_temp[:, 0]), dim=0), None

            # print(feature_1.size())    # [2*sk*im, 768] #2 means sk and im cls
            # print(feature_2.size())    # [sk*im, 1]

            if args.retrieval == 'rn':
                if j == 0:
                    dist_im = - feature_2.view(sk_len, im_len).cpu().data.numpy()  # 1*args.batch
                else:
                    dist_im = np.concatenate((dist_im, - feature_2.view(sk_len, im_len).cpu().data.numpy()), axis=1)
            if args.retrieval == 'sa':
                dist_temp = F.pairwise_distance(F.normalize(feature_1[:sk_len * im_len]),
                                                F.normalize(feature_1[sk_len * im_len:]), 2)
                if j == 0:
                    dist_im = dist_temp.view(sk_len, im_len).cpu().data.numpy()
                else:
                    dist_im = np.concatenate((dist_im, dist_temp.view(sk_len, im_len).cpu().data.numpy()), axis=1)

        if i == 0:
            all_dist = dist_im
        else:
            all_dist = np.concatenate((all_dist, dist_im), axis=0)
        print(all_dist.shape)
        #all_dist.shape=(all_sk_label.size, all_im_label.size)
    # print(all_sk_label.size, all_im_label.size)     # [762 x 1711] / 2
class_same = (np.expand_dims(all_sk_label, axis=1) == np.expand_dims(all_im_label, axis=0)) * 1
# print(all_dist.size, class_same.size)     # [762 x 1711] / 2


In [None]:
print(class_same.shape)
print(class_same)
np.savetxt("./logs/all_dist",all_dist)
np.savetxt("./logs/class_same",class_same)


In [None]:
map_all, map_200, precision100, precision200 = calculate(all_dist, class_same, test=True)
print(map_all,map_200,precision100,precision200)

In [None]:
arg_sort_sim = all_dist.argsort()   # 得到从小到大索引值
print(arg_sort_sim.shape)
print(arg_sort_sim)
np.savetxt("./logs/arg_sort_sim",torch.tensor(arg_sort_sim,dtype=int))

In [None]:
def patch2im(patch_index,im, patch_size):
    '''
    im: (c, w, h)
    patch_index: (2)
    return: (c, patch_size, patch_size)
    '''
    # print(patch_index.shape, im.shape, patch_size)
    # print(patch_index)
    # print(patch_index[0].item()*patch_size)
        
    return im[:, \
        patch_index[0]*patch_size:(patch_index[0].item()+1)*patch_size, \
        patch_index[1].item()*patch_size:(patch_index[1].item()+1)*patch_size]


In [None]:
def patch_match(im, indices,patch_size):
    '''
        im: (b,c,w,h)
        indices: (m,im.shape.len)
    '''
    # print(im.shape)
    x = torch.zeros((0,)+tuple(im.shape[1:]))
    # print(x)
    for i in indices:
        patch_index = np.unravel_index(i[1],(im.size(-1)/patch_size,im.size(-1)/patch_size))
        item = patch2im(torch.tensor(patch_index,dtype=int), im[i[0]], patch_size)
        x= torch.cat([x, item])
    return x 


In [None]:
def patch_replace_data(im_index,im):
    '''
    im_index: [b_i, n_i]
    im: [b, n, ......]
    '''
    for i,v in enumerate(im_index):
        if i == 0:    
            print(v)
            im_rtn = im[v[0]][v[1]].unsqueeze(0)
            print(im_rtn.shape)
        else:    
            im_rtn = torch.cat((im_rtn,im[v[0]][v[1]].unsqueeze(0)))
    
    return im_rtn

In [None]:
sk_index= 1
im_index = arg_sort_sim[sk_index,:3]
# (sk_tmp, im_tmp) = patch_replace_data(max_indices, im_valid_data[im_index[0],im_index[1],im_index[2],])
print(sk_index, im_index)
(sk,_) = sk_valid_data[sk_index]
sk = sk.unsqueeze(0)

tmp = [im_valid_data[i] for i in im_index]
im = [i[0].unsqueeze(0) for i in tmp]
im = torch.concatenate(im)
print(sk.shape, im.shape)

torchvision.utils.save_image(sk.cuda(),f"./logs/sk-{sk_index}.jpg")

im_tmp = torchvision.utils.make_grid(im)
torchvision.utils.save_image(im_tmp.cuda(),f"./logs/im_top_{len(im_index)}.jpg")
print(sk.shape, im_tmp.shape)

In [None]:
from model import rn


print(sk.shape, im.shape)


sk_sa, sk_idxs = model(sk.cuda(), None, 'test', only_sa=True)#sk_sa.shape=(20,192,768)
im_sa, im_idxs = model(im.cuda(), None, 'test', only_sa=True)#im_sa.shape=(20,192,768)


sk_im_sa = torch.cat((sk_sa, im_sa), dim=0)
print(sk_im_sa.shape)
ca_fea = model.ca(sk_im_sa)  # [2b, 197, 768]
cls_fea = ca_fea[:, 0]  # [2b, 1, 768]
token_fea = ca_fea[:, 1:]  # [2b, 196, 768]
print(token_fea.shape)

token_fea_tmp = einops.rearrange(token_fea, "b (h w) c -> b c h w", h=14)
print(token_fea_tmp.shape)
up_fea = model.output4VQGAN(token_fea_tmp)
print(up_fea.shape)
up_fea = einops.rearrange(up_fea, "b c h w -> b (h w) c")
print(up_fea.shape)

batch = token_fea.size(0)

In [None]:
# token_fea = einops.rearrange(token_fea,"b d h w -> b d (h w)") #token_fea = token_fea.view(batch, 768, 14, 14)
sk_fea = up_fea[0]
im_fea = up_fea[sk.size(0):]
# np.savetxt("./logs/sk_fea", sk_fea.cpu())
# np.savetxt("./logs/im_fea", im_fea.cpu())
print(sk_fea.shape, im_fea.shape)
cos_scores = rn.cos_similar(sk_fea, im_fea)
print(cos_scores.shape)
np.savetxt("./logs/cos_scores",cos_scores.cpu()[0])

In [None]:
# print(cos_scores.argsort(0).shape,cos_scores.argsort(0))
# print(torch.argmax(einops.rearrange(cos_scores,"a b c -> b (a c)")))
b = einops.rearrange(cos_scores,"a b c -> b (a c)")
# print(cos_scores.shape,cos_scores)

max_indices = torch.empty((0,2), dtype=int)
print(b)
print(max_indices)

for i in b:
    max_indices_item = torch.argmax(i)
    # print(i.shape)
    new = np.unravel_index(max_indices_item.cpu(),(cos_scores.shape[0],cos_scores.shape[2]))
    # print(torch.Tensor(new))
    max_indices = torch.cat((max_indices, torch.tensor(new, dtype=torch.int).unsqueeze(0)), 0)
    # print(max_indices)
    
# print(np.unravel_index(b.values, (3, 196)))
np.savetxt("./logs/max_indices",max_indices)


In [None]:
#patch replace op test
# indices = max_indices

# print(im.shape)
# # x = torch.zeros((0,)+tuple(im.shape[1:]))
# x = torch.zeros((0, 3, 14,14))
# # print(x)
# for i,v in enumerate(indices):
#     patch_index = np.unravel_index(i,(16,16))
#     item = patch2im(torch.tensor(patch_index,dtype=int), im[0], im.shape[-1]//16)
#     # print(item.shape)
#     x= torch.cat([x, item.unsqueeze(0)])

In [None]:
print(max_indices.shape,up_fea.shape)
im_replaced = patch_replace_data(max_indices, im_fea)
print(im_replaced.shape)

In [None]:
# indices = max_indices

# print(im.shape)
# # x = torch.zeros((0,)+tuple(im.shape[1:]))
# x = torch.zeros((0, 3, 16,16))
# # print(x)
# for i in indices:
#     patch_index = np.unravel_index(i[1],(14,14))
#     item = patch2im(torch.tensor(patch_index,dtype=int), im[i[0]], int(im.shape[-1]/14))
#     # print(item.shape)
#     x= torch.cat([x, item.unsqueeze(0)])

In [None]:
# x = torchvision.utils.make_grid(x,nrow=14)
# torchvision.utils.save_image(x,"./logs/patch_replace.jpg")

In [None]:
# valid
# map_all, map_200, precision_100, precision_200 = valid_cls(args, model, sk_valid_data, im_valid_data)
print(f'map_all:{map_all:.4f} map_200:{map_200:.4f} precision_100:{precision100:.4f} precision_200:{precision200:.4f}')

In [None]:
import torch
vqgan_dict = torch.load("../download/last.ckpt")

In [None]:
!proxychains git clone https://github.com/CompVis/taming-transformers
%cd taming-transformers

In [None]:
!mkdir -p logs/vqgan_imagenet_f16_1024/checkpoints
!mkdir -p logs/vqgan_imagenet_f16_1024/configs
# !wget 'https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1' -O 'logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt' 
!proxychains wget 'https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1' -O 'taming-transformers/logs/vqgan_imagenet_f16_1024/configs/model.yaml' 

In [None]:
%pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 einops>=0.3.0

In [10]:
import sys
sys.path.append("taming-transformers/")


In [11]:

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda:0


In [12]:
import os
os.getcwd()

'/mnt/g/github/ZSE-SBIR'

In [13]:
import yaml
import torch
from omegaconf import OmegaConf
from taming.models.vqgan import VQModel, GumbelVQ

def load_config(config_path, display=False):
  config = OmegaConf.load(config_path)
  if display:
    print(yaml.dump(OmegaConf.to_container(config)))
  return config

def load_vqgan(config, ckpt_path=None, is_gumbel=False):
  if is_gumbel:
    model = GumbelVQ(**config.model.params)
  else:
    model = VQModel(**config.model.params)
  if ckpt_path is not None:
    sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
    missing, unexpected = model.load_state_dict(sd, strict=False)
  return model.eval()

def preprocess_vqgan(x):
  x = 2.*x - 1.
  return x

def custom_to_pil(x):
  x = x.detach().cpu()
  x = torch.clamp(x, -1., 1.)
  x = (x + 1.)/2.
  x = x.permute(1,2,0).numpy()
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

def reconstruct_with_vqgan(x, model):
  # could also use model(x) for reconstruction but use explicit encoding and decoding here
  z, _, [_, _, indices] = model.encode(x)
  print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
  xrec = model.decode(z)
  return xrec

In [14]:
config1024 = load_config("taming-transformers/logs/vqgan_imagenet_f16_1024/configs/model.yaml", display=False)
model1024 = load_vqgan(config1024, ckpt_path="taming-transformers/logs/vqgan_imagenet_f16_1024/checkpoints/last.ckpt").to(DEVICE)

print(model1024)
print(model1024, file=open("logs/model1024_info","w"))

Working with z of shape (1, 256, 16, 16) = 65536 dimensions.




loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.
VQModel(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0-1): 2 x Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (2): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps

In [15]:
# xrec  = reconstruct_with_vqgan(x, model1024)
h = einops.rearrange(im_replaced.unsqueeze(0),"b (h w) c -> b c h w",h=32) #(1,256,1024)
h = h.to(torch.float32)

# h = h.flatten()
print(h, file=open("./logs/h","w"))
h = einops.rearrange(h,"(b c h w) -> b c h w",b=1,c=256,h=32) #(1,256,1024)

print(h.shape)
print(h.dtype)


NameError: name 'einops' is not defined

In [None]:
h = einops.rearrange(h, 'b c h w -> b h w c').contiguous()

In [33]:
import torch
h = torch.arange(0.,1024*256.).reshape(1,32,32,256).cuda()
print(h)

tensor([[[[0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 2.5300e+02,
           2.5400e+02, 2.5500e+02],
          [2.5600e+02, 2.5700e+02, 2.5800e+02,  ..., 5.0900e+02,
           5.1000e+02, 5.1100e+02],
          [5.1200e+02, 5.1300e+02, 5.1400e+02,  ..., 7.6500e+02,
           7.6600e+02, 7.6700e+02],
          ...,
          [7.4240e+03, 7.4250e+03, 7.4260e+03,  ..., 7.6770e+03,
           7.6780e+03, 7.6790e+03],
          [7.6800e+03, 7.6810e+03, 7.6820e+03,  ..., 7.9330e+03,
           7.9340e+03, 7.9350e+03],
          [7.9360e+03, 7.9370e+03, 7.9380e+03,  ..., 8.1890e+03,
           8.1900e+03, 8.1910e+03]],

         [[8.1920e+03, 8.1930e+03, 8.1940e+03,  ..., 8.4450e+03,
           8.4460e+03, 8.4470e+03],
          [8.4480e+03, 8.4490e+03, 8.4500e+03,  ..., 8.7010e+03,
           8.7020e+03, 8.7030e+03],
          [8.7040e+03, 8.7050e+03, 8.7060e+03,  ..., 8.9570e+03,
           8.9580e+03, 8.9590e+03],
          ...,
          [1.5616e+04, 1.5617e+04, 1.5618e+04,  ..., 1.5869

In [34]:
h = model1024.quantize.forward(h) #don't use same name
print(model1024.quantize)


tensor([[[[0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 2.5300e+02,
           2.5400e+02, 2.5500e+02],
          [2.5600e+02, 2.5700e+02, 2.5800e+02,  ..., 5.0900e+02,
           5.1000e+02, 5.1100e+02],
          [5.1200e+02, 5.1300e+02, 5.1400e+02,  ..., 7.6500e+02,
           7.6600e+02, 7.6700e+02],
          ...,
          [7.4240e+03, 7.4250e+03, 7.4260e+03,  ..., 7.6770e+03,
           7.6780e+03, 7.6790e+03],
          [7.6800e+03, 7.6810e+03, 7.6820e+03,  ..., 7.9330e+03,
           7.9340e+03, 7.9350e+03],
          [7.9360e+03, 7.9370e+03, 7.9380e+03,  ..., 8.1890e+03,
           8.1900e+03, 8.1910e+03]],

         [[8.1920e+03, 8.1930e+03, 8.1940e+03,  ..., 8.4450e+03,
           8.4460e+03, 8.4470e+03],
          [8.4480e+03, 8.4490e+03, 8.4500e+03,  ..., 8.7010e+03,
           8.7020e+03, 8.7030e+03],
          [8.7040e+03, 8.7050e+03, 8.7060e+03,  ..., 8.9570e+03,
           8.9580e+03, 8.9590e+03],
          ...,
          [1.5616e+04, 1.5617e+04, 1.5618e+04,  ..., 1.5869

In [35]:
print(a)

(tensor([[[[ 0.2914,  0.9260, -1.3055,  ...,  0.3967, -0.5718,  1.0182],
          [ 0.2914,  0.9260, -1.3055,  ...,  0.3967, -0.5717,  1.0182],
          [ 0.2914,  0.9260, -1.3055,  ..., -0.2224, -1.0343,  1.0951],
          ...,
          [ 0.2915,  0.9258, -1.3057,  ...,  0.3965, -0.5718,  1.0181],
          [ 0.2915,  0.9258, -1.3057,  ...,  0.3965, -0.5718,  1.0181],
          [ 0.2915,  0.9258, -1.3057,  ..., -0.2227, -1.0342,  1.0952]],

         [[ 1.0356,  1.2827, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  1.6152,  0.8984,  1.5244],
          ...,
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  1.6152,  0.8984,  1.5244]],

         [[-1.1230,  0.1719, -0.8711,  ...,  1.7871,  1.0898, -0.1523],
          [-1.1230,  0.1719, 

In [32]:
quant, emb_loss, info = model1024.quantize(h)

(tensor([[[[ 0.2914,  0.9260, -1.3055,  ...,  0.3967, -0.5718,  1.0182],
          [ 0.2914,  0.9260, -1.3055,  ...,  0.3967, -0.5717,  1.0182],
          [ 0.2914,  0.9260, -1.3055,  ..., -0.2224, -1.0343,  1.0951],
          ...,
          [ 0.2915,  0.9258, -1.3057,  ...,  0.3965, -0.5718,  1.0181],
          [ 0.2915,  0.9258, -1.3057,  ...,  0.3965, -0.5718,  1.0181],
          [ 0.2915,  0.9258, -1.3057,  ..., -0.2227, -1.0342,  1.0952]],

         [[ 1.0356,  1.2827, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  1.6152,  0.8984,  1.5244],
          ...,
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  0.9814,  0.5361,  2.5293],
          [ 1.0361,  1.2832, -0.2197,  ...,  1.6152,  0.8984,  1.5244]],

         [[-1.1230,  0.1719, -0.8711,  ...,  1.7871,  1.0898, -0.1523],
          [-1.1230,  0.1719, 

RuntimeError: Tensor type unknown to einops <class 'tuple'>