In [2]:
import  torch
from    torch import nn
from    torch import optim
from    torch.nn import functional as F
from    torch.utils.data import TensorDataset, DataLoader
from    torch import optim
import  numpy as np
from torch.autograd import Variable

from   models.learner import Learner
from models.generator import Generator
from    copy import deepcopy
import os
from torchsummary import summary

from utils.dataloader import train_data_gen , test_data_gen
import shutil
import matplotlib.pyplot as plt
# import tqdm.notebook as tqdm
from torch.utils.tensorboard import SummaryWriter
import json
import tqdm

In [3]:
with open('configs/0613_maml_gen.json') as json_file:
    args = json.load(json_file)

In [4]:
print(args)

{'epoch': 96000, 'n_way': 5, 'k_spt': 5, 'k_qry': 15, 'img_sz': 84, 'tasks_per_batch': 3, 'img_c': 3, 'meta_gen_lr': 0.0005, 'meta_discrim_lr': 0.0001, 'update_lr': 0.004, 'update_steps': 2, 'update_steps_test': 2, 'loss': 'cross_entropy', 'min_learning_rate': 1e-15, 'number_of_training_steps_per_iter': 4, 'multi_step_loss_num_epochs': 15, 'spy_gen_num': 5, 'qry_gen_num': 25, 'num_distractor': 0, 'batch_for_gradient': 50, 'no_save': 0, 'learn_inner_lr': 0, 'create_graph': 0, 'msl': 0, 'single_fast_test': 0, 'consine_schedule': 0, 'save_path': '0612_maml_gen'}


In [5]:
# if os.path.exists("images/" + args["save_path"]):
#     shutil.rmtree("images/" + args["save_path"])
    
# if os.path.exists("data/" + args["save_path"]):
#     shutil.rmtree("data/" + args["save_path"])
    
# if os.path.exists("save_models/" + args["save_path"]):
#     shutil.rmtree("save_models/" + args["save_path"])
    
# if os.path.exists("runs/" + args["save_path"]):
#     shutil.rmtree("runs/" + args["save_path"])    

writer = SummaryWriter('runs/' + args["save_path"])

In [6]:
def mkdir_p(path):
    if not os.path.exists("images/" + path):
        os.makedirs("images/" + path)
        
    if not os.path.exists("data/" + path):
        os.makedirs("data/" + path)
        
    if not os.path.exists("save_models/" + path):
        os.makedirs("save_models/" + path)        

def save_imgs(path, imgs, step):

    some_imgs = np.reshape(imgs, [imgs.shape[0]*imgs.shape[1], -1])[0:50]

    # save png of imgs
    i = 0
    for flat_img in some_imgs:
        img = flat_img.reshape(3,84,84).swapaxes(0,1).swapaxes(1,2)
        im = ((img - np.min(img))*255/(np.max(img - np.min(img)))).astype(np.uint8)
        if i < 15:
            plt.subplot(5, 3, i+1)
            plt.axis('off')
            plt.imshow(im)
        i += 1
    plt.savefig("images/" + path + "/images_step" + str(step) + ".png")
    plt.close()

In [7]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load datasets/BelgiumTSC
load complete time 0.33933115005493164
load datasets/ArTS
load complete time 0.3034861087799072
load datasets/chinese_traffic_sign
load complete time 0.6109473705291748
load datasets/CVL
load complete time 0.42543721199035645
load datasets/FullJCNN2013
load complete time 0.22013282775878906
load datasets/logo_2k
load complete time 0.9927217960357666
load datasets/GTSRB
load complete time 0.0888059139251709
load datasets/DFG
load complete time 0.03560972213745117


In [8]:
ndf = 64
discriminator_config = [
    ('conv2d', [ndf, 3, 4, 4, 2, 1]),
    ('leakyrelu', [0.2,True]),
    # ('bn', [ndf]),
    
    ('conv2d', [ndf*2, ndf, 4, 4, 2, 1]),
    ('bn', [ndf*2]),
    ('leakyrelu', [0.2,True]),

    ('conv2d', [ndf*4, ndf*2, 4, 4, 2, 1]),
    ('bn', [ndf*4]),
    ('leakyrelu', [0.2,True]),
    
    
    ('conv2d', [ndf*8, ndf*4, 4, 4, 2, 1]),
    ('bn', [ndf*8]),
    ('leakyrelu', [0.2,True]),
    
    ('conv2d', [1,ndf*8 , 2, 2, 1, 0]),
    ('flatten', []),
    ('linear',[6, 16]),
    ('softmax',[])
]
nz = 100
ngf = 64
gen_config = [
    ('convert_z',[]),
    ('convt2d',[nz,ngf*8,4,4,1,0]),
    ('bn',[ngf * 8]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*8,ngf*4,4,4,2,0]),
    ('bn',[ngf * 4]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*4,ngf*2,4,4,2,0]),
    ('bn',[ngf * 2]),
    ('leakyrelu', [.2, True]),  
    
    ('convt2d',[ngf*2,ngf,3,3,2,1]),
    ('bn',[ngf]),
    ('leakyrelu', [.2, True]),      
    
    ('convt2d',[ngf,3,3,3,2,1]),
    ('convt2d',[3,3,2,2,1,1]),
    ("tanh",[])
]

In [11]:
class Meta(nn.Module):
    """
    Meta Learner with GAN incorporated
    """
    def __init__(self, args, discriminator_config, gen_config):
        """
        :param args:
        """
        super(Meta, self).__init__()
        
        cuda = torch.cuda.is_available()
        self.FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 
        self.LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
        self.total_epochs = args["epoch"]   
        # model parameters config
        self.meta_gen_lr = args["meta_gen_lr"]
        self.meta_discrim_lr = args["meta_discrim_lr"]
        
        self.update_lr = args["update_lr"]
        self.consine_schedule = args["consine_schedule"]
        self.update_steps = args["update_steps"]
        self.update_steps_test = args["update_steps_test"]

        # dataset config
        self.img_c = args["img_c"]
        self.img_sz = args["img_sz"]        
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.MSL = args["msl"]
        # generator num
        self.spy_gen_num = args["spy_gen_num"]
        self.qry_gen_num = args["qry_gen_num"]
        # query gan batch
        self.batch_for_gradient = args["batch_for_gradient"]
        self.fix_noise = torch.randn(self.qry_gen_num, nz,1,1, device=device)
        self.criterion = nn.BCELoss()
        # load model
        self.generator = Generator(gen_config, self.img_c, self.img_sz)
        self.discrim_net = Learner(discriminator_config, self.img_c, self.img_sz)
        beta1 = 0.0
        beta2 = 0.0
        
        self.meta_gen_optim = optim.Adam(self.generator.parameters(), lr=self.meta_gen_lr,betas=(beta1, 0.9))
        self.meta_d_optim = optim.Adam(self.discrim_net.parameters(), lr=self.meta_discrim_lr,betas=(beta2, 0.9))
        if self.consine_schedule:
            self.min_learning_rate = 1e-8
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=self.meta_d_optim, T_max=self.total_epochs,
                                                                  eta_min=self.min_learning_rate)
        self.real_value = 1
        self.fake_value = 0
        self.discrim_fake = 5
    def cross_entropy(self, output,label):
        output = torch.log(output)
        loss = F.nll_loss(output,label)
        return loss
    
    def pred(self, x, weights=None, nets=None, nway=True, discrim=True, conditions=False):
        if weights == None:
            discrim_weights = self.discrim_net.parameters()
        else:
            discrim_weights = weights

        discrim_logits = self.discrim_net(x, vars=discrim_weights, bn_training=True) if discrim else None
          
        return discrim_logits

    def get_num_corrects(self, y, x=None, weights=None):
            
        with torch.no_grad():
            discrim_logits = self.pred(x, weights=weights)
            pred_q = discrim_logits.argmax(dim=1)

            nway_correct = torch.eq(pred_q, y).sum().item()
            
            other = torch.tensor([5]*len(discrim_logits)).cuda()
            other_correct = torch.eq(pred_q, other).sum().item()
        return nway_correct, other_correct

        
    def update_weights(self, net_losses, net_weights,learned_lrs, gen=False):
        if gen:
            update_lr = self.gen_update_lr
        else:
            update_lr = self.update_lr
        # grad = torch.autograd.grad(net_losses, net_weights, retain_graph=True, create_graph=self.create_graph)
        grad = torch.autograd.grad(net_losses, net_weights)
        weights = list(map(lambda p: p[1] - update_lr * p[0], zip(grad, net_weights)))

        return weights
    
    def meta_test(self,qry_img,qry_label,discrim_weight,gen_weight):
        ### discriminator train
        q_real_discrim_logits = self.pred(qry_img, weights=discrim_weight)

        real_discrim_loss_q = self.cross_entropy(q_real_discrim_logits, qry_label)

        discrim_fake_label = torch.full((self.qry_gen_num,), self.discrim_fake, dtype=torch.long, device=device) 
        noise = torch.randn(self.qry_gen_num, nz,1,1, device=device)
        q_gen = torch.empty(0,3,84,84).cuda()
        if self.qry_gen_num < self.batch_for_gradient:

            q_gen = self.generator(qry_img, noise , vars=gen_weight)
        else:
            for i in range(self.qry_gen_num//self.batch_for_gradient):

                noise_tmp = noise[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient]

                q_gen = torch.cat([q_gen,self.generator(qry_img[i*self.batch_for_gradient:(i+1)*self.batch_for_gradient], noise_tmp , vars=gen_weight)])
        q_fake_discrim_logits = self.pred(q_gen.detach(), weights=discrim_weight)
        fake_discrim_loss_q = self.cross_entropy(q_fake_discrim_logits, discrim_fake_label)
        if torch.isnan(fake_discrim_loss_q):
            print("fake d loss error")
        if torch.isnan(real_discrim_loss_q):
            print("real d loss error")
        d_loss_q = (fake_discrim_loss_q + real_discrim_loss_q)
        
        ### generator train
        gen_fake_label = torch.full((self.qry_gen_num,), self.real_value, dtype=torch.float, device=device)
        gen_q_discrim = 1 - self.pred(q_gen, weights=discrim_weight)[:,-1]
        g_loss_q = self.criterion(gen_q_discrim, gen_fake_label)
        if torch.isnan(g_loss_q):
            print("g loss error")
            
        return d_loss_q, g_loss_q

    def single_task_forward(self, x_spt, y_spt, x_qry, y_qry, update_steps,nets=None, images=False):
        
        corrects = {key: np.zeros(update_steps + 1) for key in 
                        [
                        "query_nway", # number of meta-test (query) images correctly discriminated
                        "predict_other",
                        "gen_discrim", # number of generated images correctly discriminated
                        ]}

        support_sz, c_, h, w = x_spt.size()
        nz = 100

        discrim_weights,gen_weights = [x.parameters() for x in nets]

        # this is the meta-test loss and accuracy before first update

        q_discrim,other = self.get_num_corrects(y=y_qry, weights=None, x=x_qry)
        corrects["query_nway"][0] += q_discrim
        corrects["predict_other"][0] += other
        # run the i-th task and compute loss for k-th inner update
        query_fake_label = torch.full((self.qry_gen_num,), self.discrim_fake, dtype=torch.long, device=device)
        for k in range(1, update_steps + 1):
            ## discrim loss
            noise = torch.randn(self.spy_gen_num, nz , 1, 1, device=device)
            x_gen = self.generator(x_spt, noise , vars=gen_weights)
            
            # update discrim weight

            real_discrim_logits = self.pred(x_spt, weights=discrim_weights)

            fake_discrim_logits = self.pred(x_gen, weights=discrim_weights)

            fake_label = torch.full((self.spy_gen_num,), self.discrim_fake, dtype=torch.long, device=device)
            
            real_discrim_loss = self.cross_entropy(real_discrim_logits, y_spt)
            fake_discrim_loss = self.cross_entropy(fake_discrim_logits,fake_label)
            D_loss = fake_discrim_loss + real_discrim_loss

            discrim_weights = self.update_weights(D_loss, discrim_weights,self.update_lr) 
            
            with torch.no_grad():
                x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights) 
                gen_correct,_ = self.get_num_corrects(y=query_fake_label,x=x_gen, weights=discrim_weights)
                corrects["gen_discrim"][k-1] += gen_correct
                
                q_discrim_correct,other = self.get_num_corrects(y=y_qry, x=x_qry, weights=discrim_weights)
                corrects['query_nway'][k] += q_discrim_correct
                corrects["predict_other"][k] += other
            # meta-test nway and discrim accuracy
            # [query_sz]


        
        # final gen-discrim and gen-nway accuracy
        with torch.no_grad():
            x_gen = self.generator(x_qry, self.fix_noise , vars=gen_weights)
            gen_correct,_ = self.get_num_corrects(y=query_fake_label,x=x_gen, weights=discrim_weights)
            corrects["gen_discrim"][-1] += gen_correct
        d_loss_q, g_loss_q = self.meta_test(x_qry,y_qry,discrim_weights,gen_weights)
            
        if images:
            return d_loss_q,g_loss_q, corrects, x_gen
        else:
            return d_loss_q,g_loss_q, corrects

    def forward(self, x_spt, y_spt, x_qry, y_qry,step):
        """
        :param x_spt:   [b, support_sz, c_, h, w]
        :param y_spt:   [b, support_sz]
        :param x_qry:   [b, query_sz, c_, h, w]
        :param y_qry:   [b, query_sz]
        :return:
        """
        self.current_epoch = step 
        tasks_per_batch, support_sz, c_, h, w = x_spt.size()
        query_sz = x_qry.size(1)
        g_loss_q = 0
        d_loss_q = 0
        gen_losses_q = [0 for _ in range(self.update_steps + 1)]
        discrim_losses_q = [0 for _ in range(self.update_steps + 1)]
        corrects = {key: np.zeros(self.update_steps + 1) for key in 
                        [
                        "query_nway", # number of meta-test (query) images correctly discriminated
                        "predict_other",
                        "gen_discrim", # number of generated images correctly discriminated
                        ]}
        net = [self.discrim_net,self.generator]
        for i in range(tasks_per_batch):
            d_loss_q_tmp,g_loss_q_tmp, corrects_tmp = self.single_task_forward(x_spt[i], y_spt[i], x_qry[i], y_qry[i],self.update_steps,nets = net,images=False)
            g_loss_q += g_loss_q_tmp
            d_loss_q += d_loss_q_tmp
            assert len(corrects_tmp.keys()) == len(corrects.keys())
            for key in corrects.keys():
                corrects[key] += corrects_tmp[key]
            
        # end of all tasks
        # sum over final losses on query set across all tasks
        g_loss_q /= tasks_per_batch
        self.meta_gen_optim.zero_grad()
        g_loss_q.backward()
        self.meta_gen_optim.step()        

        # optimize theta parameters
        d_loss_q /= tasks_per_batch
        self.meta_d_optim.zero_grad()
        d_loss_q.backward()
        self.meta_d_optim.step()
        
        accs = {}
        accs["query_nway"] = corrects["query_nway"] / (tasks_per_batch * query_sz)
        accs["predict_other"] = corrects["predict_other"] / (tasks_per_batch * query_sz)
        accs["gen_discrim"] = corrects["gen_discrim"] / (tasks_per_batch * self.qry_gen_num)
        return accs,d_loss_q,g_loss_q

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        """

        :param x_spt:   [support_sz, c_, h, w]
        :param y_spt:   [support_sz]
        :param x_qry:   [query_sz, c_, h, w]
        :param y_qry:   [query_sz]
        :return:
        """

        support_sz, c_, h, w = x_spt.size()

        assert len(x_spt.shape) == 4

        query_sz = x_qry.size(0)

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        
        discrim_net = deepcopy(self.discrim_net)
        generator = deepcopy(self.generator)
        net = [self.discrim_net,self.generator]
        d_loss_q,g_loss_q, corrects, imgs = self.single_task_forward(x_spt, y_spt, x_qry, y_qry,self.update_steps_test, nets=net,images=True)

        del discrim_net
        
        accs["query_nway"] = corrects["query_nway"] / (query_sz)
        accs["predict_other"] = corrects["predict_other"] / (query_sz)
        accs["gen_discrim"] = corrects["gen_discrim"] / (self.qry_gen_num)

        return accs, imgs,d_loss_q,g_loss_q


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamlGAN = Meta(args, discriminator_config, gen_config).to(device)
step = 0
path = args["save_path"]
mkdir_p(path)
best_acc = []

with tqdm.tqdm(initial=step,
                   total=int(args["epoch"])) as pbar_train:
    for _ in range(args["epoch"] * args["tasks_per_batch"]//6000):
        # fetch meta_batchsz num of episode each time
        train_dataloader = DataLoader(train_data_generator, args["tasks_per_batch"], shuffle=True, num_workers=4, pin_memory=True)

        for _, (x_spt, y_spt, x_qry, y_qry) in enumerate(train_dataloader):
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

            accs,d_loss,g_loss = mamlGAN(x_spt, y_spt, x_qry, y_qry,step)
            # accs,d_loss = mamlGAN(x_spt, y_spt, x_qry, y_qry,step)
            writer.add_scalar('Loss/train_d_loss', d_loss, step)
            writer.add_scalar('Loss/train_g_loss', g_loss, step)
            writer.add_scalar('Accuracy/query_nway', accs["query_nway"][-1], step)
            writer.add_scalar('Accuracy/gen_discrim', accs["gen_discrim"][-1], step)
            writer.add_scalar('Accuracy/predict_other', accs["predict_other"][-1],step)
            if step % 100 == 0:
                print("step " + str(step))
                print('d loss:',d_loss.item())
                print('g loss:',g_loss.item())
                print("accs",accs)


            if step % 300 == 0:  # evaluation
                db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=4, pin_memory=True)
                accs_all_test = []
                imgs_all_test = []
                d_loss_all_test = []
                g_loss_all_test = []
                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    # accs, d_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs, imgs,d_loss,g_loss = mamlGAN.finetunning(x_spt, y_spt, x_qry, y_qry)


                    accs_all_test.append(accs)
                    imgs_all_test.append(imgs.cpu().detach().numpy())
                    d_loss_all_test.append(d_loss.item())
                    g_loss_all_test.append(g_loss.item())

                imgs_all_test = np.array(imgs_all_test)
                # [b, update_step+1]
                # accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                d_loss = np.mean(np.array(d_loss_all_test))
                g_loss = np.mean(np.array(g_loss_all_test))

                print('d loss:',d_loss)
                print('g loss:',g_loss)
                print('Test acc:', accs)    

                writer.add_scalar('Loss/test_d_loss', d_loss, step)
                writer.add_scalar('Loss/test_g_loss', g_loss, step)
                writer.add_scalar('Accuracy/test_query_nway', accs["query_nway"][-1], step)
                writer.add_scalar('Accuracy/test_gen_discrim', accs["gen_discrim"][-1], step)
                writer.add_scalar('Accuracy/test_predict_other', accs["predict_other"][-1],step)
                if not len(best_acc):
                    best_acc = accs
                    best_epoch = step
                    torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                else:
                    if max(accs) > max(best_acc):
                        best_acc = accs
                        best_epoch = step
                        torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/best.pth")
                torch.save({'model_state_dict': mamlGAN.state_dict()}, "save_models/" + path + "/model_step" + str(step) + ".pth")

                save_imgs(path, imgs_all_test, step)

            step = step + 1
            pbar_train.update(1)

  0%|                                                 | 0/96000 [00:00<?, ?it/s]

step 0
d loss: 2.363431215286255
g loss: 1.2237510681152344
accs {'query_nway': array([0.17333333, 0.34666667, 0.37777778]), 'predict_other': array([0.18222222, 0.17333333, 0.16444444]), 'gen_discrim': array([0.58666667, 0.77333333, 0.77333333])}
d loss: 2.6459457620978357
g loss: 0.8663842096924782
Test acc: {'query_nway': array([0.16      , 0.53333333, 0.57333333]), 'predict_other': array([0.16      , 0.22666667, 0.26666667]), 'gen_discrim': array([0.08, 0.64, 0.64])}


  0%|                                     | 101/96000 [00:23<6:41:00,  3.99it/s]

step 100
d loss: 1.7167425155639648
g loss: 3.443291187286377
accs {'query_nway': array([0.18666667, 0.23111111, 0.29333333]), 'predict_other': array([0.00888889, 0.05777778, 0.06666667]), 'gen_discrim': array([0.90666667, 1.        , 1.        ])}


  0%|                                     | 202/96000 [00:39<3:35:30,  7.41it/s]

step 200
d loss: 1.26929771900177
g loss: 4.7000532150268555
accs {'query_nway': array([0.23111111, 0.41333333, 0.53777778]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  0%|                                     | 300/96000 [00:57<4:09:59,  6.38it/s]

step 300
d loss: 1.1243021488189697
g loss: 4.804119110107422
accs {'query_nway': array([0.24      , 0.43111111, 0.55111111]), 'predict_other': array([0.        , 0.01333333, 0.00888889]), 'gen_discrim': array([0.92, 1.  , 1.  ])}
d loss: 1.3644985228776931
g loss: 3.9385994017124175
Test acc: {'query_nway': array([0.16      , 0.38666667, 0.49333333]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  0%|▏                                    | 402/96000 [01:17<3:37:17,  7.33it/s]

step 400
d loss: 1.5945181846618652
g loss: 2.6161673069000244
accs {'query_nway': array([0.20444444, 0.12888889, 0.42666667]), 'predict_other': array([0.12888889, 0.63555556, 0.        ]), 'gen_discrim': array([1.        , 0.97333333, 0.97333333])}


  1%|▏                                    | 502/96000 [01:34<3:31:54,  7.51it/s]

step 500
d loss: 1.6586192846298218
g loss: 5.451716423034668
accs {'query_nway': array([0.20888889, 0.44      , 0.51555556]), 'predict_other': array([0.02666667, 0.05333333, 0.03111111]), 'gen_discrim': array([0.90666667, 0.90666667, 0.90666667])}


  1%|▏                                    | 600/96000 [01:50<4:46:03,  5.56it/s]

step 600
d loss: 1.2928227186203003
g loss: 6.15001106262207
accs {'query_nway': array([0.19111111, 0.37333333, 0.46666667]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}
d loss: 1.1821458883583547
g loss: 6.106702932715416
Test acc: {'query_nway': array([0.33333333, 0.53333333, 0.61333333]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  1%|▎                                    | 702/96000 [02:11<3:37:55,  7.29it/s]

step 700
d loss: 1.6518462896347046
g loss: 4.654877662658691
accs {'query_nway': array([0.03555556, 0.25333333, 0.38222222]), 'predict_other': array([0.57333333, 0.02222222, 0.06666667]), 'gen_discrim': array([0.76      , 0.98666667, 0.98666667])}


  1%|▎                                    | 802/96000 [02:27<4:23:30,  6.02it/s]

step 800
d loss: 1.1354906558990479
g loss: 4.807629108428955
accs {'query_nway': array([0.20888889, 0.45333333, 0.56888889]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  1%|▎                                    | 900/96000 [02:46<4:09:04,  6.36it/s]

step 900
d loss: 1.508542537689209
g loss: 4.495633125305176
accs {'query_nway': array([0.11111111, 0.45777778, 0.50222222]), 'predict_other': array([0.52888889, 0.01777778, 0.05333333]), 'gen_discrim': array([0.64      , 0.90666667, 0.90666667])}
d loss: 1.3559754610061645
g loss: 3.765403465926647
Test acc: {'query_nway': array([0.        , 0.29333333, 0.68      ]), 'predict_other': array([0.85333333, 0.        , 0.        ]), 'gen_discrim': array([0.16, 0.92, 0.92])}


  1%|▍                                   | 1002/96000 [03:03<3:25:47,  7.69it/s]

step 1000
d loss: 1.0568253993988037
g loss: 4.226287841796875
accs {'query_nway': array([0.26222222, 0.48444444, 0.61333333]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  1%|▍                                   | 1102/96000 [03:18<3:34:06,  7.39it/s]

step 1100
d loss: 1.2637505531311035
g loss: 8.159392356872559
accs {'query_nway': array([0.        , 0.30666667, 0.46666667]), 'predict_other': array([1., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  1%|▍                                   | 1200/96000 [03:35<5:19:15,  4.95it/s]

step 1200
d loss: 1.0161058902740479
g loss: 5.44084358215332
accs {'query_nway': array([0.19555556, 0.51111111, 0.59555556]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}
d loss: 1.1787633046507835
g loss: 5.65355912744999
Test acc: {'query_nway': array([0.34666667, 0.50666667, 0.73333333]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  1%|▍                                   | 1302/96000 [03:55<3:45:05,  7.01it/s]

step 1300
d loss: 1.3308796882629395
g loss: 3.334371566772461
accs {'query_nway': array([0.14222222, 0.16      , 0.48      ]), 'predict_other': array([0.11111111, 0.49333333, 0.00888889]), 'gen_discrim': array([1., 1., 1.])}


  1%|▌                                   | 1402/96000 [04:11<3:34:18,  7.36it/s]

step 1400
d loss: 1.4642432928085327
g loss: 8.723167419433594
accs {'query_nway': array([0.09777778, 0.32      , 0.41777778]), 'predict_other': array([0.04, 0.  , 0.  ]), 'gen_discrim': array([1., 1., 1.])}


  2%|▌                                   | 1500/96000 [04:29<3:33:16,  7.38it/s]

step 1500
d loss: 1.1851251125335693
g loss: 6.390166282653809
accs {'query_nway': array([0.22222222, 0.46222222, 0.55555556]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}
d loss: 1.1669722333550454
g loss: 5.367955774068832
Test acc: {'query_nway': array([0.22666667, 0.37333333, 0.52      ]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  2%|▌                                   | 1602/96000 [04:50<4:35:11,  5.72it/s]

step 1600
d loss: 0.9553803205490112
g loss: 6.40683650970459
accs {'query_nway': array([0.22222222, 0.56      , 0.61333333]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  2%|▋                                   | 1702/96000 [05:07<4:51:23,  5.39it/s]

step 1700
d loss: 1.4141197204589844
g loss: 9.670666694641113
accs {'query_nway': array([0.18222222, 0.34222222, 0.42666667]), 'predict_other': array([0.00444444, 0.00444444, 0.00444444]), 'gen_discrim': array([1., 1., 1.])}


  2%|▋                                   | 1800/96000 [05:24<4:09:54,  6.28it/s]

step 1800
d loss: 1.19931960105896
g loss: 6.67488956451416
accs {'query_nway': array([0.23555556, 0.44      , 0.56444444]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}
d loss: 1.0563895933330059
g loss: 6.353534096479416
Test acc: {'query_nway': array([0.13333333, 0.38666667, 0.46666667]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  2%|▋                                   | 1902/96000 [05:44<4:18:18,  6.07it/s]

step 1900
d loss: 0.7681993246078491
g loss: 7.84945821762085
accs {'query_nway': array([0.18666667, 0.47555556, 0.67555556]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  2%|▊                                   | 2001/96000 [06:04<6:18:25,  4.14it/s]

step 2000
d loss: 1.2263422012329102
g loss: 9.844486236572266
accs {'query_nway': array([0.21333333, 0.39555556, 0.52888889]), 'predict_other': array([0., 0., 0.]), 'gen_discrim': array([1., 1., 1.])}


  2%|▊                                   | 2021/96000 [06:07<3:42:12,  7.05it/s]