In [1]:
import  torch, os
import  numpy as np

from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
from matplotlib import pyplot as plt
from PIL import Image
import json

from datetime import datetime
from meta_gan import MetaGAN
from dataloader import train_data_generator,test_data_generator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def mkdir_p(path):
    if not os.path.exists(path):
        os.makedirs(path)

def save_train_accs(path, accs):
    file = open(path +  '/q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["q_nway"]]))
    file.close()

    file = open(path +  '/discrim_loss.txt', 'ab')
    np.savetxt(file, np.array([accs["discrim_loss"]]))
    file.close()

    file = open(path +  '/gen_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["gen_nway"]]))
    file.close()

    file = open(path +  '/gen_loss.txt', 'ab')
    np.savetxt(file, np.array([accs["gen_loss"]]))
    file.close()

def save_test_accs(path, accs):
    file = open(path +  '/test_q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs]))
    file.close()

def save_imgs(path, imgs, step):
    # save raw txt files
    img_f=open(path+"/images_step" + str(step) + ".txt",'ab')
    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]
    np.savetxt(img_f,some_imgs)
    img_f.close()

    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,84,84).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 49:
            plt.subplot(7, 7, 1 + i)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig(path+"/images_step" + str(step) + ".png")
    plt.close()

In [3]:
args = {'epoch':24000,
        'n_way':5,
        'k_spt':1,
        'k_qry':15,
        'img_sz':84,
        "tasks_per_batch":4,
        'img_c':3,
        'task_num': 4,
        'meta_lr':1e-3,
        'update_lr':2e-4,
        'gan_update_lr':2e-4,
        'update_steps':5,
        'update_steps_test':10,
        "no_save":False,
        "learn_inner_lr":True,
        'condition_discrim':False,
        "loss":"cross_entropy",
        "create_graph":False,
        "single_fast_test":False,
        "consine_schedule":True,
        "min_learning_rate":1e-10,
        "number_of_training_steps_per_iter":5,
        "multi_step_loss_num_epochs":15,
        'save_path':'0425_consine_mamlgan2'
       }

In [4]:
train_data_generator = train_data_generator(args)
test_data_generator = test_data_generator(args)

load dataset/BelgiumTSC
load complete time 2.818248987197876
load dataset/ArTS
load complete time 2.73152232170105
load dataset/chinese_traffic_sign
load complete time 0.58467698097229
load dataset/CVL
load complete time 0.4786684513092041
load dataset/FullJCNN2013
load complete time 0.17418503761291504
load dataset/logo_2k
load complete time 1.0330214500427246
load dataset/GTSRB
load complete time 0.14809322357177734
load dataset/DFG
load complete time 0.04896831512451172


In [5]:
spt_size = args["k_spt"] * args["n_way"]
qry_size = args["k_qry"] * args["n_way"]

In [6]:
shared_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
]

nway_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [128, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [128]),
    ('conv2d', [256, 128, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [256]),
    ('conv2d', [512, 256, 3, 3, 2, 0]),
    ('relu', [True]),
    ('bn', [512]),
    ('flatten', []),
    ('linear', [args["n_way"], 8192])
]

discriminator_config = [
    ('conv2d', [64, 3, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 3, 3, 2, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [64, 64, 2, 2, 1, 0]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('flatten', []),
    ('linear', [1, 64 * 8 * 8])]
if args["condition_discrim"]:
    discriminator_config = [
        ('conv2d', [32, 3, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 2, 0]),
        ('conv2d', [32, 32, 3, 3, 1, 0]),
        ('relu', [True]),
        ('bn', [32]),
        ('max_pool2d', [2, 1, 0]),
        ('condition', [512, 32, 5]),
        ('leakyrelu', [0.2, True]),
        ('flatten',[]),
        ('linear', [1024, 1600]),
        ('bn', [1024]),
        ('linear', [1, 1024])
        # don't use a sigmoid at the end
    ]



gen_config = [
    ('convt2d', [3, 64, 3, 3, 1, 1]),
    ('leakyrelu', [.2, True]),
    ('bn', [64]),
    ('random_proj', [100, 84, 64]),
    ('convt2d', [128, 64, 3, 3, 1, 1]),
    #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    ('relu', [.2, True]),
    ('bn', [64]),
    # ('encode', [1024, 64*28*28]),
    # ('decode', [64*28*28, 1024]),
    ('relu', [.2, True]),
    ('conv2d', [64, 64, 3, 3, 1, 1]),
    #('convt2d', [1, 128, 4, 4, 2, 1]), # [ch_in, ch_out, kernel_sz, kernel_sz, stride, padding]
    ('relu', [.2, True]),
    ('bn', [64]),
    ('conv2d', [3, 64, 3, 3, 1, 1]),
    ('sigmoid', [True])
]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config, gen_config).to(device)

  super(Adam, self).__init__(params, defaults)


In [7]:
tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))

In [8]:
save_model = not args["no_save"]
if save_model:
    path = args["save_path"]
    mkdir_p(path)
    file = open(path +  '/architecture.txt', 'w+')
    file.write("shared_config = " + json.dumps(shared_config) + "\n" + 
        "nway_config = " + json.dumps(nway_config) + "\n" +
        "discriminator_config = " + json.dumps(discriminator_config) + "\n" + 
        "gen_config = " + json.dumps(gen_config)  + "\n" + 
        "learn_inner_lr = " + str(args["learn_inner_lr"])   + "\n" + 
        "condition_discrim = " + str(args["condition_discrim"])
        )
    file.close()

In [9]:
args["epoch"]//6000

4

In [10]:
path = args["save_path"]
step = 0
for _ in range(args["epoch"]//6000):
        # fetch meta_batchsz num of episode each time

    train_dataloader = DataLoader(train_data_generator, args["tasks_per_batch"], shuffle=True, num_workers=4, pin_memory=True)

    for _, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_dataloader):

        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

        accs = mamlGAN(x_spt, y_spt, x_qry, y_qry)

        if step % 100 == 0:
            print("step " + str(step))
            for key in accs.keys():
                print(key + ": " + str(accs[key]))
            save_train_accs(path, accs)
        if step % 500 == 0:  # evaluation
            db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=4, pin_memory=True)
            accs_all_test = []
            imgs_all_test = []
            for x_spt, y_spt, x_qry, y_qry in db_test:
                x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                             x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                accs, imgs = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)
                accs_all_test.append(accs)
                imgs_all_test.append(imgs.cpu().detach().numpy())

            imgs_all_test = np.array(imgs_all_test)

            # [b, update_step+1]
            accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
            print('Test acc:', accs)
            save_test_accs(path, accs)
            torch.save({'model_state_dict': mamlGAN.state_dict()}, path + "/model_step" + str(step))

            save_imgs(path, imgs_all_test, step)
        step = step + 1

step 0
q_nway: [0.23       0.25333333 0.26333333 0.27333333 0.28333333 0.3       ]
discrim_loss: [8.4948684  4.05652153 4.27415031 3.59539306 3.48411924 3.74690288]
gen_loss: [2.93323123 2.88178617 3.05683374 3.31443202 3.1094467  3.15787145]
gen_nway: [0.2  0.2  0.25 0.4  0.25 0.25]
Test acc: [0.2017 0.4966 0.5005 0.4998 0.499  0.4976]
step 100
q_nway: [0.19       0.47333333 0.53333333 0.54       0.56333333 0.55333333]
discrim_loss: [14.96501994  2.95786697  3.32242298  3.2412473   3.48072928  2.83375162]
gen_loss: [3.94521987 2.94608107 2.88987952 2.98121196 3.8003329  3.26951051]
gen_nway: [0.1  0.45 0.9  0.9  0.85 0.95]
step 200
q_nway: [0.12       0.58       0.63333333 0.58666667 0.62       0.62666667]
discrim_loss: [12.78364134  4.03368771  4.27373445  3.2404117   2.95953536  3.70325744]
gen_loss: [3.44803157 2.99808899 4.2243073  5.22732759 3.34146935 4.0522584 ]
gen_nway: [0.15 0.85 1.   1.   1.   1.  ]
step 300
q_nway: [0.23666667 0.60666667 0.61333333 0.62333333 0.61333333 0.