In [1]:
import argparse, random
from tqdm import tqdm

import torch
from torch.optim.lr_scheduler import StepLR

# custom packages
from configs.conf import configuration
from dataset_mini import *

In [2]:
# deal with params
parser = argparse.ArgumentParser()
parser.add_argument('--exp_name', type=str, default='debug', help="model_name")
parser.add_argument('--gpu', type=int, default=1, help="gpu")
parser.add_argument('--n_epochs', type=int, default=2000, help="epoch") # 2000
parser.add_argument('--alg', type=str, default='cycle_2', help="alg")
parser.add_argument('--command', type=str, default='train', help="train or infer")
get_args, _ = parser.parse_known_args()

args = configuration()
args.exp_name = get_args.exp_name
args.gpu = get_args.gpu
args.n_epochs = get_args.n_epochs
args.alg = get_args.alg
args.command = get_args.command

args.initialize()

In [3]:
# RANDOM SEED
torch.manual_seed(args.seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
torch.set_num_threads(2)
torch.backends.cudnn.benchmark = True

In [4]:
im_width, im_height, channels = list(map(int, args.x_dim.split(',')))
print(im_width, im_height, channels)

# Step 1: init dataloader
print("init data loader")
if args.dataset == 'mini':
    loader_train = dataset_mini(args.n_examples, args.n_episodes, 'train', args)
    loader_val = dataset_mini(args.n_examples, args.n_episodes, 'val', args)    

84 84 3
init data loader


In [None]:
root_dir = ./dataset/miniImagenet'

In [11]:
self = loader_val

In [12]:
"""
    load the pkl processed mini-imagenet into label,unlabel
"""
pkl_name = '{}/data/mini-imagenet-cache-{}.pkl'.format(self.root_dir, self.split)
print('Loading pkl dataset: {} '.format(pkl_name))

try:
    with open(pkl_name, "rb") as f:
        data         = pkl.load(f, encoding='bytes')
        image_data   = data[b'image_data']
        class_dict   = data[b'class_dict']
except:
    with open(pkl_name, "rb") as f:
        data         = pkl.load(f)
        image_data   = data['image_data']
        class_dict   = data['class_dict']

#         print(data.keys(), image_data.shape, class_dict.keys())
data_classes     = sorted(class_dict.keys()) # sorted to keep the order

n_classes        = len(data_classes)
print('n_classes:{}, n_label:{}, n_unlabel:{}'.format(n_classes,self.n_label,self.n_unlabel))
dataset_l        = np.zeros([n_classes, self.n_label, self.im_height, self.im_width, self.channels], dtype=np.float32)
if self.n_unlabel>0:
    dataset_u    = np.zeros([n_classes, self.n_unlabel, self.im_height, self.im_width, self.channels], dtype=np.float32)
else:
    dataset_u    = []

for i, cls in enumerate(data_classes):
    idxs         = class_dict[cls] 
    np.random.RandomState(self.seed).shuffle(idxs) # fix the seed to keep label,unlabel fixed
    dataset_l[i] = image_data[idxs[0:self.n_label]]
    if self.n_unlabel>0:
        dataset_u[i] = image_data[idxs[self.n_label:]]
print('labeled data:', np.shape(dataset_l))
print('unlabeled data:', np.shape(dataset_u))

self.dataset_l   = dataset_l
self.dataset_u   = dataset_u
self.n_classes   = n_classes

Loading pkl dataset: ./dataset/miniImagenet/data/mini-imagenet-cache-val.pkl 
n_classes:16, n_label:600, n_unlabel:0
labeled data: (16, 600, 84, 84, 3)
unlabeled data: (0,)


In [13]:
self.n_classes

16

In [14]:
data_classes

['n01855672',
 'n02091244',
 'n02114548',
 'n02138441',
 'n02174001',
 'n02950826',
 'n02971356',
 'n02981792',
 'n03075370',
 'n03417042',
 'n03535780',
 'n03584254',
 'n03770439',
 'n03773504',
 'n03980874',
 'n09256479']

In [22]:
dict_class = {}
for ep in range(100):
    selected_classes = np.random.permutation(data_classes)[:5]
    for cls in selected_classes:
        if cls not in dict_class.keys():
            dict_class[cls] = 0
        dict_class[cls] += 20

In [23]:
dict_class

{'n03773504': 720,
 'n02971356': 720,
 'n09256479': 720,
 'n03417042': 720,
 'n01855672': 520,
 'n02114548': 640,
 'n03535780': 620,
 'n03980874': 660,
 'n02950826': 620,
 'n03770439': 660,
 'n02174001': 640,
 'n02138441': 660,
 'n02981792': 520,
 'n03075370': 560,
 'n03584254': 440,
 'n02091244': 580}

In [11]:
self.dataset_l.shape

(16, 600, 84, 84, 3)

In [12]:
selected_classes = np.random.permutation(self.n_classes)[:5]

In [13]:
selected_classes

array([9, 8, 5, 3, 0])

In [7]:
loader_val.load_data_pkl()

Loading pkl dataset: ./dataset/miniImagenet/data/mini-imagenet-cache-val.pkl 
n_classes:16, n_label:600, n_unlabel:0
labeled data: (16, 600, 84, 84, 3)
unlabeled data: (0,)


In [8]:
support, s_labels, query, q_labels, unlabel = loader_val.next_data(args.n_test_way, args.n_test_shot,
                                                                               args.n_test_query)

In [9]:
s_labels

array([[0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3],
       [4, 4, 4, 4, 4]], dtype=uint8)