In [1]:
import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np
from torch.autograd import Variable

from    learner import Learner
from generator import Generator
from    copy import deepcopy
from conditioner import Conditioner
import os
from torchsummary import summary

from utils.dataloader import train_data_gen , test_data_gen
import shutil
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm

from torch.utils.tensorboard import SummaryWriter
import json
# import tqdm

In [2]:
with open('configs/gen8.json') as json_file:
    args = json.load(json_file)

In [3]:
print(args)

{'epoch': 30000, 'n_way': 5, 'k_spt': 1, 'k_qry': 10, 'img_sz': 84, 'tasks_per_batch': 5, 'img_c': 3, 'meta_gen_lr': 0.0005, 'meta_discrim_lr': 0.0001, 'update_lr': 0.004, 'update_steps': 1, 'update_steps_test': 1, 'loss': 'cross_entropy', 'min_learning_rate': 1e-15, 'number_of_training_steps_per_iter': 4, 'multi_step_loss_num_epochs': 15, 'spy_gen_num': 5, 'qry_gen_num': 25, 'num_distractor': 0, 'batch_for_gradient': 50, 'no_save': 0, 'learn_inner_lr': 0, 'create_graph': 0, 'msl': 0, 'single_fast_test': 0, 'consine_schedule': 0, 'save_path': 'gen13'}


In [4]:
if os.path.exists("images/" + args["save_path"]):
    shutil.rmtree("images/" + args["save_path"])
    
if os.path.exists("data/" + args["save_path"]):
    shutil.rmtree("data/" + args["save_path"])
    
if os.path.exists("save_models/" + args["save_path"]):
    shutil.rmtree("save_models/" + args["save_path"])
    
if os.path.exists("runs/" + args["save_path"]):
    shutil.rmtree("runs/" + args["save_path"])    

writer = SummaryWriter('runs/' + args["save_path"])

In [5]:
def mkdir_p(path):
    if not os.path.exists("images/" + path):
        os.makedirs("images/" + path)
        
    if not os.path.exists("data/" + path):
        os.makedirs("data/" + path)
        
    if not os.path.exists("save_models/" + path):
        os.makedirs("save_models/" + path)        


def save_imgs(path, imgs, step):

    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]

    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,84,84).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 15:
            plt.subplot(5, 3, i+1)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig("images/" + path + "/images_step" + str(step) + ".png")
    plt.close()

In [6]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load datasets/BelgiumTSC
load complete time 0.3554813861846924
load datasets/ArTS
load complete time 0.3645918369293213
load datasets/chinese_traffic_sign
load complete time 0.6034379005432129
load datasets/CVL
load complete time 0.4320671558380127
load datasets/FullJCNN2013
load complete time 0.2325131893157959
load datasets/logo_2k
load complete time 1.0139472484588623
load datasets/GTSRB
load complete time 0.07374167442321777
load datasets/DFG
load complete time 0.030877351760864258


In [7]:
ndf = 64
discriminator_config = [
    ('conv2d', [ndf, 3, 4, 4, 2, 1]),
    ('leakyrelu', [0.2,True]),
    # ('bn', [ndf]),
    
    ('conv2d', [ndf*2, ndf, 4, 4, 2, 1]),
    ('bn', [ndf*2]),
    ('leakyrelu', [0.2,True]),

    ('conv2d', [ndf*4, ndf*2, 4, 4, 2, 1]),
    ('bn', [ndf*4]),
    ('leakyrelu', [0.2,True]),
    
    
    ('conv2d', [ndf*8, ndf*4, 4, 4, 2, 1]),
    ('bn', [ndf*8]),
    ('leakyrelu', [0.2,True]),
    
    ('conv2d', [1,ndf*8 , 2, 2, 1, 0]),
    ('flatten', []),
    ('linear',[6, 16]),
    ('softmax',[])
]

nz = 100
ngf = 64
gen_config = [
    ('convert_z',[]),
    ('convt2d',[nz,ngf*8,4,4,1,0]),
    ('bn',[ngf * 8]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*8,ngf*4,4,4,2,0]),
    ('bn',[ngf * 4]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*4,ngf*2,4,4,2,0]),
    ('bn',[ngf * 2]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*2,ngf,3,3,2,1]),
    ('bn',[ngf]),
    ('leakyrelu', [.2, True]),      
    
    ('convt2d',[ngf,3,3,3,2,1]),
    ('convt2d',[3,3,2,2,1,1]),
    ("tanh",[])
]

In [8]:
class Meta(nn.Module):
    """
    Meta Learner with GAN incorporated
    """
    def __init__(self, args, discriminator_config, gen_config):
        """
        :param args:
        """
        super(Meta, self).__init__()
        
        cuda = torch.cuda.is_available()
        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
        self.LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
        self.total_epochs = args["epoch"]   
        # model parameters config
        self.meta_gen_lr = args["meta_gen_lr"]
        self.meta_discrim_lr = args["meta_discrim_lr"]
        
        self.update_lr = args["update_lr"]
        
        self.update_steps = args["update_steps"]
        self.update_steps_test = args["update_steps_test"]
        
        # dataset config
        self.img_c = args["img_c"]
        self.img_sz = args["img_sz"]        
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.MSL = args["msl"]
        # generator num
        self.spy_gen_num = args["spy_gen_num"]
        self.qry_gen_num = args["qry_gen_num"]
        # query gan batch
        self.batch_for_gradient = args["batch_for_gradient"]
        self.fix_noise = torch.randn(self.batch_for_gradient, nz,1,1, device=device)
        self.criterion = nn.BCELoss()
        # load model
        self.generator = Generator(gen_config, self.img_c, self.img_sz)
        self.discrim_net = Learner(discriminator_config, self.img_c, self.img_sz)
        beta1 = 0.0
        beta2 = 0.0


        self.meta_gen_optim = optim.Adam(self.generator.parameters(), lr=self.meta_gen_lr,betas=(beta1, 0.9))
        self.meta_d_optim = optim.Adam(self.discrim_net.parameters(), lr=self.meta_discrim_lr,betas=(beta2, 0.9))


        self.real_value = 1
        self.fake_value = 0
    
    def pred(self, x, weights=None, nets=None, nway=True, discrim=True, conditions=False):
        if weights == None:
            discrim_weights = self.discrim_net.parameters()
        else:
            discrim_weights = weights

        discrim_logits = self.discrim_net(x, vars=discrim_weights, bn_training=True) if discrim else None
          
        return discrim_logits

    def get_num_corrects(self, y, x=None, weights=None):
            
        with torch.no_grad():

            discrim_logits = self.pred(x, weights=weights)[:,0]
            nway_correct = (discrim_logits).mean().item()

        return nway_correct

    def update_weights(self, net_losses, net_weights,learned_lrs):

        update_lr = self.update_lr
        grad = torch.autograd.grad(net_losses, net_weights)
        weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, net_weights)))

        return weights
    
    def meta_test(self,qry_img,qry_label,discrim_weight,gen_weight):
        ### discriminator train
        q_real_discrim_logits = self.pred(qry_img, weights=discrim_weight)[:,0]

        real_discrim_loss_q = self.criterion(q_real_discrim_logits, qry_label)

        discrim_fake_label = torch.full((self.qry_gen_num,), self.fake_value, dtype=torch.float, device=device) 
        noise = torch.randn(self.qry_gen_num, nz,1,1, device=device)
        q_gen = torch.empty(0,3,84,84).cuda()
        
        if self.qry_gen_num < self.batch_for_gradient:
            q_gen = self.generator(qry_img, noise , vars=gen_weight)
        else:
            for i in range(self.qry_gen_num//self.batch_for_gradient):
                noise_tmp = noise[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient]
                q_gen = torch.cat([q_gen,self.generator(qry_img[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient], noise_tmp , vars=gen_weight)])

        q_fake_discrim_logits = self.pred(q_gen.detach(), weights=discrim_weight)[:,0]
        fake_discrim_loss_q = self.criterion(q_fake_discrim_logits, discrim_fake_label)
        d_loss_q = (fake_discrim_loss_q + real_discrim_loss_q)
        
        ### generator train
        gen_fake_label = torch.full((self.qry_gen_num,), self.real_value, dtype=torch.float, device=device)
        gen_q_discrim = self.pred(q_gen, weights=discrim_weight)[:,0]
        g_loss_q = self.criterion(gen_q_discrim, gen_fake_label)
        return d_loss_q, g_loss_q

    def single_task_forward(self, x_spt, y_spt, x_qry, y_qry, update_steps,nets=None, images=False):
        
        corrects = {key: np.zeros(update_steps + 1) for key in 
                        [
                        "D(x)",
                        "D(G(z))"
                        ]}

        y_spt = torch.full((y_spt.size(0),), self.real_value, dtype=torch.float, device=device)
        y_qry = torch.full((y_qry.size(0),), self.real_value, dtype=torch.float, device=device)
        support_sz, c_, h, w = x_spt.size()
        nz = 100

        discrim_weights,gen_weights = [x.parameters() for x in nets]

        # this is the meta-test loss and accuracy before first update

        q_discrim = self.get_num_corrects(y=y_qry, weights=None, x=x_qry)
        corrects["D(x)"][0] += q_discrim
        # run the i-th task and compute loss for k-th inner update
        
        for k in range(1, update_steps + 1):
            ## discrim loss
            noise = torch.randn(self.spy_gen_num, nz , 1, 1, device=device)
            x_gen = self.generator(x_spt, noise , vars=gen_weights)
            
            # update discrim weight

            real_discrim_logits = self.pred(x_spt, weights=discrim_weights)[:,0]
            # real_discrim_logits = self.pred(x_spt, weights=discrim_weights)

            fake_discrim_logits = self.pred(x_gen, weights=discrim_weights)[:,0]
            # fake_discrim_logits = self.pred(x_gen, weights=discrim_weights)
            
            fake_label = torch.full((self.spy_gen_num,), self.fake_value, dtype=torch.float, device=device)
            

            real_discrim_loss = self.criterion(real_discrim_logits, y_spt)
            fake_discrim_loss = self.criterion(fake_discrim_logits,fake_label)
            D_loss = fake_discrim_loss + real_discrim_loss
            # print(fake_discrim_loss.item(),real_discrim_loss.item())
            discrim_weights = self.update_weights(D_loss, discrim_weights,self.update_lr) 
            
            with torch.no_grad():
                x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights) 
                gen_correct = self.pred(x_gen, weights=discrim_weights)[:,0]
                gen_correct = gen_correct.mean().item()
                corrects["D(G(z))"][k-1] += gen_correct
                
                q_discrim_correct = self.get_num_corrects(y=y_qry, x=x_qry, weights=discrim_weights)
                corrects["D(x)"][k] += q_discrim_correct
#             # meta-test nway and discrim accuracy
#             # [query_sz]
        
#         # final gen-discrim and gen-nway accuracy
        with torch.no_grad():
            x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights)
            gen_correct = self.pred(x_gen, weights=discrim_weights)[:,0]
            # gen_correct = self.pred(x_gen, weights=discrim_weights)
            gen_correct = gen_correct.mean().item()
            corrects["D(G(z))"][-1] += gen_correct
        d_loss_q, g_loss_q = self.meta_test(x_qry,y_qry,discrim_weights,gen_weights)
            
        if images:
            return d_loss_q,g_loss_q, corrects, x_gen
        else:
            return d_loss_q,g_loss_q, corrects

    def forward(self, x_spt, y_spt, x_qry, y_qry,step):
        """
        :param x_spt:   [b, support_sz, c_, h, w]
        :param y_spt:   [b, support_sz]
        :param x_qry:   [b, query_sz, c_, h, w]
        :param y_qry:   [b, query_sz]
        :return:
        """
        self.current_epoch = step 
        tasks_per_batch, support_sz, c_, h, w = x_spt.size()
        query_sz = x_qry.size(1)
        g_loss_q = 0
        d_loss_q = 0
        gen_losses_q = [0 for _ in range(self.update_steps + 1)]
        discrim_losses_q = [0 for _ in range(self.update_steps + 1)]
        corrects = {key: np.zeros(self.update_steps + 1) for key in 
                        [
                        "D(x)",
                        "D(G(z))"
                        ]}
        net = [self.discrim_net,self.generator]
        for i in range(tasks_per_batch):
            d_loss_q_tmp,g_loss_q_tmp, corrects_tmp = self.single_task_forward(x_spt[i], y_spt[i], x_qry[i], y_qry[i],self.update_steps,nets = net,images=False)
            g_loss_q += g_loss_q_tmp
            d_loss_q += d_loss_q_tmp
            assert len(corrects_tmp.keys()) == len(corrects.keys())
            for key in corrects.keys():
                corrects[key] += corrects_tmp[key]
            
        # end of all tasks
        # sum over final losses on query set across all tasks
        if step > 30:
            g_loss_q /= tasks_per_batch
            self.meta_gen_optim.zero_grad()
            g_loss_q.backward()
            self.meta_gen_optim.step()        

        # optimize theta parameters
        d_loss_q /= tasks_per_batch
        self.meta_d_optim.zero_grad()
        d_loss_q.backward()
        self.meta_d_optim.step()
        
        accs = {}
        accs["D(x)"] = corrects["D(x)"] / (tasks_per_batch)
        accs["D(G(z))"] = corrects["D(G(z))"] / (tasks_per_batch)

        return accs,d_loss_q,g_loss_q

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [support_sz, c_, h, w]
        :param y_spt:   [support_sz]
        :param x_qry:   [query_sz, c_, h, w]
        :param y_qry:   [query_sz]
        :return:
        """

        support_sz, c_, h, w = x_spt.size()

        assert len(x_spt.shape) == 4

        query_sz = x_qry.size(0)

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        
        discrim_net = deepcopy(self.discrim_net)
        generator = deepcopy(self.generator)
        net = [self.discrim_net,self.generator]
        d_loss_q,g_loss_q, corrects, imgs = self.single_task_forward(x_spt, y_spt, x_qry, y_qry,self.update_steps_test, nets=net,images=True)

        del discrim_net
        
        accs["D(x)"] = corrects["D(x)"]
        accs["D(G(z))"] = corrects["D(G(z))"]

        return accs, imgs,d_loss_q,g_loss_q


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamlGAN = Meta(args, discriminator_config, gen_config).to(device)
step = 0
path = args["save_path"]
mkdir_p(path)
best_acc = []

with tqdm.tqdm(initial=step,
                   total=int(args["epoch"])) as pbar_train:
    for _ in range(args["epoch"] * args["tasks_per_batch"]//6000):
        # fetch meta_batchsz num of episode each time
        train_dataloader = DataLoader(train_data_generator, args["tasks_per_batch"], shuffle=True, num_workers=4, pin_memory=True)

        for _, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_dataloader):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs,d_loss,g_loss = mamlGAN(x_spt, y_spt, x_qry, y_qry,step)
            # accs,d_loss = mamlGAN(x_spt, y_spt, x_qry, y_qry,step)
            writer.add_scalar('Loss/train_d_loss', d_loss, step)
            writer.add_scalar('Loss/train_g_loss', g_loss, step)
            writer.add_scalar('Accuracy/"train_D(x)', accs["D(x)"][-1], step)
            writer.add_scalar('Accuracy/"train_D(G(z))', accs["D(G(z))"][-1], step)
            if step % 100 == 0:
                print("step " + str(step))
                print('d loss:',d_loss.item())
                print('g loss:',g_loss.item())
                print("accs",accs)


            if step % 500 == 0:  # evaluation
                db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                imgs_all_test = []
                d_loss_all_test = []
                g_loss_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    # accs, d_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs, imgs,d_loss,g_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)


                    accs_all_test.append(accs)
                    imgs_all_test.append(imgs.cpu().detach().numpy())
                    d_loss_all_test.append(d_loss.item())
                    g_loss_all_test.append(g_loss.item())

                imgs_all_test = np.array(imgs_all_test)
                # [b, update_step+1]
                # accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                d_loss = np.mean(np.array(d_loss_all_test))
                g_loss = np.mean(np.array(g_loss_all_test))

                print('d loss:',d_loss)
                print('g loss:',g_loss)
                print('Test acc:', accs)    

                writer.add_scalar('Loss/test_d_loss', d_loss, step)
                writer.add_scalar('Loss/test_g_loss', g_loss, step)
                writer.add_scalar('Accuracy/"test_D(x)', accs["D(x)"][-1], step)
                writer.add_scalar('Accuracy/"test_D(G(z))', accs["D(G(z))"][-1], step)

                if not len(best_acc):
                    best_acc = accs
                    best_epoch = step
                    torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                else:
                    if max(accs) > max(best_acc):
                        best_acc = accs
                        best_epoch = step
                        torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/model_step" + str(step) + ".pth")

                save_imgs(path, imgs_all_test, step)

            step = step + 1
            pbar_train.update(1)

  0%|          | 0/30000 [00:00<?, ?it/s]

step 0
d loss: 1.4976173639297485
g loss: 10.921062469482422
accs {'D(x)': array([0.20729602, 0.47721277]), 'D(G(z))': array([0.22991294, 0.22991011])}
d loss: 1.349763709306717
g loss: 2.174746948480606
Test acc: {'D(x)': array([0.30513278, 0.59015721]), 'D(G(z))': array([0.19835092, 0.19836073])}
step 100
d loss: 1.3975893259048462
g loss: 0.7426688075065613
accs {'D(x)': array([0.83565809, 0.79118344]), 'D(G(z))': array([0.5523163 , 0.55231531])}
step 200
d loss: 0.17730998992919922
g loss: 4.299203395843506
accs {'D(x)': array([0.95750504, 0.90808061]), 'D(G(z))': array([0.07030124, 0.07030432])}
step 300
d loss: 0.35638609528541565
g loss: 2.512193441390991
accs {'D(x)': array([0.95898908, 0.86937265]), 'D(G(z))': array([0.10200596, 0.10200369])}
step 400
d loss: 0.3572602868080139
g loss: 2.267479658126831
accs {'D(x)': array([0.95211376, 0.90674336]), 'D(G(z))': array([0.18995091, 0.18994861])}
step 500
d loss: 0.027400070801377296
g loss: 4.040902614593506
accs {'D(x)': array([

KeyboardInterrupt: 