In [123]:
from __future__ import print_function

import os
import socket
import time
import sys
import subprocess
import numpy as np
import warnings
warnings.filterwarnings('ignore',category=FutureWarning)
print("parse_option")

import torch
import torch.optim as optim
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from models import model_pool
from models.util import create_model
from models.resnet_language import LangPuller


from dataset.mini_imagenet import ImageNet, MetaImageNet
from dataset.tiered_imagenet import TieredImageNet, MetaTieredImageNet
from dataset.transform_cfg import transforms_options, transforms_list

from util import adjust_learning_rate, create_and_save_embeds, create_and_save_descriptions
from eval.util import accuracy, AverageMeter, validate
from configs import parse_option_supervised
from PIL import Image
from torch.utils.data import Dataset
import os
import pickle
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import re

parse_option


In [124]:
import os
import pickle
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import re

In [125]:
class cub200(Dataset):
    #index_path -> txt_path, 
    #index -> base size
    def __init__(self, args,root='./cub', train=True,
                 index_path=None, index=None, base_sess=None,transform=None,):
        super(Dataset, self).__init__()
        self.root = root
        self.base_sess = base_sess
        self.transform = transform
        self.index_path = index_path
        self.index = index

        self.train = train  # training set or test set
        self._pre_operate()
        self.mean = [120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0]
        self.std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
        self.normalize = transforms.Normalize(mean=self.mean, std=self.std)
        self.unnormalize = transforms.Normalize(mean=-np.array(self.mean)/self.std, std=1/np.array(self.std))

        if transform is None:
            if self.base_sess == True:
                self.transform = transforms.Compose([
                    lambda x: Image.fromarray(x),
                    transforms.RandomResizedCrop(224),
                    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                    transforms.RandomHorizontalFlip(),
                    lambda x: np.asarray(x),
                    transforms.ToTensor(),
                    self.normalize
                ])
            else:
                self.transform = transforms.Compose([
                lambda x: Image.fromarray(x),
                transforms.RandomResizedCrop(224),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.RandomHorizontalFlip(),
                lambda x: np.asarray(x),
                transforms.ToTensor(),
                self.normalize
                ])
        else:
            self.transform = transform

        if self.train:
            # self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
            #base는 100까지 따라서 index = 100
            if base_sess:
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
            #novel session에 대한 세션 정보 줘야함
            else:
                self.data, self.targets = self.SelectfromTxt(self.data2label, index_path)
        else:
            if base_sess:
                print(index)
                self.data, self.targets = self.SelectfromClasses(self.data, self.targets, index)
            else:
                #modifying
                #novel, test
                self.data, self.targets = self.SelectfromNovelClasses(self.data, self.targets, index_path)
                       
            #HSJ self.labels
        self.labels = self.targets
        self.imgs = self._getImg(self.data)
            #HSJ self.imgs
        #HSJ LABELTOHUMAN

        # Labels are available by codes by default. Converting them into human readable labels.
        self.label2human =[""] *200
        with open('./cub/CUB_200_2011/' +'classes.txt', 'r') as f:
            for line in f.readlines():
                catname, humanname = line.strip().lower().split(' ')
                num,humanname = humanname.strip().lower().split('.')
                humanname = " ".join(humanname.split('_'))
                if int(catname) in range(1,201):
                    self.label2human[int(catname)-1]= humanname
        #HSJ LABELTOHUMAN

        #HSJ basec_map
        basec = np.sort(np.arange(100))
                
        # Create mapping for base classes as they are not consecutive anymore.
        self.basec_map = dict(zip(basec, np.arange(len(basec))))
        #HSJ basec_map

    def _getImg(self,d_list):
        img_list = []
        for d_path in d_list:
            c_img = Image.open(d_path).convert('RGB')
            c_img = np.array(c_img)
            c_img_transformed = self.transform(c_img)
            img_list.append(c_img_transformed.numpy())
        img_list_np = np.array(img_list)
        return img_list_np

        
    def SelectfromTxt(self, data2label, index_path):
        index = open('./cub/CUB_200_2011/index_list/session_'+ str(index_path) + '.txt').read().splitlines()
        data_tmp = []
        targets_tmp = []
        for i in index:
            img_path = self.root + '/'+ i
            data_tmp.append(img_path)
            targets_tmp.append(data2label[img_path])

        return data_tmp, targets_tmp

    def SelectfromClasses(self, data, targets, index):
        data_tmp = []
        targets_tmp = []
        for i in range(index):
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])
        return data_tmp, targets_tmp
    
    def SelectfromNovelClasses(self, data, targets,index_path):
        data_tmp = []
        targets_tmp = []
        for i in range(100+((index_path-2)*10),100+((index_path-1)*10)):
            ind_cl = np.where(i == targets)[0]
            for j in ind_cl:
                data_tmp.append(data[j])
                targets_tmp.append(targets[j])

        return data_tmp, targets_tmp
    def list2dict(self, list):
        dict = {}
        for l in list:
            s = l.split(' ')
            id = int(s[0])
            cls = s[1]
            if id not in dict.keys():
                dict[id] = cls
            else:
                raise EOFError('The same ID can only appear once')
        return dict
    
    def text_read(self, file):
        with open(file, 'r') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                lines[i] = line.strip('\n')
        return lines
        
    def _pre_operate(self):
            image_file = './cub/'+ 'CUB_200_2011/images.txt'
            split_file = './cub/'+ 'CUB_200_2011/train_test_split.txt'
            class_file = './cub/'+ 'CUB_200_2011/image_class_labels.txt'
            id2image = self.list2dict(self.text_read(image_file))
            id2train = self.list2dict(self.text_read(split_file))  # 1: train images; 0: test iamges
            id2class = self.list2dict(self.text_read(class_file))
            train_idx = []
            test_idx = []
            for k in sorted(id2train.keys()):
                if id2train[k] == '1':
                    train_idx.append(k)
                else:
                    test_idx.append(k)

            self.data = []
            self.targets = []
            self.data2label = {}
            if self.train:
                for k in train_idx:
                    image_path = './cub/'+ 'CUB_200_2011/images/'+ str(id2image[k])
                    self.data.append(image_path)
                    self.targets.append(int(id2class[k]) - 1)
                    self.data2label[image_path] = (int(id2class[k]) - 1)

            else:
                for k in test_idx:
                    image_path = './cub/'+ 'CUB_200_2011/images/'+ str(id2image[k])
                    self.data.append(image_path)
                    self.targets.append(int(id2class[k]) - 1)
                    self.data2label[image_path] = (int(id2class[k]) - 1)
            self.targets = np.array(self.targets)
                    
    def __getitem__(self, item):
        if self.base_sess:
            img = self.imgs[item]
            target = self.targets[item] - min(self.labels)
            
            return img, target,item
        else:
            if self.train == True and self.base_sess and self.n_base_support_samples > 0:
                    assert self.n_base_support_samples > 0
                    # These samples will be stored in memory for every episode.
                    support_xs = []
                    support_ys = []
                    if self.fix_seed:
                        np.random.seed(item)
                    cls_sampled = np.random.choice(self.classes, len(self.classes), False)
                    
                    for idx, cls in enumerate(np.sort(cls_sampled)):
                        imgs = np.asarray(self.data[cls]).astype('uint8')
                        support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]),
                                                                  self.n_base_support_samples,
                                                                  False)
                        support_xs.append(imgs[support_xs_ids_sampled])
                        support_ys.append([cls] * self.n_base_support_samples)    
                    support_xs, support_ys = np.array(support_xs), np.array(support_ys)
                    num_ways, n_queries_per_way, height, width, channel = support_xs.shape
                    support_xs = support_xs.reshape((-1, height, width, channel))
                    if self.n_base_aug_support_samples > 1:
                        support_xs = np.tile(support_xs, (self.n_base_aug_support_samples, 1, 1, 1))
                        support_ys = np.tile(support_ys.reshape((-1, )), (self.n_base_aug_support_samples))
                    support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
                    support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))

                    # Dummy query.
                    query_xs = support_xs
                    query_ys = support_ys
            else:
            
                if self.fix_seed:
                    np.random.seed(item)

                #몇개로 나눌지(cub는 의미 없음)
                """BytesWarning
                if self.disjoint_classes:
                    cls_sampled = self.classes[:self.n_ways] # 
                    self.classes = self.classes[self.n_ways:]
                else:
                    cls_sampled = np.random.choice(self.classes, self.n_ways, False)
                """
                cls_sampled = self.targets

                support_xs = []
                support_ys = []
                query_xs = []
                query_ys = []
                for idx, cls in enumerate(np.sort(cls_sampled)):
                    #support_xs_ids_sampled = np.random.choice(range(imgs.shape[0]), self.n_shots, False)
                    support_xs.append(self.imgs)
                    #support_xs.append(imgs[support_xs_ids_sampled])
                    lbl = idx
                    if self.eval_mode in ["few-shot-incremental-fine-tune"]:
                        lbl = cls
                    support_ys.append([lbl] * self.n_shots) #

                    #query_xs_ids = np.setxor1d(np.arange(imgs.shape[0]), support_xs_ids_sampled)
                    #query_xs_ids = np.random.choice(query_xs_ids, self.n_queries, False)
                    query_xs.append(self.imgs)
                    #query_xs.append(imgs[query_xs_ids])
                    query_ys.append([lbl] * 30) #

                support_xs, support_ys, query_xs, query_ys = np.array(support_xs), np.array(support_ys), np.array(query_xs), np.array(query_ys)
                num_ways, n_queries_per_way, height, width, channel = query_xs.shape

                query_xs = query_xs.reshape((num_ways * n_queries_per_way, height, width, channel))
                query_ys = query_ys.reshape((num_ways * n_queries_per_way, ))

                support_xs = support_xs.reshape((-1, height, width, channel))
                """
                if self.n_aug_support_samples > 1:
                    support_xs = np.tile(support_xs, (self.n_aug_support_samples, 1, 1, 1))
                    support_ys = np.tile(support_ys.reshape((-1, )), (self.n_aug_support_samples))
                """
                support_xs = np.split(support_xs, support_xs.shape[0], axis=0)
                query_xs = query_xs.reshape((-1, height, width, channel))
                query_xs = np.split(query_xs, query_xs.shape[0], axis=0)

                support_xs = torch.stack(list(map(lambda x: self.train_transform(x.squeeze()), support_xs)))
                query_xs = torch.stack(list(map(lambda x: self.test_transform(x.squeeze()), query_xs)))

        return support_xs.float(), support_ys, query_xs.float(), query_ys
            

    def __len__(self):
        return len(self.data)

In [2]:
args = lambda x: None
args.n_ways = 5
args.n_shots = 5
args.n_queries = 30
args.data_root = 'data'
args.data_aug = True
args.n_test_runs = 5
args.n_aug_support_samples = 1
args.set_seed = 20
args.continual = True
args.eval_mode = "few-shot-incremental-fine-tune"
args.n_base_support_samples = 1
args.n_base_aug_support_samples = 0

In [3]:
base_test_loader = DataLoader(ImageNet(args=args, split='train', phase='test'),
                                      batch_size=64 // 2,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=5 // 2)

In [4]:
meta_valloader = DataLoader(MetaImageNet(args=args, split='val',
                                                 
                                                 fix_seed=True, use_episodes=False, disjoint_classes=True),
                                    batch_size=64, shuffle=False, drop_last=False,
                                    num_workers=5)

In [7]:
meta_valloader_it = itertools.cycle(iter(meta_valloader))

In [6]:
import itertools

In [8]:
base_valloader_it = itertools.cycle(iter(base_test_loader))

In [9]:
base_batch = next(base_valloader_it)

In [10]:
base_batch[1]

tensor([57, 18, 48, 38, 27, 18, 15, 58, 14, 24, 36, 28, 13, 26, 18, 28, 43, 30,
        52,  9, 26,  4, 10, 30, 40, 53, 45, 33, 31, 12,  3, 23])

In [11]:
def drop_a_dim(data): #TODO why do we need this in the first place?
    support_xs, support_ys, query_xs, query_ys = data
    batch_size, _, height, width, channel = support_xs.size()
    support_xs = support_xs.view(-1, height, width, channel)
    query_xs = query_xs.view(-1, height, width, channel)
    support_ys = support_ys.view(-1).detach().numpy() # TODO
    query_ys = query_ys.view(-1).detach().numpy()
    return (support_xs, support_ys, query_xs, query_ys)

In [12]:
d_idx = drop_a_dim(next(meta_valloader_it))

In [13]:
np.sort(np.unique(d_idx[3]))

array([ 6, 13, 17, 18, 20, 22, 25, 26, 28, 34, 40, 43, 58, 61, 67, 71, 73,
       78, 81, 85, 86, 90, 95, 97, 99])

In [14]:
len([name for name in meta_valloader.dataset.label2human if name != ''])

40

In [15]:
def get_vocabs(base_loader=None, novel_loader=None, query_ys=None):
    vocab_all = []
    vocab_base = None
    if base_loader is not None:
        label2human_base = base_loader.dataset.label2human
        vocab_base  = [name for name in label2human_base if name != '']
        vocab_all  += vocab_base

    vocab_novel, orig2id = None, None

    if novel_loader is not None:
        novel_ids = np.sort(np.unique(query_ys))
        label2human_novel = novel_loader.dataset.label2human
        vocab_novel = [label2human_novel[i] for i in novel_ids]
        orig2id = dict(zip(novel_ids, len(vocab_base) + np.arange(len(novel_ids))))
        vocab_all += vocab_novel

    return vocab_base, vocab_all, vocab_novel, orig2id

In [106]:
out_vocabs = get_vocabs(train_test,val_test,d_idx[3])

In [107]:
out_vocabs[2]

['parakeet auklet',
 'indigo bunting',
 'spotted catbird',
 'gray catbird',
 'eastern towhee',
 'brandt cormorant',
 'bronzed cowbird',
 'shiny cowbird',
 'american crow',
 'purple finch',
 'scissor tailed flycatcher',
 'frigatebird',
 'california gull',
 'herring gull',
 'ruby throated hummingbird',
 'pomarine jaeger',
 'florida jay',
 'belted kingfisher',
 'ringed kingfisher',
 'pacific loon',
 'mallard',
 'mockingbird',
 'hooded oriole',
 'scott oriole',
 'brown pelican']

In [18]:
for idx in range(101,106):
    print(idx)

101
102
103
104
105


In [53]:
val_check = DataLoader(cub200(args=args, base_sess = True, train=False, index = 100, index_path = 1),
                                batch_size=64 // 2, shuffle=False, drop_last=False,
                                num_workers=5 // 2)

100


In [56]:
train_check = DataLoader(cub200(args=args, base_sess = True, train=True, index = 100, index_path = 1),
                                batch_size=64 // 2, shuffle=False, drop_last=False,
                                num_workers=5 // 2)

In [58]:
set(train_check.dataset.targets)

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99}

In [66]:
for idx, (img,target,_) in enumerate(train_check):
    print(target)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 2, 2, 2, 2])
tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        4, 4, 4, 4, 4, 4, 4, 4])
tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5])
tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6,
        6, 6, 6, 6, 6, 6, 6, 6])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7])
tensor([7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8, 8, 8])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9,
        9, 9, 9,

In [65]:
val_check.dataset.label2human[0]

'black footed albatross'

In [68]:
train_test = DataLoader(cub200(args=args,base_sess = False, train=True, index_path = 2),batch_size = 64,shuffle = False, drop_last=False, num_workers = 5)

In [72]:
train_test.dataset.

TypeError: __getitem__() missing 1 required positional argument: 'item'

In [74]:
val_test = DataLoader(cub200(args=args,base_sess = False, train=False, index_path = 2),batch_size = 64,shuffle = False, drop_last=False, num_workers = 5)

In [75]:
val_test.dataset.targets

[100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 100,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 101,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 102,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 103,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 104,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 105,
 106,
 106,
 106,
 106,
 106,
 106,
 106,
 106

In [None]:
#정리하면 support_xs 는 train img, support_ys 는 label, query_xs 는 test img, query_ys는 label vocab novel은 이번에 학습하는 train 의 label2human, vocab base에 계속 더함 

In [77]:
len(val_test.dataset.targets)

279

In [126]:
val_train = DataLoader(cub200(args=args,base_sess = False, train=True, index_path = 2),batch_size = 64,shuffle = False, drop_last=False, num_workers = 5)

val_test = DataLoader(cub200(args=args,base_sess = False, train=False, index_path = 2),batch_size = 64,shuffle = False, drop_last=False, num_workers = 5)

In [120]:
base_batch[1]

tensor([57, 18, 48, 38, 27, 18, 15, 58, 14, 24, 36, 28, 13, 26, 18, 28, 43, 30,
        52,  9, 26,  4, 10, 30, 40, 53, 45, 33, 31, 12,  3, 23])

In [117]:
base_batch_t = val_check.dataset.imgs, val_check.dataset.targets,val_check.dataset.label2human

In [122]:
len(base_batch_t[0])

2864

In [165]:
support_xs = train_test.dataset.imgs
support_ys = train_test.dataset.targets
query_xs = val_test.dataset.imgs
query_ys = val_test.dataset.targets

In [186]:
h = next(meta_valloader_it)

In [190]:
h[2].size()

torch.Size([5, 150, 3, 84, 84])

In [188]:
torch.Tensor(support_xs).size()

torch.Size([50, 3, 224, 224])

In [175]:
type(query_ys)

list

In [196]:
np.array(support_ys)

array([100, 100, 100, 100, 100, 101, 101, 101, 101, 101, 102, 102, 102,
       102, 102, 103, 103, 103, 103, 103, 104, 104, 104, 104, 104, 105,
       105, 105, 105, 105, 106, 106, 106, 106, 106, 107, 107, 107, 107,
       107, 108, 108, 108, 108, 108, 109, 109, 109, 109, 109])

In [197]:
d_idx[1]

array([13, 13, 13, 13, 13, 18, 18, 18, 18, 18, 58, 58, 58, 58, 58, 67, 67,
       67, 67, 67, 78, 78, 78, 78, 78, 17, 17, 17, 17, 17, 26, 26, 26, 26,
       26, 28, 28, 28, 28, 28, 73, 73, 73, 73, 73, 97, 97, 97, 97, 97, 43,
       43, 43, 43, 43, 61, 61, 61, 61, 61, 85, 85, 85, 85, 85, 95, 95, 95,
       95, 95, 99, 99, 99, 99, 99,  6,  6,  6,  6,  6, 34, 34, 34, 34, 34,
       81, 81, 81, 81, 81, 86, 86, 86, 86, 86, 90, 90, 90, 90, 90, 20, 20,
       20, 20, 20, 22, 22, 22, 22, 22, 25, 25, 25, 25, 25, 40, 40, 40, 40,
       40, 71, 71, 71, 71, 71])

In [180]:
query_xs.shape

(279, 3, 224, 224)

In [193]:
type(d_idx[0])

torch.Tensor

In [194]:
d_idx[0].shape

torch.Size([125, 3, 84, 84])

In [133]:
import copy

In [134]:
model = create_model('resnet18', 100, args, vocab=None, dataset='cub200')
model.load_state_dict(ckpt['model'])
basenet = copy.deepcopy(model).cuda()
base_weight, base_bias = basenet._get_base_weights()

NotImplementedError: dataset not supported: cub200

In [136]:
t = [[1,2,3],[4,5,6],[7,8,9]]

In [141]:
k = torch.rand(100,640)
l = torch.rand(10,640)

In [156]:
k.shape

torch.Size([100, 640])

In [151]:
t = torch.cat([k,l],0)

In [149]:
k.size(0)

100

In [159]:
t  = torch.nn.Parameter(t,requires_grad=False)

In [164]:
torch.norm(t[0:k.size(0), :]- k)

tensor(0.)

In [109]:
m_v_i = get_vocabs(train_check,val_test,val_test.dataset.targets)

In [112]:
train_check.dataset.label2human[0:100]

['black footed albatross',
 'laysan albatross',
 'sooty albatross',
 'groove billed ani',
 'crested auklet',
 'least auklet',
 'parakeet auklet',
 'rhinoceros auklet',
 'brewer blackbird',
 'red winged blackbird',
 'rusty blackbird',
 'yellow headed blackbird',
 'bobolink',
 'indigo bunting',
 'lazuli bunting',
 'painted bunting',
 'cardinal',
 'spotted catbird',
 'gray catbird',
 'yellow breasted chat',
 'eastern towhee',
 'chuck will widow',
 'brandt cormorant',
 'red faced cormorant',
 'pelagic cormorant',
 'bronzed cowbird',
 'shiny cowbird',
 'brown creeper',
 'american crow',
 'fish crow',
 'black billed cuckoo',
 'mangrove cuckoo',
 'yellow billed cuckoo',
 'gray crowned rosy finch',
 'purple finch',
 'northern flicker',
 'acadian flycatcher',
 'great crested flycatcher',
 'least flycatcher',
 'olive sided flycatcher',
 'scissor tailed flycatcher',
 'vermilion flycatcher',
 'yellow bellied flycatcher',
 'frigatebird',
 'northern fulmar',
 'gadwall',
 'american goldfinch',
 'euro

In [111]:
m_v_i[0]

['black footed albatross',
 'laysan albatross',
 'sooty albatross',
 'groove billed ani',
 'crested auklet',
 'least auklet',
 'parakeet auklet',
 'rhinoceros auklet',
 'brewer blackbird',
 'red winged blackbird',
 'rusty blackbird',
 'yellow headed blackbird',
 'bobolink',
 'indigo bunting',
 'lazuli bunting',
 'painted bunting',
 'cardinal',
 'spotted catbird',
 'gray catbird',
 'yellow breasted chat',
 'eastern towhee',
 'chuck will widow',
 'brandt cormorant',
 'red faced cormorant',
 'pelagic cormorant',
 'bronzed cowbird',
 'shiny cowbird',
 'brown creeper',
 'american crow',
 'fish crow',
 'black billed cuckoo',
 'mangrove cuckoo',
 'yellow billed cuckoo',
 'gray crowned rosy finch',
 'purple finch',
 'northern flicker',
 'acadian flycatcher',
 'great crested flycatcher',
 'least flycatcher',
 'olive sided flycatcher',
 'scissor tailed flycatcher',
 'vermilion flycatcher',
 'yellow bellied flycatcher',
 'frigatebird',
 'northern fulmar',
 'gadwall',
 'american goldfinch',
 'euro

In [85]:
m_v_i = next(meta_valloader_it)

In [99]:
train_test.dataset.imgs.shape

(50, 3, 224, 224)

In [90]:
batch_size, _, height, width, channel = m_v_i[0].size()

In [101]:
d_idx[0].size()

torch.Size([125, 3, 84, 84])

In [100]:
m_v_i[0].size()

torch.Size([5, 25, 3, 84, 84])