In [1]:
import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
from    torch.optim import lr_scheduler
import torchvision.datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

import  numpy as np
import json
from datetime import datetime
from    copy import deepcopy
from torch.autograd import Variable
from torchsummary import summary
from utils.dataloader import train_data_gen,test_data_gen

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
args = {'epoch':24000,
        'n_way':5,
        'k_spt':1,
        'k_qry':15,
        'img_sz':64,
        "tasks_per_batch":4,
        'img_c':3,
        'task_num': 4,
        'meta_lr':1e-3,
        'update_lr':2e-4,
        'update_steps':5,
        'update_steps_test':10,
        "no_save":False,
        "learn_inner_lr":True,
        'condition_discrim':False,
        "loss":"cross_entropy",
        "create_graph":False,
        "num_distractor":1,
        'save_path':'0409_conditional_result',
       }

In [3]:
def mkdir_p(path):
    if not os.path.exists(path):
        os.makedirs(path)

def save_train_accs(path, accs):
    file = open(path +  '/q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["q_nway"]]))
    file.close()

    file = open(path +  '/q_discrim_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["q_discrim"]]))
    file.close()

    file = open(path +  '/gen_discrim_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs["gen_discrim"]]))
    file.close()

def save_test_accs(path, accs):
    file = open(path +  '/test_q_nway_accuracies.txt', 'ab')
    np.savetxt(file, np.array([accs]))
    file.close()

def save_imgs(path, imgs, step):
    # save raw txt files
    img_f=open(path+"/images_step" + str(step) + ".txt",'ab')
    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]
    np.savetxt(img_f,some_imgs)
    img_f.close()

    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,args["img_sz"],args["img_sz"]).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 49:
            plt.subplot(7, 7, 1 + i)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig(path+"/images_step" + str(step) + ".png")
    plt.close()

In [4]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load dataset/BelgiumTSC
load complete time 4.5584189891815186
load dataset/ArTS
load complete time 4.621725797653198
load dataset/chinese_traffic_sign
load complete time 0.7982516288757324
load dataset/CVL
load complete time 0.6074516773223877
load dataset/FullJCNN2013
load complete time 0.3425893783569336
load dataset/logo_2k
load complete time 1.3441529273986816
load dataset/GTSRB
load complete time 0.2050189971923828
load dataset/DFG
load complete time 0.07396125793457031


In [5]:
class Nway_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 13 * 13, 120)
        self.fc2 = nn.Linear(120, 64)
        self.fc3 = nn.Linear(64, args["n_way"])

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
nway_classifier = Nway_classifier().to(device)
summary(nway_classifier, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 60, 60]             456
         MaxPool2d-2            [-1, 6, 30, 30]               0
            Conv2d-3           [-1, 16, 26, 26]           2,416
         MaxPool2d-4           [-1, 16, 13, 13]               0
            Linear-5                  [-1, 120]         324,600
            Linear-6                   [-1, 64]           7,744
            Linear-7                    [-1, 5]             325
Total params: 335,541
Trainable params: 335,541
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 0.31
Params size (MB): 1.28
Estimated Total Size (MB): 1.64
----------------------------------------------------------------


In [7]:
class Semi_nway_classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 13 * 13, 120)
        self.fc2 = nn.Linear(120, 64)
        self.fc3 = nn.Linear(64, 6)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [8]:
semi_nway_classifier = Semi_nway_classifier().to(device)
summary(semi_nway_classifier, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 60, 60]             456
         MaxPool2d-2            [-1, 6, 30, 30]               0
            Conv2d-3           [-1, 16, 26, 26]           2,416
         MaxPool2d-4           [-1, 16, 13, 13]               0
            Linear-5                  [-1, 120]         324,600
            Linear-6                   [-1, 64]           7,744
            Linear-7                    [-1, 6]             390
Total params: 335,606
Trainable params: 335,606
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 0.31
Params size (MB): 1.28
Estimated Total Size (MB): 1.64
----------------------------------------------------------------


In [9]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 13 * 13, 120)
        self.fc2 = nn.Linear(120, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
disrim = discriminator().to(device)
summary(disrim, (3, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 60, 60]             456
         MaxPool2d-2            [-1, 6, 30, 30]               0
            Conv2d-3           [-1, 16, 26, 26]           2,416
         MaxPool2d-4           [-1, 16, 13, 13]               0
            Linear-5                  [-1, 120]         324,600
            Linear-6                   [-1, 64]           7,744
            Linear-7                    [-1, 1]              65
Total params: 335,281
Trainable params: 335,281
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.05
Forward/backward pass size (MB): 0.31
Params size (MB): 1.28
Estimated Total Size (MB): 1.64
----------------------------------------------------------------


In [11]:
nz = 100
ngf = 64
nc = 3
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [12]:
generator = Generator().to(device)
summary(generator, (100,1,1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 512, 4, 4]         819,200
       BatchNorm2d-2            [-1, 512, 4, 4]           1,024
              ReLU-3            [-1, 512, 4, 4]               0
   ConvTranspose2d-4            [-1, 256, 8, 8]       2,097,152
       BatchNorm2d-5            [-1, 256, 8, 8]             512
              ReLU-6            [-1, 256, 8, 8]               0
   ConvTranspose2d-7          [-1, 128, 16, 16]         524,288
       BatchNorm2d-8          [-1, 128, 16, 16]             256
              ReLU-9          [-1, 128, 16, 16]               0
  ConvTranspose2d-10           [-1, 64, 32, 32]         131,072
      BatchNorm2d-11           [-1, 64, 32, 32]             128
             ReLU-12           [-1, 64, 32, 32]               0
  ConvTranspose2d-13            [-1, 3, 64, 64]           3,072
             Tanh-14            [-1, 3,

In [13]:
cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
IntTensor = torch.cuda.IntTensor if cuda else torch.IntTensor 

params = list(nway_classifier.parameters()) + list(disrim.parameters()) + list(semi_nway_classifier.parameters())
params += list(generator.parameters())

meta_optim = optim.Adam(params, lr=0.02)

In [14]:
def pred(x, weights=[None, None, None], nets=None):

    nway_weights, discrim_weights, semi_weights = nway_classifier,disrim,semi_nway_classifier
    nway_net, discrim_net, semi_net = nway_classifier,disrim,semi_nway_classifier

    discrim_logits = discrim_net(x)
    class_logits = nway_net(x)
    semi_class_logits = semi_net(x)

    return class_logits, discrim_logits,semi_class_logits

In [15]:
def get_num_corrects(unlabel, y=None, x=None, weights=None, class_logits=None, semi_logits=None,discrim_logits=None ,conditions=None):
    if semi_logits == None:
        semi_logits = class_logits
    if unlabel:
        nway_correct = None
    else:
        nway_correct = torch.eq(class_logits.argmax(dim=1), y).sum().item()
    semi_correct = torch.eq(semi_logits.argmax(dim=1), y).sum().item()
    return nway_correct, semi_correct

In [16]:
import os
os.environ['CUDA_LAUNCH_BLOCKING']="1"

In [None]:
update_steps = 4
real_val = 1.0
fake_val = 0.0
distractor_val = float(5 + 1)
inner_g_optim = optim.Adam(gen_weights, 2e-4, betas=(0.5, 0.999))

for step, (x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry) in enumerate(train_data_generator):

    x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), \
    y_qry.to(device),unlbl_x_spt.to(device),unlbl_x_qry.to(device)
    support_sz, c_, h, w = x_spt.size()
    corrects = {key: np.zeros(update_steps + 1) for key in 
                    [
                    "q_nway", # number of meta-test (query) images correctly classified
                    "q_n+1_nway",
                    "distractor_n+1_way",
                    "nway_loss",
                    "n_1_way_loss",
                    "gen_loss", # number of generated images correctly discriminated
                    "discrim_loss"
                    ]}

    semi_net = deepcopy(semi_nway_classifier)
    nway_net = deepcopy(nway_classifier)
    discrim_net = deepcopy(disrim)
    gen_net = deepcopy(generator)                            
    nets = (nway_net, discrim_net, semi_nway_classifier)
    net_weights = [net.parameters() for net in nets]
    gen_weights = gen_net.parameters()


    real = Variable(FloatTensor(support_sz, 1).fill_(real_val), requires_grad=False)
    fake = Variable(FloatTensor(support_sz, 1).fill_(fake_val), requires_grad=False)
    y_distractor = Variable(FloatTensor(support_sz).fill_(distractor_val), requires_grad=False)
    semi_disrractor = Variable(FloatTensor(15).fill_(distractor_val), requires_grad=False)
    # qry_distractor = Variable(FloatTensor(75, 1).fill_(distractor_val), requires_grad=False)
    # this is the meta-test loss and accuracy before first update
    real_class_logits, real_discrim_logits,real_semi_logits = pred(x_qry, weights=net_weights)
    q_nway, q_semi_way = get_num_corrects(unlabel=False,y=y_qry, weights=[None, None, None],class_logits=real_class_logits ,x=x_qry)
    corrects['q_nway'][0] += q_nway
    corrects['q_n+1_nway'][0] += q_semi_way
    unlbl_class_logits, unlbl_discrim_logits,unlbl_semi_logits = pred(unlbl_x_qry, weights=net_weights)
    _, distractor_semi_way = get_num_corrects(unlabel=True, y=semi_disrractor, weights=[None, None, None],semi_logits= unlbl_class_logits,x=unlbl_x_qry)
    corrects['distractor_n+1_way'][0] += distractor_semi_way

    for k in range(1, update_steps + 1):
        noise = torch.randn(x_spt.size(0), 100, 1, 1, device=device)

        x_gen = generator(noise)

        # train  discriminator
        real_class_logits, real_discrim_logits,real_semi_logits = pred(x_spt, weights=net_weights)
        nway_loss = F.cross_entropy(real_class_logits, y_spt)
        real_semi_loss = F.cross_entropy(real_semi_logits, y_spt)
        real_discrim_loss = F.binary_cross_entropy_with_logits(real_discrim_logits, real)

        inner_d_optim.zero_grad()
        real_discrim_loss.backward(retain_graph=True)
        inner_d_optim.step()

#         inner_n_way_optim.zero_grad()
#         nway_loss.backward(retain_graph=True)
#         inner_n_way_optim.step()            

#         inner_unlabel_optim.zero_grad()
#         nway_loss.backward(retain_graph=True)
#         inner_unlabel_optim.step()

#         _, gen_discrim_logits,gen_semi_logits = pred(x_gen, weights=net_weights)
#         gen_discrim_loss = F.binary_cross_entropy_with_logits(gen_discrim_logits, fake)
#         semi_disrractor = Variable(IntTensor(5).fill_(distractor_val), requires_grad=False)

#         semi_disrractor = semi_disrractor.type(torch.cuda.LongTensor) 
#         gen_semi_loss = F.cross_entropy(gen_semi_logits, semi_disrractor)

#         inner_d_optim.zero_grad()
#         gen_discrim_loss.backward(retain_graph=True)
#         inner_d_optim.step()

#         inner_unlabel_optim.zero_grad()
#         gen_semi_loss.backward(retain_graph=True)
#         inner_unlabel_optim.step()            


#         unlabel_class_logits, unlabel_discrim_logits,unlabel_semi_logits = pred(unlbl_x_spt, weights=net_weights)
#         unlabel_discrim_loss = F.binary_cross_entropy_with_logits(gen_discrim_logits, real)   
#         unlabel_semi_loss = F.cross_entropy(unlabel_semi_logits, y_distractor)

#         inner_d_optim.zero_grad()
#         unlabel_discrim_loss.backward(retain_graph=True)
#         inner_d_optim.step()

#         inner_unlabel_optim.zero_grad()
#         unlabel_semi_loss.backward(retain_graph=True)
#         inner_unlabel_optim.step()    

#         #train generator
#         x_gen = generator(noise)
#         _, gen_discrim_logits,gen_semi_logits = pred(x_gen, weights=net_weights)
#         gen_discrim_loss = F.binary_cross_entropy_with_logits(gen_discrim_logits, fake)
#         gen_loss = -1 * torch.nn.functional.logsigmoid(gen_discrim_logits).mean() #- gen_discrim_loss
#         inner_g_optim.zero_grad()
#         gen_loss.backward()
#         inner_g_optim.step()  

#         # meta-test nway and discrim accuracy
#         # [query_sz]

#         q_nway, q_semi_way = get_num_corrects(unlabel=False,y=y_qry, weights=[None, None, None], x=x_qry)

#         corrects['q_nway'][k] += q_nway
#         corrects['q_n+1_nway'][k] += q_semi_way

#         _, distractor_semi_way = get_num_corrects(unlabel=True, y=semi_disrractor, weights=[None, None, None], x=unlbl_x_qry)

#         corrects['distractor_n+1_way'][k] += distractor_semi_way

#     # meta-test loss
#     real_class_logits, real_discrim_logits,real_semi_logits = pred(x_spt, weights=net_weights)
#     loss_q = F.cross_entropy(real_class_logits, y_spt)
#     print("gen_loss:",gen_loss)
#     print("nway_loss:",loss_q)
    print("disrim_loss",real_discrim_loss)

In [None]:
_, gen_discrim_logits,gen_semi_logits = pred(x_gen, weights=net_weights)

In [None]:
class Meta(nn.Module):
    def __init__(self, args, config):
        super(Meta, self).__init__()
        self.update_lr = args["update_lr"]
        self.meta_lr = args["meta_lr"]
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.task_num = args["task_num"]
        self.update_step = args["update_step"]
        self.update_step_test = args["update_step_test"]
        self.learn_inner_lr = args["learn_inner_lr"]
        self.nway_net = nway_classifier
        self.discrim_net = disrim
        self.semi_net = semi_nway_classifier
        self.generator_net = generator
        
        cuda = torch.cuda.is_available()
        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor   
        
        params = list(self.nway_classifier.parameters()) + list(self.disrim.parameters()) + list(self.semi_net.parameters())
        params += list(self.generator.parameters())
                    
        self.create_graph = args.create_graph
        self.meta_optim = optim.Adam(params, lr=self.meta_lr)
        self.loss = args.loss
        
        self.real_val = 1.0 # requires that real_val > fake_val
        self.fake_val = 0.0
        self.distractor_val = float(self.n_way + 1) # requires that real_val > fake_val

    def pred(self, x, weights=[None, None, None], nets=None, nway=True, discrim=True, conditions=None):

        nway_weights, discrim_weights, semi_weights = weights
        nway_net, discrim_net, semi_net = nets

        discrim_logits = discrim_net(x, conditions=conditions, vars=discrim_weights, bn_training=True) if discrim else None
        class_logits = nway_net(x, vars=nway_weights, bn_training=True) if nway else None
        semi_class_logits = semi_net(x, vars=semi_weights, bn_training=True) if semi_weights else None
        
        discrim_logits = discrim_net(x, vars=discrim_weights) if discrim else None
        class_logits = nway_net(x, vars=nway_weights) if nway else None
        semi_class_logits = nway_net(x,vars=nway_weights) if semi_weights else None

        return class_logits, discrim_logits,semi_class_logits

    def get_num_corrects(self, unlabel, y, x=None, weights=None, class_logits=None, discrim_logits=None, conditions=None):
        with torch.no_grad():
            if type(class_logits) == type(None):
                
                class_logits, discrim_logits,semi_logits = self.pred(x, weights=weights)
            
            nway_correct = torch.eq(class_logits.argmax(dim=1), y).sum().item()
            semi_correct = torch.eq(semi_logits.argmax(dim=1, y).sum().item()
                
        return nway_correct, semi_correct
    def loss_cross_entropy(self, class_logits, y_class, discrim_logits=None, y_discrim=None, distractor):

        nway_loss = F.cross_entropy(class_logits, y_class)
        if type(discrim_logits) == type(None):
            return nway_loss
        
        discrim_loss = F.binary_cross_entropy_with_logits(discrim_logits, y_discrim)

        return nway_loss, discrim_loss  

    def single_task_forward(self, x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry, nets=None,generator_net = None, images=False):
        
        corrects = {key: np.zeros(self.update_steps + 1) for key in 
                        [
                        "q_nway", # number of meta-test (query) images correctly classified
                        "q_n+1_nway",
                        "distractor_n+1_way",
                        "nway_loss",
                        "n_1_way_loss",
                        "gen_loss", # number of generated images correctly discriminated
                        "discrim_loss"
                        ]}
        
        semi_net = deepcopy(self.semi_net)
        nway_net = deepcopy(self.nway_net)
        discrim_net = deepcopy(self.discrim_net)
        gen_net = deepcopy(self.generator_net)                            
        nets = (shared_net, nway_net, discrim_net)
        net_weights = [net.parameters() for net in nets]
        gen_weights = gen_net.parameters()
                          
        inner_g_optim = optim.Adam(gen_weights, 2e-4, betas=(0.5, 0.999))
        inner_n_way_optim = optim.Adam(net_weights[0], 2e-4, betas=(0.5, 0.999))
        inner_d_optim = optim.Adam(net_weights[1], 2e-4, betas=(0.5, 0.999))
        inner_unlabel_optim = optim.Adam(net_weights[2], 2e-4, betas=(0.5, 0.999))

        # this is the meta-test loss and accuracy before first update
        q_nway, q_semi_way = self.get_num_corrects(real=True, y=y_qry, weights=[None, None, None], x=x_qry)
        corrects['q_nway'][0] += q_nway
        corrects['q_n+1_nway'][0] += q_semi_way
        y_distractor = Variable(self.FloatTensor(support_sz, 1).fill_(self.distractor_val), requires_grad=False)
        _, distractor_semi_way = self.get_num_corrects(real=True, y=y_distractor, weights=[None, None, None], x=unlbl_x_qry)
        corrects['distractor_n'][0] += distractor_semi_way

        real = Variable(self.FloatTensor(support_sz, 1).fill_(self.real_val), requires_grad=False)
        fake = Variable(self.FloatTensor(support_sz, 1).fill_(self.fake_val), requires_grad=False)

        for k in range(1, self.update_steps + 1):
            noise = torch.randn(x_spt.size(0), 100, 1, 1, device=device)
            x_gen = generator(noise,vars=gen_weights) 

            y_spt_idxs = [int((y_spt == i).nonzero()[0]) for i in y_spt]

            real_class_logits, real_discrim_logits,real_semi_logits = self.pred(x_spt, weights=net_weights)
            gen_class_logits, gen_discrim_logits,gen_semi_logits = self.pred(x_gen, weights=net_weights)
            gen_class_logits, gen_discrim_logits,gen_semi_logits = self.pred(unlbl_x_spt, weights=net_weights)
                                    
            real_nway_loss, real_discrim_loss = self.loss_cross_entropy(real_class_logits, y_spt, real_discrim_logits, real, distractor=False)
            gen_nway_loss, gen_discrim_loss = self.loss_cross_entropy(gen_class_logits, y_gen, gen_discrim_logits, fake, distractor=True)
            gen_nway_loss, gen_discrim_loss = self.loss_cross_entropy(gen_class_logits, y_gen, gen_discrim_logits, fake, distractor=True)

            nway_loss = (gen_nway_loss + real_nway_loss) / 2
            discrim_loss = (gen_discrim_loss + real_discrim_loss) / 2
            shared_loss = nway_loss + discrim_loss  #

            gen_loss = -1 * torch.nn.functional.logsigmoid(gen_discrim_logits).mean() #- gen_discrim_loss

            net_losses = (shared_loss, nway_loss, discrim_loss)
            if self.learn_inner_lr:
                net_weights, gen_weights = self.update_weights_learned_lr(net_losses, net_weights, gen_loss, gen_weights, self.learned_lrs[k-1])
            else:
                net_weights, gen_weights = self.update_weights(net_losses, net_weights, gen_loss, gen_weights)

            _, gen_discrim_correct = self.get_num_corrects(real=False, y=y_gen, class_logits=gen_class_logits, discrim_logits=gen_discrim_logits, conditions=class_image_embeddings)

            corrects["gen_nway"][k-1] += gen_nway_correct
            corrects["gen_discrim"][k-1] += gen_discrim_correct

            # meta-test nway and discrim accuracy
            # [query_sz]
            q_nway_correct, q_discrim_correct = self.get_num_corrects(real=True, y=y_qry, x=x_qry, weights=net_weights, conditions=class_image_embeddings)
            corrects['q_nway'][k] += q_nway_correct
            corrects['q_discrim'][k] += q_discrim_correct


        # final gen-discrim and gen-nway accuracy
        with torch.no_grad():
            x_gen, y_gen = self.generator(x_spt, y_spt, vars=gen_weights, bn_training=False)
            gen_nway_correct, gen_discrim_correct = self.get_num_corrects(real=False, y=y_gen, x=x_gen, weights=net_weights, conditions=class_image_embeddings)

            corrects['gen_nway'][-1] += gen_nway_correct
            corrects['gen_discrim'][-1] += gen_discrim_correct

        # meta-test loss
        q_class_logits, _ = self.pred(x_qry, weights=net_weights, discrim=False)
        loss_q = self.loss_cross_entropy(q_class_logits, y_qry) # doesn't use discrim loss

        if images:
            return loss_q, corrects, x_gen
        else:
            return loss_q, corrects
#     def forward(self, x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry):

#         task_num, setsz, c_, h, w = x_spt.size()
#         querysz = x_qry.size(1)

#         losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i
#         corrects = [0 for _ in range(self.update_step + 1)]

#         for i in range(task_num):

#             # 1. run the i-th task and compute loss for k=0
#             logits = self.net(x_spt[i], vars=None, bn_training=True)
#             loss = F.cross_entropy(logits, y_spt[i])
#             grad = torch.autograd.grad(loss, self.net.parameters())
#             fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

#             # this is the loss and accuracy before first update
#             with torch.no_grad():
#                 # [setsz, nway]
#                 logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
#                 loss_q = F.cross_entropy(logits_q, y_qry[i])
#                 losses_q[0] += loss_q

#                 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#                 correct = torch.eq(pred_q, y_qry[i]).sum().item()
#                 corrects[0] = corrects[0] + correct

#             # this is the loss and accuracy after the first update
#             with torch.no_grad():
#                 # [setsz, nway]
#                 logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
#                 loss_q = F.cross_entropy(logits_q, y_qry[i])
#                 losses_q[1] += loss_q
#                 # [setsz]
#                 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#                 correct = torch.eq(pred_q, y_qry[i]).sum().item()
#                 corrects[1] = corrects[1] + correct

#             for k in range(1, self.update_step):
#                 # 1. run the i-th task and compute loss for k=1~K-1
#                 logits = self.net(x_spt[i], fast_weights, bn_training=True)
#                 loss = F.cross_entropy(logits, y_spt[i])
#                 # 2. compute grad on theta_pi
#                 grad = torch.autograd.grad(loss, fast_weights)
#                 # 3. theta_pi = theta_pi - train_lr * grad
#                 fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

#                 logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
#                 # loss_q will be overwritten and just keep the loss_q on last update step.
#                 loss_q = F.cross_entropy(logits_q, y_qry[i])
#                 losses_q[k + 1] += loss_q

#                 with torch.no_grad():
#                     pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#                     correct = torch.eq(pred_q, y_qry[i]).sum().item()  # convert to numpy
#                     corrects[k + 1] = corrects[k + 1] + correct

#         # optimize theta parameters
#         loss_q = losses_q[-1] / task_num
#         self.meta_optim.zero_grad()
#         loss_q.backward()
#         self.meta_optim.step()


#         accs = np.array(corrects) / (querysz * task_num)

#         return accs


#     def finetunning(self, x_spt, y_spt, x_qry, y_qry):
#         """

#         :param x_spt:   [setsz, c_, h, w]
#         :param y_spt:   [setsz]
#         :param x_qry:   [querysz, c_, h, w]
#         :param y_qry:   [querysz]
#         :return:
#         """
#         assert len(x_spt.shape) == 4

#         querysz = x_qry.size(0)

#         corrects = [0 for _ in range(self.update_step_test + 1)]

#         # in order to not ruin the state of running_mean/variance and bn_weight/bias
#         # we finetunning on the copied model instead of self.net
#         net = deepcopy(self.net)

#         # 1. run the i-th task and compute loss for k=0
#         logits = net(x_spt)
#         loss = F.cross_entropy(logits, y_spt)
#         grad = torch.autograd.grad(loss, net.parameters())
#         fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))

#         # this is the loss and accuracy before first update
#         with torch.no_grad():
#             # [setsz, nway]
#             logits_q = net(x_qry, net.parameters(), bn_training=True)
#             # [setsz]
#             pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#             # scalar
#             correct = torch.eq(pred_q, y_qry).sum().item()
#             corrects[0] = corrects[0] + correct

#         # this is the loss and accuracy after the first update
#         with torch.no_grad():
#             # [setsz, nway]
#             logits_q = net(x_qry, fast_weights, bn_training=True)
#             # [setsz]
#             pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#             # scalar
#             correct = torch.eq(pred_q, y_qry).sum().item()
#             corrects[1] = corrects[1] + correct

#         for k in range(1, self.update_step_test):
#             # 1. run the i-th task and compute loss for k=1~K-1
#             logits = net(x_spt, fast_weights, bn_training=True)
#             loss = F.cross_entropy(logits, y_spt)
#             # 2. compute grad on theta_pi
#             grad = torch.autograd.grad(loss, fast_weights)
#             # 3. theta_pi = theta_pi - train_lr * grad
#             fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

#             logits_q = net(x_qry, fast_weights, bn_training=True)
#             # loss_q will be overwritten and just keep the loss_q on last update step.
#             loss_q = F.cross_entropy(logits_q, y_qry)

#             with torch.no_grad():
#                 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#                 correct = torch.eq(pred_q, y_qry).sum().item()  # convert to numpy
#                 corrects[k + 1] = corrects[k + 1] + correct
                
#         del net

#         accs = np.array(corrects) / querysz

#         return accs

In [None]:
# mamlGAN = MetaGAN(args, shared_config, nway_config, discriminator_config, gen_config).to(device)

In [None]:
# tmp = filter(lambda x: x.requires_grad, mamlGAN.parameters())
# num = sum(map(lambda x: np.prod(x.shape), tmp))
# print(mamlGAN)
# print('Total trainable tensors:', num)

In [None]:
for epoch in range(args.epoch//10000):
    # fetch meta_batchsz num of episode each time
    db = DataLoader(mini, args.tasks_per_batch, shuffle=True, num_workers=1, pin_memory=True)

    for step, (x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry) in enumerate(db):

        x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
        accs = mamlGAN(x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry)

        if step % 30 == 0:
            print("step " + str(step))
            for key in accs.keys():
                print(key + ": " + str(accs[key]))
            if save_model:
                save_accs(path, accs)

        if step % 500 == 0:  # evaluation
            db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
            accs_all_test = []
            imgs_all_test = []
            for x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry in db_test:
                x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                             x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                accs, imgs = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry,unlbl_x_spt,unlbl_x_qry)
                accs_all_test.append(accs)
                imgs_all_test.append(imgs.cpu().detach().numpy())

            imgs_all_test = np.array(imgs_all_test)

            if save_model:
                save_imgs(path, imgs_all_test, step)

            print('Test acc:', accs)