In [1]:
import  torch, os
import  numpy as np

from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
from matplotlib import pyplot as plt
from PIL import Image
import json

from datetime import datetime
from meta_ganpp import MetaGAN
from dataloader import train_data_generator,test_data_generator
import tqdm
import shutil
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.max_memory_allocated(device=device)

  from .autonotebook import tqdm as notebook_tqdm


0

In [2]:
def mkdir_p(path):
    if not os.path.exists(path):
        os.makedirs(path)

def save_train_accs(path, accs):
    file = open(path +  '/q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["q_nway"]]))
    file.close()

    file = open(path +  '/discrim_loss.txt', 'ab')
    np.savetxt(file, np.array([accs["discrim_loss"]]))
    file.close()

    file = open(path +  '/gen_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["gen_nway"]]))
    file.close()

    file = open(path +  '/gen_loss.txt', 'ab')
    np.savetxt(file, np.array([accs["gen_loss"]]))
    file.close()

def save_test_accs(path, accs):
    file = open(path +  '/test_q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs]))
    file.close()

def save_imgs(path, imgs, step):
    # save raw txt files
    img_f=open(path+"/images_step" + str(step) + ".txt",'ab')
    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]
    np.savetxt(img_f,some_imgs)
    img_f.close()

    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,84,84).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 49:
            plt.subplot(7, 7, 1 + i)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig(path+"/images_step" + str(step) + ".png")
    plt.close()

In [3]:
args = {'epoch':24000,
        'n_way':5,
        'k_spt':1,
        'k_qry':15,
        'img_sz':84,
        "tasks_per_batch":4,
        'img_c':3,
        'task_num': 4,
        'meta_lr':1e-3,
        'update_lr':1e-3,
        'gan_update_lr':2e-4,
        'update_steps':4,
        'update_steps_test':10,
        "no_save":False,
        "learn_inner_lr":True,
        'condition_discrim':False,
        "loss":"cross_entropy",
        "create_graph":False,
        "single_fast_test":False,
        "consine_schedule":True,
        "min_learning_rate":1e-10,
        "number_of_training_steps_per_iter":4,
        "multi_step_loss_num_epochs":15,
        'save_path':'0425_consine_mamlgan_plus2'
       }

if os.path.exists(args["save_path"]):
    shutil.rmtree(args["save_path"])

In [4]:
train_data_generator = train_data_generator(args)
test_data_generator = test_data_generator(args)

load dataset/BelgiumTSC
load complete time 3.410439968109131
load dataset/ArTS
load complete time 3.38081955909729
load dataset/chinese_traffic_sign
load complete time 0.6300342082977295
load dataset/CVL
load complete time 0.4109306335449219
load dataset/FullJCNN2013
load complete time 0.2707638740539551
load dataset/logo_2k
load complete time 1.054626703262329
load dataset/GTSRB
load complete time 0.06203627586364746
load dataset/DFG
load complete time 0.02844381332397461


In [5]:
spt_size = args["k_spt"] * args["n_way"]
qry_size = args["k_qry"] * args["n_way"]

In [6]:
shared_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
]

nway_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [128, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [128]),
    ('conv2d', [256, 128, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [256]),
    ('conv2d', [512, 256, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [512]),
    ('flatten', []),
    ('linear', [args["n_way"], 8192])
]

discriminator_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 2, 2, 1, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('flatten', []),
    ('linear', [1, 64 * 8 * 8])]

# discriminator_config = [
#     ('conv2d', [64, 3, 3, 3, 2, 0]),
#     ('leakyrelu', [.2, True]),

#     ('conv2d', [64, 64, 3, 3, 2, 0]),
#     ('leakyrelu', [.2, True]),

#     ('conv2d', [64, 64, 3, 3, 2, 0]),
#     ('leakyrelu', [.2, True]),

#     ('conv2d', [64, 64, 2, 2, 1, 0]),
#     ('leakyrelu', [.2, True]),
    
#     ('flatten',[]),
#     ('concat_y',[]),

#     ('linear', [1, 64 * 8 * 8])]

if args["condition_discrim"]:
    discriminator_config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('condition', [512, 32, 5]),
        ('leakyrelu', [0.2, True]),
        ('flatten',[]),
        ('linear', [1024, 1600]),
        ('bn', [1024]),
        ('linear', [1, 1024])
        # don't use a sigmoid at the end
    ]



# gen_config = [
#     ('convt2d', [3, 64, 3, 3, 1, 1]),
#     ('leakyrelu', [.2, True]),
#     ('bn', [64]),
#     ('random_proj', [100, 84, 64]),
#     ('convt2d', [128, 64, 3, 3, 1, 1]),
#     #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('relu', [.2, True]),
#     ('bn', [64]),
#     # ('encode', [1024, 64*28*28]),
#     # ('decode', [64*28*28, 1024]),
#     ('relu', [.2, True]),
#     ('conv2d', [64, 64, 3, 3, 1, 1]),
#     #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('relu', [.2, True]),
#     ('bn', [64]),
# ]
# gen_config = [
#     ('conv2d', [32, 3, 4, 4, 2, 0]),
#     ('leakyrelu', [.2, True]),
#     ('bn', [32]),
#     ('random_proj', [100, 41, 32]),

#     ('convt2d', [64, 32, 3, 3, 1, 1]),
#     #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('relu', [.2, True]),
#     ('bn', [32]),
#     # ('encode', [1024, 64*28*28]),
#     # # ('decode', [64*28*28, 1024]),
#     ('relu', [.2, True]),
#     ('conv2d', [3, 32, 3, 3, 1, 1]),
#     ('convt2d', [3, 3, 4, 4, 2, 0]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('relu', [.2, True]),
#     ('bn', [3]),
#     ("sigmoid",[])
# ]



# gen_config = [
#     ('c_gan',[100,21*21*32,100+5]), # [latent_dim, embedding_dim, ch_out, h_out/w_out]
#     # img: (32, 21, 21)
#     ('convt2d', [32, 16, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('bn', [16]),
#     ('relu', [True]),
#     # img: (16, 42, 42)
#     ('convt2d', [16, 3, 4, 4, 2, 1]),
#     # # img: (3, 84, 84)
#     ('sigmoid', [True])
# ]
# gen_config = [
#     ('c_gan',[100,512,7,256]), # [latent_dim, embedding_dim, ch_out, h_out/w_out]
#     # img: (32, 21, 21)
#     ('convt2d', [32, 16, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('bn', [16]),
#     ('relu', [True]),
#     # img: (16, 42, 42)
#     ('convt2d', [16, 3, 4, 4, 2, 1]),
#     # # img: (3, 84, 84)
#     ('sigmoid', [True])
# ]
gen_config = [
    # img: (256, 7, 7)
    ('c_gan',[100,512,256,7]), # [latent_dim, embedding_dim, ch_out, h_out/w_out]
    # img: (128, 14, 14)
    ('convt2d', [256, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    ('bn', [128]),
    ('relu', [True]),
    # img: (64, 28, 28)
    ('convt2d', [128, 64, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    ('bn', [64]),
    ('relu', [True]),
    # img: (32, 84, 84)
    ('convt2d', [64, 32, 3, 3, 3, 0]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    ('bn', [32]),
    ('relu', [True]),
    ('conv2d', [3, 32, 3, 3, 1, 1]),
    # # # img: (3, 84, 84)
    # ('sigmoid', [True])
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config, gen_config).to(device)

  super(Adam, self).__init__(params, defaults)


In [7]:
# import  torch
# from    torch import nn
# from    torch.nn import functional as F
# import  numpy as np



# class Generator(nn.Module):
#     """

#     """

#     def __init__(self, config, img_c, img_sz, num_classes):
#         """

#         :param config: network config file, type:list of (string, list)
#         :param img_c: 1 or 3
#         :param img_sz:  28 or 84
#         """
#         super(Generator, self).__init__()


#         self.config = config

#         self.num_classes = num_classes

#         # this dict contains all tensors needed to be optimized
#         self.vars = nn.ParameterList()
#         # running_mean and running_var
#         self.vars_bn = nn.ParameterList()

#         for i, (name, param) in enumerate(self.config):
#             if name == 'conv2d':
#                 # [ch_out, ch_in, kernelsz, kernelsz]
#                 w = nn.Parameter(torch.ones(*param[:4]))
#                 # gain=1 according to cbfin's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[0])))

#             elif name == 'convt2d':
#                 # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#                 # output will be sz = stride * (input_sz) + kernel_sz
#                 w = nn.Parameter(torch.ones(*param[:4]))
#                 # gain=1 according to cbfin's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_in, ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[1])))

#             elif name == 'linear':
#                 # [ch_out, ch_in]
#                 w = nn.Parameter(torch.ones(*param))
#                 # gain=1 according to cbfinn's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[0])))
#             elif name == 'encode':
#                 # [ch_out, ch_in]
#                 w = nn.Parameter(torch.ones(*param))
#                 # gain=1 according to cbfinn's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[0])))
#             elif name == 'decode':
#                 # [ch_out, ch_in]
#                 w = nn.Parameter(torch.ones(*param))
#                 # gain=1 according to cbfinn's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[0])))
#             elif name == 'bn':
#                 # [ch_out]
#                 w = nn.Parameter(torch.ones(param[0]))
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[0])))

#                 # must set requires_grad=False
#                 running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
#                 running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
#                 self.vars_bn.extend([running_mean, running_var])
#             elif name == "random_proj":
#                 # [ch_in, ch_out, img_sz]
#                 # latent_dim, latent_ch_out, emb_dim, emb_ch_out, hw_out = param
#                 emb_dim, emb_ch_out, hw_out = param
#                 # latent projection params
#                 # latent_dim, hw_out, rand_ch_out = param
#                 w_lat = nn.Parameter(torch.ones(hw_out*hw_out*latent_ch_out, latent_dim))
#                 torch.nn.init.kaiming_normal_(w_lat)
      
#             elif name == "c_gan":
#                 w = nn.Parameter(torch.ones(param[2]*param[3]*param[3],param[0] + param[1]))
#                 # gain=1 according to cbfinn's implementation
#                 torch.nn.init.kaiming_normal_(w)
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[2]*param[3]*param[3])))
                
#                 w = nn.Parameter(torch.ones(param[2]*param[3]*param[3]))
#                 self.vars.append(w)
#                 # [ch_out]
#                 self.vars.append(nn.Parameter(torch.zeros(param[2]*param[3]*param[3])))

#                 # must set requires_grad=False
#                 running_mean = nn.Parameter(torch.zeros(param[2]*param[3]*param[3]), requires_grad=False)
#                 running_var = nn.Parameter(torch.ones(param[2]*param[3]*param[3]), requires_grad=False)
#                 self.vars_bn.extend([running_mean, running_var])
                
#             elif name in ['tanh', 'relu', 'upsample', 'avg_pool2d', 'max_pool2d',
#                           'flatten', 'reshape', 'leakyrelu', 'sigmoid', 'identity', 'update_identity', 'encode', 'decode']:
#                 continue
#             else:
#                 raise NotImplementedError

#     def forward(self, x, y, vars=None, bn_training=True):
#         """
#         This function can be called by finetunning, however, in finetunning, we dont wish to update
#         running_mean/running_var. Thought weights/bias of bn is updated, it has been separated by fast_weights.
#         Indeed, to not update running_mean/running_var, we need set update_bn_statistics=False
#         but weight/bias will be updated and not dirty initial theta parameters via fast_weiths.
#         :param x: [b, 512]
#         :param vars:
#         :param bn_training: set False to not update
#         :return: x, loss, likelihood, kld
#         """

#         batch_sz = x.size()[0]

#         x_orig = x

#         if vars == None:
#             vars = self.vars

#         idx = 0
#         bn_idx = 0

#         # assert self.config[0][0] is 'random_proj'
#         # need to start with the random projection
#         for name, param in self.config:
#             # print(name)
#             if name == 'conv2d':
#                 w, b = vars[idx], vars[idx + 1]
#                 # remember to keep synchrozied of forward_encoder and forward_decoder!
#                 x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
#                 idx += 2
#                 # print(name, param, '\tout:', x.shape)
#             elif name == 'convt2d':
#                 w, b = vars[idx], vars[idx + 1]
#                 # remember to keep synchrozied of forward_encoder and forward_decoder!
#                 x = F.conv_transpose2d(x, w, b, stride=param[4], padding=param[5])
#                 idx += 2
#                 # print(name, param, '\tout:', x.shape)
#             elif name == 'linear':
#                 w, b = vars[idx], vars[idx + 1]
#                 x = F.linear(x, w, b)
#                 idx += 2
#                 # print('forward:', idx, x.norm().item())
#             elif name == 'bn':
#                 w, b = vars[idx], vars[idx + 1]
#                 running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
#                 x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
#                 idx += 2
#                 bn_idx += 2
#             elif name == 'encode':
#                 x = x.view(x.size(0), -1)
#                 w, b = vars[idx], vars[idx + 1]
#                 x = F.linear(x, w, b)
#                 idx += 2
#             elif name == 'decode':
#                 w, b = vars[idx], vars[idx + 1]
#                 x = F.linear(x, w, b)
#                 x = x.view(x.size(0), 64,28,28)
#                 idx += 2
#             elif name == 'random_proj':

#                 latent_dim, latent_ch_out, emb_dim, emb_ch_out, hw_out = param
#                 # latent_dim, hw_out, rand_ch_out = param
#                 cuda = torch.cuda.is_available()

#                 # send random tensor to linear layer, reshape into noise channels
#                 FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
#                 rand = FloatTensor((x.size(0),latent_dim))
#                 torch.randn(x.size(0),latent_dim, out=rand, requires_grad=True)
#                 # w_lat, b_lat = vars[idx], vars[idx + 1]
#                 # rand = F.linear(rand, w_lat, b_lat)
#                 # rand = F.leaky_relu(rand, 0.2)
#                 # rand = rand.view(rand.size(0), rand_ch_out, hw_out, hw_out)
#                 x = torch.cat((y, rand), 1)

#                 # w_lat, b_lat = vars[idx], vars[idx + 1]

#                 # rand = F.linear(rand, w_lat, b_lat)
#                 # rand = F.leaky_relu(rand, 0.2)
#                 # rand = rand.view(rand.size(0), latent_ch_out, hw_out, hw_out)

#                 # send class embbeddings through a linear layer, reshape embeddings channels
#                 # w_emb, b_emb = vars[idx+2], vars[idx + 3]
#                 # x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
#                 idx += 2

#                 x = F.linear(x, w_emb, b_emb)
#                 x = F.leaky_relu(x, 0.2)
#                 x = x.view(x.size(0), emb_ch_out, hw_out, hw_out)

#                 # concatenate embeddings and projections
                

#                 idx += 2
#             elif name == "c_gan":
#                 latent_dim = param[0]
#                 cuda = torch.cuda.is_available()
#                 FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
#                 rand = FloatTensor(x.size(0),latent_dim)
                
#                 torch.randn(x.size(0),latent_dim, out=rand, requires_grad=False)
                
#                 x = torch.cat((x, rand), 1)

#                 w, b = vars[idx], vars[idx + 1]

#                 x = F.linear(x, w, b)
#                 idx += 2
#                 print(x.size())
#                 w, b = vars[idx], vars[idx + 1]
#                 print(w.size())
#                 running_mean, running_var = self.vars_bn[bn_idx], self.vars_bn[bn_idx+1]
#                 x = F.batch_norm(x, running_mean, running_var, weight=w, bias=b, training=bn_training)
#                 idx += 2
#                 bn_idx += 2
                
#                 x = x.view(x.size(0),param[2],param[3],param[3])
                
                
#             elif name == 'update_identity':
#                 x_orig = x
#             elif name == 'identity':
#                 # print(x.shape)
#                 x += x_orig
#             elif name == 'flatten':
#                 # print(x.shape)
#                 x = x.view(x.size(0), -1)
#             elif name == 'reshape':
#                 # [b, 8] => [b, 2, 2, 2]
#                 x = x.view(x.size(0), *param)
#             elif name == 'relu':
#                 x = F.relu(x, inplace=param[0])
#             elif name == 'leakyrelu':
#                 x = F.leaky_relu(x, negative_slope=param[0], inplace=param[1])
#             elif name == 'tanh':
#                 x = F.tanh(x)
#             elif name == 'sigmoid':
#                 x = torch.sigmoid(x)
#             elif name == 'upsample':
#                 x = F.upsample_nearest(x, scale_factor=param[0])
#             elif name == 'max_pool2d':
#                 x = F.max_pool2d(x, param[0], param[1], param[2])
#             elif name == 'avg_pool2d':
#                 x = F.avg_pool2d(x, param[0], param[1], param[2])

#             else:
#                 raise NotImplementedError

#         # make sure variable is used properly
#         assert idx == len(vars)
#         assert bn_idx == len(self.vars_bn)

#         # right now still returning y so that we can easilly extend to generating diff nums of examples by adjusting y in here
#         return x, y


#     def zero_grad(self, vars=None):
#         """

#         :param vars:
#         :return:
#         """
#         with torch.no_grad():
#             if vars == None:
#                 for p in self.vars:
#                     if not p.grad == None:
#                         p.grad.zero_()
#             else:
#                 for p in vars:
#                     if not p.grad ==  None:
#                         p.grad.zero_()

#     def parameters(self):
#         """
#         override this function since initial parameters will return with a generator.
#         :return:
#         """
#         return self.vars

In [8]:
# from conditioner import Conditioner
# gen_config = [
#     # img: (256, 7, 7)
#     ('c_gan',[100,512,256,7]), # [latent_dim, embedding_dim, ch_out, h_out/w_out]
#     # img: (128, 14, 14)
#     ('convt2d', [256, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('bn', [128]),
#     ('relu', [True]),
#     # img: (64, 28, 28)
#     ('convt2d', [128, 64, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('bn', [64]),
#     ('relu', [True]),
#     # img: (32, 84, 84)
#     ('convt2d', [64, 32, 3, 3, 3, 0]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
#     ('bn', [32]),
#     ('relu', [True]),
#     ('conv2d', [3, 32, 3, 3, 1, 1]),
#     # # # img: (3, 84, 84)
#     # ('sigmoid', [True])
# ]
# x_spt[0].size()
# conditioner = Conditioner().cuda()
# image_embeddings = conditioner(x_spt[0]).squeeze()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# a = Generator(gen_config,3,84,5).to(device)

# a(image_embeddings,y_spt[0])[0].shape

In [9]:
tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))

In [10]:
save_model = not args["no_save"]
if save_model:
    path = args["save_path"]
    mkdir_p(path)
    file = open(path +  '/architecture.txt', 'w+')
    file.write("shared_config = " + json.dumps(shared_config) + "\n" + 
        "nway_config = " + json.dumps(nway_config) + "\n" +
        "discriminator_config = " + json.dumps(discriminator_config) + "\n" + 
        "gen_config = " + json.dumps(gen_config)  + "\n" + 
        "learn_inner_lr = " + str(args["learn_inner_lr"])   + "\n" + 
        "condition_discrim = " + str(args["condition_discrim"])
        )
    file.close()

In [11]:
path = args["save_path"]
step = 0
best_epoch = 0
best_acc = []
with tqdm.tqdm(initial=step,
                   total=int(args["epoch"])) as pbar_train:
    for _ in range(args["epoch"]//6000):
            # fetch meta_batchsz num of episode each time

        train_dataloader = DataLoader(train_data_generator, args["tasks_per_batch"], shuffle=True, num_workers=4, pin_memory=True)

        for _, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_dataloader):

            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs = mamlGAN(x_spt, y_spt, x_qry, y_qry, step)

            if step % 100 == 0:
                print("step " + str(step))
                for key in accs.keys():
                    print(key + ": " + str(accs[key]))
                save_train_accs(path, accs)
            if step % 500 == 0:  # evaluation
                db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                imgs_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs, imgs = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)
                    torch.cuda.empty_cache()
                    accs_all_test.append(accs)
                    imgs_all_test.append(imgs.cpu().detach().numpy())

                imgs_all_test = np.array(imgs_all_test)
                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print('Test acc:', accs)
                save_test_accs(path, accs)
                
                if not len(best_acc):
                    best_acc = accs
                    best_epoch = step
                    torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/best")
                else:
                    if max(accs) > max(best_acc):
                        best_acc = accs
                        best_epoch = step
                        torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/best")
                torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/model_step" + str(step))

                save_imgs(path, imgs_all_test, step)
            step = step + 1
            pbar_train.update(1)

  0%|                                                 | 0/24000 [00:00<?, ?it/s]../aten/src/ATen/native/cuda/Loss.cu:115: operator(): block: [0,0,0], thread: [3,0,0] Assertion `input_val >= zero && input_val <= one` failed.
  0%|                                                 | 0/24000 [00:02<?, ?it/s]


RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

In [None]:
imgs_all_test[0][0]

In [None]:
plt.imshow(imgs_all_test[0][4].transpose((1,2,0)))

In [None]:
best_acc = np.array(([best_epoch] + list(accs))).astype(np.float8)
save_test_accs(path,best_acc)

In [None]:
84*84*3