In [1]:
import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np
from torch.autograd import Variable

from   models.learner import Learner
from models.generator import Generator
from    copy import deepcopy
import os
from torchsummary import summary

from utils.dataloader import train_data_gen , test_data_gen
import shutil
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
from torch.utils.tensorboard import SummaryWriter
import json
# import tqdm

In [2]:
PATH = "save_models/0612_maml_gen/model_step9900.pth"

In [3]:
with open('configs/0625_5way1shot0distractor1gen.json') as json_file:
    args = json.load(json_file)
print(args)

{'epoch': 96000, 'n_way': 5, 'k_spt': 1, 'k_qry': 15, 'img_sz': 84, 'tasks_per_batch': 5, 'img_c': 3, 'meta_gen_lr': 0.0005, 'meta_discrim_lr': 0.0001, 'update_lr': 0.004, 'update_steps': 2, 'update_steps_test': 2, 'loss': 'cross_entropy', 'min_learning_rate': 1e-15, 'number_of_training_steps_per_iter': 4, 'multi_step_loss_num_epochs': 15, 'spy_gen_num': 5, 'qry_gen_num': 25, 'num_distractor': 2, 'spy_distractor_num': 1, 'qry_distractor_num': 15, 'batch_for_gradient': 25, 'no_save': 0, 'learn_inner_lr': 0, 'create_graph': 0, 'msl': 0, 'single_fast_test': 0, 'consine_schedule': 0, 'save_path': '0625_5way1shot0distractor1gen'}


In [4]:
# if os.path.exists("images/" + args["save_path"]):
#     shutil.rmtree("images/" + args["save_path"])
    
# if os.path.exists("data/" + args["save_path"]):
#     shutil.rmtree("data/" + args["save_path"])
    
# if os.path.exists("save_models/" + args["save_path"]):
#     shutil.rmtree("save_models/" + args["save_path"])
    
# if os.path.exists("runs/" + args["save_path"]):
#     shutil.rmtree("runs/" + args["save_path"])    

writer = SummaryWriter('metagan_runs/' + args["save_path"])

In [5]:
def mkdir_p(path):
    if not os.path.exists("images/" + path):
        os.makedirs("images/" + path)
        
    if not os.path.exists("data/" + path):
        os.makedirs("data/" + path)
        
    if not os.path.exists("save_models/" + path):
        os.makedirs("save_models/" + path)        

def save_imgs(path, imgs, step):

    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]

    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,84,84).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 15:
            plt.subplot(5, 3, i+1)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig("images/" + path + "/images_step" + str(step) + ".png")
    plt.close()

In [6]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load datasets/BelgiumTSC
load complete time 0.4268476963043213
load datasets/ArTS
load complete time 0.4329042434692383
load datasets/chinese_traffic_sign
load complete time 0.7351255416870117
load datasets/CVL
load complete time 0.5277915000915527
load datasets/FullJCNN2013
load complete time 0.27721738815307617
load datasets/logo_2k
load complete time 1.0279045104980469
load datasets/GTSRB
load complete time 0.10019350051879883
load datasets/DFG
load complete time 0.11236929893493652


In [7]:
ndf = 64
discriminator_config = [
    ('conv2d', [ndf, 3, 4, 4, 2, 1]),
    ('leakyrelu', [0.2,True]),
    # ('bn', [ndf]),
    
    ('conv2d', [ndf*2, ndf, 4, 4, 2, 1]),
    ('bn', [ndf*2]),
    ('leakyrelu', [0.2,True]),

    ('conv2d', [ndf*4, ndf*2, 4, 4, 2, 1]),
    ('bn', [ndf*4]),
    ('leakyrelu', [0.2,True]),
    
    
    ('conv2d', [ndf*8, ndf*4, 4, 4, 2, 1]),
    ('bn', [ndf*8]),
    ('leakyrelu', [0.2,True]),
    
    ('conv2d', [1,ndf*8 , 2, 2, 1, 0]),
    ('flatten', []),
    ('linear',[6, 16]),
    ('softmax',[])
]
nz = 100
ngf = 64
gen_config = [
    ('convert_z',[]),
    ('convt2d',[nz,ngf*8,4,4,1,0]),
    ('bn',[ngf * 8]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*8,ngf*4,4,4,2,0]),
    ('bn',[ngf * 4]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*4,ngf*2,4,4,2,0]),
    ('bn',[ngf * 2]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*2,ngf,3,3,2,1]),
    ('bn',[ngf]),
    ('leakyrelu', [.2, True]),      
    
    ('convt2d',[ngf,3,3,3,2,1]),
    ('convt2d',[3,3,2,2,1,1]),
    ("tanh",[])
]

In [10]:
class Meta(nn.Module):
    """
    Meta Learner with GAN incorporated
    """
    def __init__(self, args, discriminator_config, gen_config):
        """
        :param args:
        """
        super(Meta, self).__init__()
        
        cuda = torch.cuda.is_available()
        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
        self.LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
        self.total_epochs = args["epoch"]   
        # model parameters config
        self.meta_gen_lr = args["meta_gen_lr"]
        self.meta_discrim_lr = args["meta_discrim_lr"]
        
        self.update_lr = args["update_lr"]
        self.consine_schedule = args["consine_schedule"]
        self.update_steps = args["update_steps"]
        self.update_steps_test = args["update_steps_test"]
        self.distractor = args["num_distractor"]
        # dataset config
        self.img_c = args["img_c"]
        self.img_sz = args["img_sz"]        
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.MSL = args["msl"]
        # generator num
        self.spy_gen_num = args["spy_gen_num"]
        self.qry_gen_num = args["qry_gen_num"]
        # query gan batch
        self.batch_for_gradient = args["batch_for_gradient"]
        self.fix_noise = torch.randn(self.qry_gen_num, nz,1,1, device=device)
        self.criterion = nn.BCELoss()
        # load model
        self.generator = Generator(gen_config, self.img_c, self.img_sz)
        self.discrim_net = Learner(discriminator_config, self.img_c, self.img_sz)
        beta1 = 0.0
        beta2 = 0.0
        
        self.meta_gen_optim = optim.Adam(self.generator.parameters(), lr=self.meta_gen_lr,betas=(beta1, 0.9))
        self.meta_d_optim = optim.Adam(self.discrim_net.parameters(), lr=self.meta_discrim_lr,betas=(beta2, 0.9))
        if self.consine_schedule:
            self.min_learning_rate = 1e-8
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.meta_d_optim, T_max=self.total_epochs,
                                                                  eta_min=self.min_learning_rate)
        self.real_value = 1
        self.fake_value = 0
        self.discrim_fake = 5
    def cross_entropy(self, output,label):
        output = torch.log(output)
        loss = F.nll_loss(output,label)
        return loss
    
    def pred(self, x, weights=None, nets=None, nway=True, discrim=True, conditions=False):
        if weights == None:
            discrim_weights = self.discrim_net.parameters()
        else:
            discrim_weights = weights

        discrim_logits = self.discrim_net(x, vars=discrim_weights, bn_training=True) if discrim else None
          
        return discrim_logits

    def get_num_corrects(self, y, x=None, weights=None, real=True):
        with torch.no_grad():
            discrim_logits = self.pred(x, weights=weights)
            if real :
                if self.distractor:
                    nway_pred_q = discrim_logits.argmax(dim=1)
                else:
                    nway_pred_q = discrim_logits[:,:-1].argmax(dim=1)
            else:
                nway_pred_q = discrim_logits.argmax(dim=1)
            
            nway_correct = torch.eq(nway_pred_q, y).sum().item()
            pred_q = discrim_logits.argmax(dim=1)
            other = torch.tensor([5]*len(discrim_logits)).cuda()
            other_correct = torch.eq(pred_q, other).sum().item()
        return nway_correct, other_correct

        
    def update_weights(self, net_losses, net_weights,learned_lrs, gen=False):
        if gen:
            update_lr = self.gen_update_lr
        else:
            update_lr = self.update_lr
        # grad = torch.autograd.grad(net_losses, net_weights, retain_graph=True, create_graph=self.create_graph)
        grad = torch.autograd.grad(net_losses, net_weights)
        weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, net_weights)))

        return weights
    
    def meta_test(self,qry_img,qry_label,discrim_weight,gen_weight):
        ### discriminator train
        q_real_discrim_logits = self.pred(qry_img, weights=discrim_weight)

        real_discrim_loss_q = self.cross_entropy(q_real_discrim_logits, qry_label)
        if torch.isnan(real_discrim_loss_q):
            print(self.current_epoch)
            print("real d loss error")
            print(q_real_discrim_logits)
            
        discrim_fake_label = torch.full((self.qry_gen_num,), self.discrim_fake, dtype=torch.long, device=device) 
        noise = torch.randn(self.qry_gen_num, nz,1,1, device=device)
        q_gen = torch.empty(0,3,84,84).cuda()
        if self.qry_gen_num < self.batch_for_gradient:

            q_gen = self.generator(qry_img, noise , vars=gen_weight)
        else:
            for i in range(self.qry_gen_num//self.batch_for_gradient):

                noise_tmp = noise[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient]

                q_gen = torch.cat([q_gen,self.generator(qry_img[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient], noise_tmp , vars=gen_weight)])
        q_fake_discrim_logits = self.pred(q_gen.detach(), weights=discrim_weight)
        fake_discrim_loss_q = self.cross_entropy(q_fake_discrim_logits, discrim_fake_label)

        if torch.isnan(fake_discrim_loss_q):
            print("fake d loss error")
            print(q_fake_discrim_logits)
        d_loss_q = (fake_discrim_loss_q + real_discrim_loss_q)
        
        ### generator train
        gen_fake_label = torch.full((self.qry_gen_num,), self.fake_value, dtype=torch.float, device=device)
        gen_q_discrim = self.pred(q_gen, weights=discrim_weight)[:,-1]

        g_loss_q = self.criterion(gen_q_discrim, gen_fake_label)
        if torch.isnan(g_loss_q):
            print("g loss error")
            
        return d_loss_q, g_loss_q

    def single_task_forward(self, x_spt, y_spt, x_qry, y_qry, update_steps,nets=None, images=False):
        
        corrects = {key: np.zeros(update_steps + 1) for key in 
                        [
                        "query_nway", # number of meta-test (query) images correctly discriminated
                        "predict_other",
                        "gen_discrim", # number of generated images correctly discriminated
                        ]}

        support_sz, c_, h, w = x_spt.size()
        nz = 100
        
        discrim_weights,gen_weights = [x.parameters() for x in nets]

        # this is the meta-test loss and accuracy before first update

        q_discrim,other = self.get_num_corrects(y=y_qry, weights=None, x=x_qry)
        corrects["query_nway"][0] += q_discrim
        corrects["predict_other"][0] += other
        # run the i-th task and compute loss for k-th inner update
        query_fake_label = torch.full((self.qry_gen_num,), self.discrim_fake, dtype=torch.long, device=device)
        for k in range(1, update_steps + 1):
            ## discrim loss
            noise = torch.randn(self.spy_gen_num, nz , 1, 1, device=device)
            x_gen = self.generator(x_spt, noise , vars=gen_weights)
            
            # update discrim weight

            real_discrim_logits = self.pred(x_spt, weights=discrim_weights)
            if (True in torch.isnan(real_discrim_logits)):
                print("inner real d loss error")
                print(q_discrim,other)
                print(query_fake_label)
                print(noise)
                print(x_gen)
                print(real_discrim_logits)
            fake_discrim_logits = self.pred(x_gen, weights=discrim_weights)
            if (True in torch.isnan(fake_discrim_logits)):
                print("inner fake d loss error")
                print(fake_discrim_logits)
            fake_label = torch.full((self.spy_gen_num,), self.discrim_fake, dtype=torch.long, device=device)
            
            real_discrim_loss = self.cross_entropy(real_discrim_logits, y_spt)
            fake_discrim_loss = self.cross_entropy(fake_discrim_logits,fake_label)
            D_loss = fake_discrim_loss + real_discrim_loss

            discrim_weights = self.update_weights(D_loss, discrim_weights,self.update_lr) 
            
            with torch.no_grad():
                x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights) 
                gen_correct,_ = self.get_num_corrects(y=query_fake_label,x=x_gen, weights=discrim_weights,real=False)
                corrects["gen_discrim"][k-1] += gen_correct
                
                q_discrim_correct,other = self.get_num_corrects(y=y_qry, x=x_qry, weights=discrim_weights)
                corrects['query_nway'][k] += q_discrim_correct
                corrects["predict_other"][k] += other
            # meta-test nway and discrim accuracy
            # [query_sz]

        # final gen-discrim and gen-nway accuracy
        with torch.no_grad():
            x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights)
            gen_correct,_ = self.get_num_corrects(y=query_fake_label,x=x_gen, weights=discrim_weights,real=False)
            corrects["gen_discrim"][-1] += gen_correct
        d_loss_q, g_loss_q = self.meta_test(x_qry,y_qry,discrim_weights,gen_weights)
            
        if images:
            return d_loss_q,g_loss_q, corrects, x_gen
        else:
            return d_loss_q,g_loss_q, corrects

    def forward(self, x_spt, y_spt, x_qry, y_qry,step):
        """
        :param x_spt:   [b, support_sz, c_, h, w]
        :param y_spt:   [b, support_sz]
        :param x_qry:   [b, query_sz, c_, h, w]
        :param y_qry:   [b, query_sz]
        :return:
        """
        self.current_epoch = step 
        tasks_per_batch, support_sz, c_, h, w = x_spt.size()
        query_sz = x_qry.size(1)
        g_loss_q = 0
        d_loss_q = 0
        gen_losses_q = [0 for _ in range(self.update_steps + 1)]
        discrim_losses_q = [0 for _ in range(self.update_steps + 1)]
        corrects = {key: np.zeros(self.update_steps + 1) for key in 
                        [
                        "query_nway", # number of meta-test (query) images correctly discriminated
                        "predict_other",
                        "gen_discrim", # number of generated images correctly discriminated
                        ]}
        net = [self.discrim_net,self.generator]
        for i in range(tasks_per_batch):
            d_loss_q_tmp,g_loss_q_tmp, corrects_tmp = self.single_task_forward(x_spt[i], y_spt[i], x_qry[i], y_qry[i],self.update_steps,nets = net,images=False)
            g_loss_q += g_loss_q_tmp
            d_loss_q += d_loss_q_tmp
            assert len(corrects_tmp.keys()) == len(corrects.keys())
            for key in corrects.keys():
                corrects[key] += corrects_tmp[key]
            
        # end of all tasks
        # sum over final losses on query set across all tasks
        g_loss_q /= tasks_per_batch
        self.meta_gen_optim.zero_grad()
        g_loss_q.backward()
        self.meta_gen_optim.step()        

        # optimize theta parameters
        d_loss_q /= tasks_per_batch
        self.meta_d_optim.zero_grad()
        d_loss_q.backward()
        self.meta_d_optim.step()
        
        accs = {}
        accs["query_nway"] = corrects["query_nway"] / (tasks_per_batch * query_sz)
        accs["predict_other"] = corrects["predict_other"] / (tasks_per_batch * query_sz)
        accs["gen_discrim"] = corrects["gen_discrim"] / (tasks_per_batch * self.qry_gen_num)
        return accs,d_loss_q,g_loss_q

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [support_sz, c_, h, w]
        :param y_spt:   [support_sz]
        :param x_qry:   [query_sz, c_, h, w]
        :param y_qry:   [query_sz]
        :return:
        """

        support_sz, c_, h, w = x_spt.size()

        assert len(x_spt.shape) == 4

        query_sz = x_qry.size(0)

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        
        discrim_net = deepcopy(self.discrim_net)
        generator = deepcopy(self.generator)
        net = [self.discrim_net,self.generator]
        d_loss_q,g_loss_q, corrects, imgs = self.single_task_forward(x_spt, y_spt, x_qry, y_qry,self.update_steps_test, nets=net,images=True)

        del discrim_net
        
        accs["query_nway"] = corrects["query_nway"] / (query_sz)
        accs["predict_other"] = corrects["predict_other"] / (query_sz)
        accs["gen_discrim"] = corrects["gen_discrim"] / (self.qry_gen_num)

        return accs, imgs,d_loss_q,g_loss_q


In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamlGAN = Meta(args, discriminator_config, gen_config).to(device)
step = 0
path = args["save_path"]
mkdir_p(path)
best_acc = []
mamlGAN.load_state_dict(torch.load(PATH),False)

_IncompatibleKeys(missing_keys=['generator.vars.0', 'generator.vars.1', 'generator.vars.2', 'generator.vars.3', 'generator.vars.4', 'generator.vars.5', 'generator.vars.6', 'generator.vars.7', 'generator.vars.8', 'generator.vars.9', 'generator.vars.10', 'generator.vars.11', 'generator.vars.12', 'generator.vars.13', 'generator.vars.14', 'generator.vars.15', 'generator.vars.16', 'generator.vars.17', 'generator.vars.18', 'generator.vars.19', 'generator.vars_bn.0', 'generator.vars_bn.1', 'generator.vars_bn.2', 'generator.vars_bn.3', 'generator.vars_bn.4', 'generator.vars_bn.5', 'generator.vars_bn.6', 'generator.vars_bn.7', 'discrim_net.vars.0', 'discrim_net.vars.1', 'discrim_net.vars.2', 'discrim_net.vars.3', 'discrim_net.vars.4', 'discrim_net.vars.5', 'discrim_net.vars.6', 'discrim_net.vars.7', 'discrim_net.vars.8', 'discrim_net.vars.9', 'discrim_net.vars.10', 'discrim_net.vars.11', 'discrim_net.vars.12', 'discrim_net.vars.13', 'discrim_net.vars.14', 'discrim_net.vars.15', 'discrim_net.var

In [9]:
with tqdm.tqdm(initial=step,
                   total=int(args["epoch"])) as pbar_train:
    for _ in range(args["epoch"] * args["tasks_per_batch"]//6000):
        train_dataloader = DataLoader(train_data_generator, args["tasks_per_batch"], shuffle=True, num_workers=2, pin_memory=True)

        for _, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_dataloader):
            tmp_x_spt, tmp_y_spt, tmp_x_qry, tmp_y_qry = x_spt.squeeze(0), y_spt.squeeze(0), \
                                         x_qry.squeeze(0), y_qry.squeeze(0)
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs,d_loss,g_loss = mamlGAN(x_spt, y_spt, x_qry, y_qry,step)
            writer.add_scalar('Loss/train_d_loss', d_loss, step)
            writer.add_scalar('Loss/train_g_loss', g_loss, step)
            writer.add_scalar('Accuracy/query_nway', accs["query_nway"][-1], step)
            writer.add_scalar('Accuracy/gen_discrim', accs["gen_discrim"][-1], step)
            writer.add_scalar('Accuracy/predict_other', accs["predict_other"][-1],step)
            if step % 100 == 0:
                print("step " + str(step))
                print('d loss:',d_loss.item())
                print('g loss:',g_loss.item())
                print("accs",accs)


            if step % 300 == 0:  # evaluation
                db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                imgs_all_test = []
                d_loss_all_test = []
                g_loss_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    tmp_x_spt, tmp_y_spt, tmp_x_qry, tmp_y_qry = x_spt.squeeze(0), y_spt.squeeze(0), \
                                                 x_qry.squeeze(0), y_qry.squeeze(0)
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
                    # accs, d_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs, imgs,d_loss,g_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)


                    accs_all_test.append(accs)
                    imgs_all_test.append(imgs.cpu().detach().numpy())
                    d_loss_all_test.append(d_loss.item())
                    g_loss_all_test.append(g_loss.item())

                imgs_all_test = np.array(imgs_all_test)
                # [b, update_step+1]
                # accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                d_loss = np.mean(np.array(d_loss_all_test))
                g_loss = np.mean(np.array(g_loss_all_test))

                print('d loss:',d_loss)
                print('g loss:',g_loss)
                print('Test acc:', accs)    

                writer.add_scalar('Loss/test_d_loss', d_loss, step)
                writer.add_scalar('Loss/test_g_loss', g_loss, step)
                writer.add_scalar('Accuracy/test_query_nway', accs["query_nway"][-1], step)
                writer.add_scalar('Accuracy/test_gen_discrim', accs["gen_discrim"][-1], step)
                writer.add_scalar('Accuracy/test_predict_other', accs["predict_other"][-1],step)
                if not len(best_acc):
                    best_acc = accs
                    best_epoch = step
                    torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                else:
                    if max(accs) > max(best_acc):
                        best_acc = accs
                        best_epoch = step
                        torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/model_step" + str(step) + ".pth")

                save_imgs(path, imgs_all_test, step)

            step = step + 1
            pbar_train.update(1)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
tmp_x_spt.size()

In [None]:
tmp_y_spt.size()

In [None]:
tmp_x_qry.size()

In [None]:
tmp_y_qry.size()

In [None]:
tmp_y_qry