In [1]:
import torch
from torch import nn
from data.datamgr import SimpleDataManager , SetDataManager
from models.predesigned_modules import resnet12
import sys
import os
from utils import *
# os.environ['CUDA_VISIBLE_DEVICES'] = "3"

import time
import numpy as np
import warnings
warnings.filterwarnings('ignore')
# fix seed
np.random.seed(1)
torch.manual_seed(1)
import tqdm
from torch.nn.parallel import DataParallel
# torch.backends.cudnn.benchmark = True
from models.models_mae import mae_vit_base_patch16
# from sklearn import svm     #导入算法模块

#--------------参数设置--------------------
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--image_size', default=224, type=int, choices=[84, 224], help='input image size, 84 for miniImagenet and tieredImagenet, 224 for cub')
parser.add_argument('--dataset', default='mini_imagenet', choices=['mini_imagenet','tiered_imagenet','cub'])
parser.add_argument('--data_path', default='/home/jiangweihao/data/mini-imagenet/',type=str, help='dataset path')

parser.add_argument('--train_n_episode', default=300, type=int, help='number of episodes in meta train')
parser.add_argument('--val_n_episode', default=300, type=int, help='number of episodes in meta val')
parser.add_argument('--train_n_way', default=5, type=int, help='number of classes used for meta train')
parser.add_argument('--val_n_way', default=5, type=int, help='number of classes used for meta val')
parser.add_argument('--n_shot', default=1, type=int, help='number of labeled data in each class, same as n_support')
parser.add_argument('--n_query', default=15, type=int, help='number of unlabeled data in each class')
parser.add_argument('--num_classes', default=64, type=int, help='total number of classes in pretrain')

parser.add_argument('--batch_size', default=128, type=int, help='total number of batch_size in pretrain')
# parser.add_argument('--freq', default=10, type=int, help='total number of inner frequency')

parser.add_argument('--momentum', default=0.9, type=int, help='parameter of optimization')
parser.add_argument('--weight_decay', default=5.e-4, type=int, help='parameter of optimization')

parser.add_argument('--gpu', default='3')
parser.add_argument('--epochs', default=100)

params = parser.parse_args(args=['--gpu', '4',  '--epochs','1'])

# -------------设置GPU--------------------
set_gpu(params.gpu)
# -------------导入数据--------------------

json_file_read = False
if params.dataset == 'mini_imagenet':
        base_file = 'train'
        val_file = 'val'
        params.num_classes = 64
elif params.dataset == 'cub':
    base_file = 'base.json'
    val_file = 'val.json'
    json_file_read = True
    params.num_classes = 200
elif params.dataset == 'tiered_imagenet':
    base_file = 'train'
    val_file = 'val'
    params.num_classes = 351
else:
    ValueError('dataset error')

# -----------  base data ----------------------
base_datamgr = SimpleDataManager(params.data_path, params.image_size, batch_size=params.batch_size, json_read=json_file_read)
base_loader = base_datamgr.get_data_loader(base_file, aug=True)

#-----------  train data ----------------------
train_few_shot_params = dict(n_way=params.train_n_way, n_support=params.n_shot)
train_datamgr = SetDataManager(params.data_path, params.image_size, n_query=params.n_query, n_episode=params.train_n_episode, json_read=json_file_read, **train_few_shot_params)
train_loader = train_datamgr.get_data_loader(base_file, aug=True)

#------------ val data ------------------------
test_few_shot_params = dict(n_way=params.val_n_way, n_support=params.n_shot)
val_datamgr = SetDataManager(params.data_path, params.image_size, n_query=params.n_query, n_episode=params.val_n_episode, json_read=json_file_read, **test_few_shot_params)
val_loader = val_datamgr.get_data_loader(val_file, aug=False)

  from .autonotebook import tqdm as notebook_tqdm


set gpu: 4


In [2]:

# ----------- 导入模型 -------------------------
model = mae_vit_base_patch16()
state_dict = torch.load('/home/jiangweihao/code/MAE_fsl/mae_pretrain_vit_base.pth')
state_dict = state_dict['model']
model.load_state_dict(state_dict,strict=False)  # 
model.cuda()

# from torchinfo import summary
# summary(model,[5,3,224,224])

# del model.fc                         # 删除最后的全连接层
model.eval()

def cache_model(support,query,model,mask_ratio=[0, 0.25, 0.5, 0.75],modal='mean'):
    
    with torch.no_grad():
        # Data augmentation for the cache model
        for i, mask in enumerate(mask_ratio):
            
            support_f_m, _, _ = model.forward_encoder(support,mask_ratio=mask)
            query_f_m, _, _ = model.forward_encoder(query,mask_ratio=mask)
            support_cls_token_m = support_f_m[:,0,:]                # 把cls_token分离出来
            query_cls_token_m = query_f_m[:,0,:]
            if modal == 'mean':
                support_f_m = support_f_m[:,1:,:].mean(dim=1,keepdim=True)
                query_f_m = query_f_m[:,1:,:].mean(dim=1,keepdim=True)
            else:
                support_f_m = support_f_m[:,1:,:]
                query_f_m = query_f_m[:,1:,:]
            if i==0:
                support_f = support_f_m 
                query_f = query_f_m 
                support_cls_token = support_cls_token_m
                query_cls_token = query_cls_token_m
            else:
                support_f = torch.cat((support_f,support_f_m),1)
                query_f = torch.cat((query_f,query_f_m),1) 
                support_cls_token = torch.cat((support_cls_token,support_cls_token_m),1)
                query_cls_token = torch.cat((query_cls_token,query_cls_token_m),1) 

    if modal == 'mean':
        support_f = support_f.mean(dim=1).squeeze(1)   
        query_f = query_f.mean(dim=1).squeeze(1) 


    # support_cls_token = support_cls_token.mean(dim=1)  
    # query_cls_token = query_cls_token.mean(dim=1) 

    # 归一化
    # support_f_m = support_f.mean(dim=-1, keepdim=True)
    # support_f = support_f - support_f_m
    support_f /= support_f.norm(dim=-1, keepdim=True)
    support_cls_token /= support_cls_token.norm(dim=-1, keepdim=True)
    # query_f_m = query_f.mean(dim=-1, keepdim=True)
    # query_f = query_f - query_f_m
    query_f /= query_f.norm(dim=-1, keepdim=True)
    query_cls_token /= query_cls_token.norm(dim=-1, keepdim=True)

    return support_f, support_cls_token, query_f, query_cls_token

def catch_feature(query, model, mask_ratio=0):

    with torch.no_grad():    

        feature, _, _ = model.forward_encoder(query,mask_ratio=mask_ratio)

    return feature[:,0,:],feature[:,1:,:]

# ---------------------------------------------
loss_fn = torch.nn.CrossEntropyLoss()

epochs = 100

start = time.time()


log('==========start testing on train set===============')

# for epoch in range(epochs):   
    
out_avg_loss = []
timer = Timer()
                
avg_loss = 0
total_correct = 0
val_acc = []




In [3]:
# for idy, (temp2,target) in enumerate(train_loader):   
temp2, target =next(iter(train_loader))

In [4]:
temp2.shape

torch.Size([5, 16, 3, 224, 224])

In [5]:
# 拆分support,query
support,query = temp2.split([params.n_shot,params.n_query],dim=1)
cache_values, q_values = target.split([params.n_shot,params.n_query],dim=1)

# cache_values = F.one_hot(cache_values).half()
cache_values = cache_values.reshape(-1,cache_values.shape[-1])[:,0]
q_values = q_values.reshape(-1)
cache_values, q_values = cache_values.cuda(), q_values.cuda()

n,k,c,h,w = support.shape
support = support.reshape(-1,c,h,w)
support = support.cuda()
query = query.reshape(-1,c,h,w)
query = query.cuda()


In [6]:
print(support.shape)
print(query.shape)

torch.Size([5, 3, 224, 224])
torch.Size([75, 3, 224, 224])


In [8]:

def patchify(imgs, patch_size=16):
        """
        imgs: (N, 3, H, W)
        x: (N, L, patch_size**2 *3)
        """
        p = patch_size
        assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

        h = w = imgs.shape[2] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

def unpatchify(x, patch_size=16):
    """
    x: (N, L, patch_size**2 *3)
    imgs: (N, 3, H, W)
    """
    p = patch_size
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]
    
    x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
    return imgs

def random_masking(x, mask_ratio=0.5):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence
    """
    N, L, D = x.shape  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))
    
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
    
    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    ids_keep = ids_shuffle[:, :len_keep]
    x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # ids_keep = ids_shuffle[:, len_keep:L]
    # y_masked = torch.gather(y, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)

    return x_masked, mask, ids_restore

In [17]:
print(cache_values)
print(q_values.unsqueeze(1).repeat(1,5).shape)

tensor([36, 37,  3, 42, 53], device='cuda:0')
torch.Size([75, 5])


In [22]:
label = torch.eq(q_values.unsqueeze(1).repeat(1,5),cache_values.unsqueeze(0).repeat(75,1)).type(torch.int16)
print(label)


tensor([[1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0],


In [41]:
# 将 query 和 support的图片进行交叉融合
#--------方法1：将各自取50%，然后直接拼接-----------
query_patch = patchify(query)          # torch.Size([75, 196, 768])
support_patch = patchify(support)  
query_patch, _, _ = random_masking(query_patch)         # torch.Size([75, 98, 768])
support_patch, _, _ = random_masking(support_patch)
print(query_patch.shape)
print(support_patch.shape)
imags = torch.cat((query_patch.unsqueeze(1).repeat(1,5,1,1), support_patch.unsqueeze(0).repeat(75,1,1,1)), dim=2)
print(imags.shape)
imags = imags.reshape(-1,imags.shape[2],imags.shape[3])
imags = unpatchify(imags)
print(imags.shape)
label = label.reshape(-1)

# -------方法2：将对应mask的位置互补----------------
# 生成mask， imags = query*mask + support*(1-mask)
def random_compose(x, y, mask_ratio=0.5):
    """
    Perform per-sample random masking by per-sample shuffling.
    Per-sample shuffling is done by argsort random noise.
    x: [N, L, D], sequence, query
    y: [N, L, D], sequence, support
    """
    N, L, D = x.shape  # batch, length, dim
    N1 = y.shape[0]  # batch, length, dim
    len_keep = int(L * (1 - mask_ratio))
    
    noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
    
    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # keep the first subset
    # ids_keep = ids_shuffle[:, :len_keep]
    # x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # ids_keep = ids_shuffle[:, len_keep:L]
    # y_masked = torch.gather(y, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

    # generate the binary mask: 0 is keep, 1 is remove
    mask = torch.ones([N, L], device=x.device)
    one = mask
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)
    mask = mask.unsqueeze(1).repeat(1,N1,1).unsqueeze(-1).repeat(1,1,1,D)
    one = one.unsqueeze(1).repeat(1,N1,1).unsqueeze(-1).repeat(1,1,1,D)
    x = x.unsqueeze(1).repeat(1,N1,1,1)
    y = y.unsqueeze(0).repeat(N,1,1,1)
    x = x * mask
    y = y * (1-mask)
   
    x = x + y

    return x


torch.Size([75, 98, 768])
torch.Size([5, 98, 768])
torch.Size([75, 5, 196, 768])
torch.Size([375, 3, 224, 224])


In [42]:
query_patch = patchify(query)          # torch.Size([75, 196, 768])
support_patch = patchify(support)
images = random_compose(query_patch,support_patch)

In [43]:
images.shape

torch.Size([75, 5, 196, 768])

In [36]:
x = patchify(query)
mask_ratio = 0.5
N, L, D = x.shape  # batch, length, dim
len_keep = int(L * (1 - mask_ratio))

noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]

# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)  # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)

# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# ids_keep = ids_shuffle[:, len_keep:L]
# y_masked = torch.gather(y, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)

In [37]:
mask = mask.unsqueeze(1).repeat(1,5,1).unsqueeze(-1).repeat(1,1,1,D)
print(mask.shape)
query_patch = patchify(query)          # torch.Size([75, 196, 768])
support_patch = patchify(support)
x = query_patch.unsqueeze(1).repeat(1,5,1,1)
y = support_patch.unsqueeze(0).repeat(75,1,1,1)
x = x * mask
y = y * mask
print(x.shape)
imags = x + y
# imags = torch.cat((query_patch.unsqueeze(1).repeat(1,5,1,1), support_patch.unsqueeze(0).repeat(75,1,1,1)), dim=2)

torch.Size([75, 5, 196, 768])
torch.Size([75, 5, 196, 768])


In [10]:
query_patch = patchify(query)
# print(query_patch.shape)
# query_recover = unpatchify(query_patch)
# print(query_recover.shape)
# print(query == query_recover)
x, mask, ids_restore = random_masking(query_patch)
print(x.shape)
print(mask.shape)
print(ids_restore.shape)

torch.Size([75, 98, 768])
torch.Size([75, 196])
torch.Size([75, 196])


In [11]:
print(mask[0][0:10])
print(ids_restore[0][0:10])

tensor([1., 0., 1., 1., 1., 1., 0., 0., 1., 0.], device='cuda:0')
tensor([173,   5, 175, 109, 147, 176,  67,   8, 130,   3], device='cuda:0')


In [None]:
# -----------feature extractor------------------
mask_ratio=[0,0.25,0.5,0.75]       # 0,0.25,0.5,0.75
support_f , support_cls_token, query_f, query_cls_token = cache_model(support,query,model,mask_ratio=mask_ratio,modal='else')

In [None]:
print("support_f.shape:",support_f.shape)
print("support_cls_token.shape:",support_cls_token.shape)
print("query_f.shape:",query_f.shape)
print("query_cls_token.shape:",query_cls_token.shape)

In [None]:
len_list = (np.array([1.0,1.0,1.0,1.0])-np.array(mask_ratio))*196
len_list = len_list.cumsum()

In [None]:
int(len_list[3])

In [None]:
support_f_1 = support_f[:,0:int(len_list[0]),:]
support_f_2 = support_f[:,int(len_list[0]):int(len_list[1]),:]
support_f_3 = support_f[:,int(len_list[1]):int(len_list[2]),:]
support_f_4 = support_f[:,int(len_list[2]):int(len_list[3]),:]


In [None]:
l = int(support_cls_token.shape[1]/4)
support_cls_token_1 = support_cls_token[:,0:l]
support_cls_token_2 = support_cls_token[:,l:2*l]
support_cls_token_3 = support_cls_token[:,2*l:3*l]
support_cls_token_4 = support_cls_token[:,3*l:]
support_cls_token_4.shape

In [None]:

query_f_1 = query_f[:,0:int(len_list[0]),:]
query_f_2 = query_f[:,int(len_list[0]):int(len_list[1]),:]
query_f_3 = query_f[:,int(len_list[1]):int(len_list[2]),:]
query_f_4 = query_f[:,int(len_list[2]):int(len_list[3]),:]
query_f_4.shape

In [None]:
query_cls_token_1 = query_cls_token[:,0:l]
query_cls_token_2 = query_cls_token[:,l:2*l]
query_cls_token_3 = query_cls_token[:,2*l:3*l]
query_cls_token_4 = query_cls_token[:,3*l:4*l]
query_cls_token_4.shape

### 计算不同mask 下得到的cls_token对余弦度量的关系

In [None]:

r15 = support_cls_token_1 @ support_cls_token_3.t()
r15

In [None]:
## 计算不同mask 下得到的query和support cls_token之间的关系
qs_33 = query_cls_token_3 @ support_cls_token_3.t()
qs_33.shape

In [None]:
def acc(relation):
    y = np.repeat(range(params.val_n_way),params.n_query)
    y = torch.from_numpy(y)
    y = y.cuda()


    pred = relation.data.max(1)[1]
    cos_acc = pred.eq(y).sum()/(params.train_n_way*params.n_query)
    return cos_acc,pred

In [None]:
a , pred= acc(qs_33)
print(a)
print(pred)

In [None]:
y = np.repeat(range(params.val_n_way),params.n_query)
y = torch.from_numpy(y)
y = y.cuda()
print(y==pred)

In [None]:
query_cls_token.shape


In [None]:
query_cls_token_reshape = query_cls_token.reshape(75,4,-1).unsqueeze(1).repeat(1,5,1,1)
support_cls_token_reshape = support_cls_token.reshape(5,4,-1).unsqueeze(0).repeat(75,1,1,1)
support_cls_token_reshape.shape

qs = query_cls_token_reshape @ support_cls_token_reshape.transpose(3,2)

In [None]:
print(qs.shape)
qs = qs.reshape(75,5,-1)


In [None]:
qs_sum = qs.sum(-1)
print(qs_sum.shape)
acc_sum,pred = acc(qs_sum)
print(acc_sum)
print(pred)

In [None]:
acc_all = []
pred_all = []
for i in range(qs.shape[-1]):
    a,pred = acc(qs[:,:,i])
    acc_all.append(a)
    pred_all.append(pred)
acc_all = np.array(torch.tensor(acc_all,device='cpu')).reshape(4,4)
print(acc_all)
pred_all = [pred.tolist() for pred in pred_all]
pred_all = torch.tensor(pred_all)
print(pred_all)

### 计算不同mask情况下patch 对分类的作用

In [None]:
# query_f_4:torch.Size([75, 49, 768])
# support_f_4:torch.Size([5, 49, 768])
def cal_sim_patch(qf,sf):
    qf = qf.unsqueeze(1).repeat(1,5,1,1)
    sf = sf.unsqueeze(0).repeat(75,1,1,1)
    sim = qf @ sf.transpose(3,2)
    return sim

In [None]:
query_f.shape

In [None]:
sim = cal_sim_patch(query_f,support_f)

In [None]:
print(sim.shape)
print(len_list)

In [None]:
sim_all = sim.sum(-1).sum(-1)
acc_p_all,pred = acc(sim_all)
print(acc_p_all)

In [None]:
len_list = [int(l) for l in len_list]
sim_1 = sim[:,:,0:len_list[0],0:len_list[0]]
print(sim_1.shape)
sim_2 = sim[:,:,len_list[0]:len_list[1],len_list[0]:len_list[1]]
print(sim_2.shape)
sim_3 = sim[:,:,len_list[1]:len_list[2],len_list[1]:len_list[2]]
print(sim_3.shape)
sim_4 = sim[:,:,len_list[2]:len_list[3],len_list[2]:len_list[3]]
print(sim_4.shape)

In [None]:
acc_p_1,_ = acc(sim_1.sum(-1).sum(-1))
print(acc_p_1)
acc_p_2,_ = acc(sim_2.sum(-1).sum(-1))
print(acc_p_2)
acc_p_3,_ = acc(sim_3.sum(-1).sum(-1))
print(acc_p_3)
acc_p_4,_ = acc(sim_4.sum(-1).sum(-1))
print(acc_p_4)

In [None]:
def sim_knn(sim,neighbor_k=3):
    q,s,_,_= sim.shape
    sim = sim.reshape(-1,sim.shape[2],sim.shape[3])
    inner_sim = torch.zeros(sim.shape[0]).cuda()
    for i in range(sim.shape[0]):
        topk_value, topk_index = torch.topk(sim[i], neighbor_k, 1)
        inner_sim[i] = torch.sum(topk_value)
    inner_sim = inner_sim.reshape(q,s)
    return inner_sim

In [None]:
k_sim1 = sim_knn(sim_1)
print(k_sim1.shape)

In [None]:
# 计算每个patch前k个相似度
acc_p,_ = acc(sim_knn(sim))
print(acc_p)
acc_p_1,_ = acc(sim_knn(sim_1))
print(acc_p_1)
acc_p_2,_ = acc(sim_knn(sim_2))
print(acc_p_2)
acc_p_3,_ = acc(sim_knn(sim_3))
print(acc_p_3)
acc_p_4,_ = acc(sim_knn(sim_4))
print(acc_p_4)

### 单纯的依赖patch来计算相似度也不行
### 利用cls_token和patch进行交叉

#### 同一个图像，不同mask之下的cls_token均具有区分度，即自相关系数会高，互相关系数会低；单纯依赖cls_token之间的相似度分类，只能有0.4且mask=0.5;


##### 验证以下：同一幅图像，各patch与cls_token之间的相似度

In [None]:

cls_patch_sim = support_cls_token_1.unsqueeze(1) @ support_f_1.transpose(2,1)
cls_patch_sim = cls_patch_sim.squeeze(1)
print(cls_patch_sim.shape)

In [None]:
neighbor_k = 15
topk_value, topk_index = torch.topk(cls_patch_sim, neighbor_k, 1)
print(topk_value.shape)
print(topk_index)

In [None]:
sf = support_f_1[:,topk_index,:]
sf.shape

In [None]:
a = torch.randint(1,9,(3,4,5))
print(a)
b = torch.randint(1,4,(3,2))
print(b)

In [None]:
aa = a[:,b,:]
# print(aa)
aaa = aa[:,1,:,:].squeeze(1)
print(aaa)

In [None]:
print(cls_patch_sim.sum(-1))

##### 用分别来自query和support中的cls_token 与 patch 交叉关系来确定选取那些patch代表 整个image

In [None]:
# 利用support的cls_token 和 query 的patch 的关系矩阵挑选 代表query的 patch组合
def select_query_patch(cls_token,patch,neighbor_k = 3):
    sim = patch.unsqueeze(1).repeat(1,5,1,1) @ cls_token.unsqueeze(1).unsqueeze(0).repeat(75,1,1,1).transpose(3,2)  # [75,196,768] [5,768]  --> [75,5,196,1]
    sim = sim.squeeze(-1)                 # [75,5,196]
    new_patch = torch.zeros(patch.shape[0],patch.shape[-1])
    for i in range(sim.shape[0]):
        _, topk_index = torch.topk(sim[i,:,:], neighbor_k, -1)             # [5,196]   -->    [5,neighbor_k]
    # _, topk_index = torch.topk(sim[0,:,:], neighbor_k, -1)
        tik = set(np.array(torch.tensor(topk_index.reshape(-1),device='cpu')))          # 求合集
        tik = np.array(list(tik))
        # new_patch = patch[i,tik,:]
        new_patch[i] = patch[i,tik,:].mean(dim=0)
    return new_patch
# 利用query的cls_token 和 support 的patch 的关系矩阵挑选 代表support的 patch组合
def select_support_patch(cls_token,patch,neighbor_k = 3):
    sim = patch.unsqueeze(0).repeat(75,1,1,1) @ cls_token.unsqueeze(1).unsqueeze(1).repeat(1,5,1,1).transpose(3,2)
    sim = sim.squeeze(-1)
    sim = sim.permute(1,0,2)
    new_patch = torch.zeros(patch.shape[0],patch.shape[-1])
    for i in range(sim.shape[0]):
        _, topk_index = torch.topk(sim[i,:,:], neighbor_k, -1)
    # _, topk_index = torch.topk(sim[0,:,:], neighbor_k, -1)
        tik = set(np.array(torch.tensor(topk_index.reshape(-1),device='cpu')))          # 求合集
        tik = np.array(list(tik))
        # new_patch = patch[:,tik,:]
        new_patch[i] = patch[i,tik,:].mean(dim=0)
    return new_patch

In [None]:
patch = query_f_3
cls_token = support_cls_token_1
sim = patch.unsqueeze(1).repeat(1,5,1,1) @ cls_token.unsqueeze(1).unsqueeze(0).repeat(75,1,1,1).transpose(3,2)  # [75,196,768] [5,768]  --> [75,5,196,1]
sim = sim.squeeze(-1)                 # [75,5,196]
new_patch = torch.zeros(patch.shape[0],patch.shape[-1])


In [None]:
print(patch.shape)
print(cls_token.shape)
print(sim.shape)

In [None]:
neighbor_k =15
i = 0 
_, topk_index = torch.topk(sim[i,:,:], neighbor_k, -1)             # [5,196]   -->    [5,neighbor_k]
# _, topk_index = torch.topk(sim[0,:,:], neighbor_k, -1)
tik = set(np.array(torch.tensor(topk_index.reshape(-1),device='cpu')))          # 求合集
tik = np.array(list(tik))


In [None]:

# new_patch = patch[i,tik,:]
new_patch_ = patch[i,tik,:].mean(dim=0)
new_patch_.shape

In [None]:
new_patch = select_query_patch(support_cls_token_1,query_f_3,neighbor_k = 3)
new_patch.shape

In [None]:
new_support_patch = select_support_patch(query_cls_token_3,support_f_3,neighbor_k = 3)
new_support_patch.shape

In [None]:
sim_patch = new_patch @ new_support_patch.t()
sim_patch = sim_patch.cuda()
acc_patch , pred= acc(sim_patch)
print(acc_patch)

#### 利用各自的cls_token 和 patch的关系，确定比较重要的patch

In [None]:
def select_patch(cls_token,patch,neighbor_k = 15):
    new_patch = torch.zeros(patch.shape[0],neighbor_k,patch.shape[2])
    sim = patch @ cls_token.unsqueeze(1).transpose(2,1)          # [5,196,768] @ [5,1,768].t --> [5,196,1]
    sim = sim.squeeze(-1)
    
    _, topk_index = torch.topk(sim, neighbor_k, -1)
    new_patch = patch[:,topk_index,:][:,1,:,:].squeeze(1)
    # new_patch = new_patch.mean(dim=1)
    return new_patch

In [None]:
print(support_f_1.shape)
print(support_cls_token_1.unsqueeze(1).transpose(2,1).shape)
cc = torch.bmm(support_f_1,support_cls_token_1.unsqueeze(1).transpose(2,1))
print(cc.shape)
_, topk_index = torch.topk(cc.squeeze(-1), neighbor_k, -1)
new_patch = support_f_1[:,topk_index,:][:,1,:,:].squeeze(1)
print(new_patch.shape)
new_patch = new_patch.mean(dim=1)
print(new_patch.shape)

In [None]:
spf = select_patch(support_cls_token_1,support_f_1)
print(spf.shape)
qyf = select_patch(query_cls_token_1,query_f_1)
print(qyf.shape)

In [None]:
sim_qs = cal_sim_patch(qyf,spf)
sim_select_patch = sim_knn(sim_qs)
accu,pred = acc(sim_select_patch)
print(accu)

In [None]:
# select patch 加上 cls_token
sim_qs = (qyf.mean(dim=1)+query_cls_token_3)@(spf.mean(dim=1)+support_cls_token_3).t()

accu,pred = acc(sim_qs)
print(accu)

sim_qs = (query_cls_token_3)@(support_cls_token_3).t()

accu,pred = acc(sim_qs)
print(accu)

In [1]:
#--------------- test for AverageMeter  ------------------
class AverageMeter(object):
	"""Computes and stores the average and current value"""
	def __init__(self):
		self.reset()

	def reset(self):
		self.val = 0
		self.avg = 0
		self.sum = 0
		self.count = 0

	def update(self, val, n=1):
		self.val = val
		self.sum += val * n
		self.count += n
		self.avg = self.sum / self.count

In [4]:
losses = AverageMeter()
losses.update(5,3)
print(losses.val)
print(losses.avg)

losses.update(4,3)
print(losses.val)
print(losses.avg)

5
5.0
4
4.5
