In [1]:
import  torch, os
import  numpy as np
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
from utils.results.dataloaders_mini_imagenet import train_data_gen , test_data_gen
# from maml_meta import Meta
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from    torch import optim
from maml_learner import Learner
from    torch.nn import functional as F
from    copy import deepcopy

import json
import shutil
import torchvision.utils as vutils

In [2]:
with open("results_configs/mini_results2.json") as json_file:
    args = json.load(json_file)
print(args)
if os.path.exists("maml_runs/" + args["save_path"]):
    shutil.rmtree("maml_runs/" + args["save_path"])
writer = SummaryWriter("results_runs/" + args["save_path"])

{'epoch': 60000, 'n_way': 5, 'k_spt': 5, 'k_qry': 15, 'data_augmentation_num': 3, 'img_sz': 84, 'task_num': 4, 'img_c': 3, 'meta_lr': 0.001, 'update_lr': 0.01, 'update_step': 5, 'update_step_test': 5, 'loss': 'cross_entropy', 'number_of_training_steps_per_iter': 5, 'multi_step_loss_num_epochs': 10, 'spy_gan_num': 1, 'qry_gan_num': 5, 'num_distractor': 5, 'gan': 0, 'spy_distractor_num': 1, 'qry_distractor_num': 3, 'batch_for_gradient': 25, 'fm': 32, 'no_save': 0, 'learn_inner_lr': 0, 'create_graph': 0, 'msl': 1, 'single_fast_test': 0, 'consine_schedule': 15000, 'eta_min': 1e-06, 'save_path': 'mini_results2'}


In [3]:
def mkdir_p(path):
    if not os.path.exists("model_results/" + path):
        os.makedirs("model_results/" + path)     

In [4]:
# BASE
fm = args["fm"]
config = [
    ("conv2d", [fm, 3, 3, 3, 1, 0]),
    ("leakyrelu", [0.2,True]),
    ("bn", [fm]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [fm, fm, 3, 3, 1, 0]),
    ("leakyrelu", [0.2,True]),
    ("bn", [fm]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [fm, fm, 3, 3, 1, 0]),
    ("leakyrelu", [0.2,True]),
    ("bn", [fm]),
    ("max_pool2d", [2, 2, 0]),
    ("conv2d", [fm, fm, 3, 3, 1, 0]),
    ("leakyrelu", [0.2,True]),
    ("bn", [fm]),
    ("max_pool2d", [2, 1, 0]),
    ("flatten", []),
    ("linear", [6, fm * 5 * 5])
]

In [5]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load datasets/mini_imagenet
load complete time 1.764672040939331
load datasets/mini_imagenet
load complete time 0.038918495178222656


In [6]:
class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, config):
        """

        :param args:
        """
        super(Meta, self).__init__()

        self.update_lr = args["update_lr"]
        self.meta_lr = args["meta_lr"]
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.task_num = args["task_num"]
        self.update_step = args["update_step"]
        self.update_step_test = args["update_step_test"]
        self.distractor = args["num_distractor"]
        self.net = Learner(config, args["img_c"], args["img_sz"])
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)
        if args["consine_schedule"]:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.meta_optim, T_max=args["consine_schedule"], eta_min=args["eta_min"])
        self.device = torch.device("cuda")
        self.gan = args["gan"]
        self.multi_step_loss_num_epochs = args["multi_step_loss_num_epochs"]
        
    def get_per_step_loss_importance_vector(self):

        loss_weights = np.ones(shape=(self.update_step)) * (
                1.0 / self.update_step)
        decay_rate = 1.0 / self.update_step / self.multi_step_loss_num_epochs
        min_value_for_non_final_losses = 0.03 / self.update_step
        for i in range(len(loss_weights) - 1):
            curr_value = np.maximum(loss_weights[i] - (self.current_epoch * decay_rate), min_value_for_non_final_losses)
            loss_weights[i] = curr_value

        curr_value = np.minimum(
            loss_weights[-1] + (self.current_epoch * (self.update_step - 1) * decay_rate),
            1.0 - ((self.update_step - 1) * min_value_for_non_final_losses))
        loss_weights[-1] = curr_value
        loss_weights = torch.Tensor(loss_weights).to(device=self.device)
        return loss_weights

    def forward(self, x_spt, y_spt, x_qry, y_qry, current_epoch,unlabel_spt_image=None, unlabel_qry_image=None,gan_spt=None, gan_qry=None):
        """

        :param x_spt:   [b, setsz, c_, h, w]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, c_, h, w]
        :param y_qry:   [b, querysz]
        :return:
        """
        self.current_epoch = current_epoch
        per_step_loss_importance_vectors = self.get_per_step_loss_importance_vector()
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)
        if self.gan :
            gan_sptsz = gan_spt.size(1)
            gan_qrysz = gan_qry.size(1)
        else:
            gan_sptsz = 0
            gan_qrysz = 0
        if self.distractor or self.gan:
            corrects = {}
            corrects["total_query_nway"] = np.zeros(self.update_step + 1)
            if self.distractor:
                unlabel_querysz = unlabel_qry.size(1)
                corrects["query_nway_recall"] = np.zeros(self.update_step + 1)
                corrects["label_query_nway_recall"] = np.zeros(self.update_step + 1)
                corrects["distractor_query_nway_recall"] = np.zeros(self.update_step + 1)
            if self.gan :
                corrects["gan_query_nway"] = np.zeros(self.update_step + 1)
        else:
            corrects = {key: np.zeros(self.update_step + 1) for key in 
                [
                "query_nway_recall",
                "label_query_nway_recall"
                ]}
        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i

        for i in range(task_num):
            spt_image = x_spt[i]
            spt_label = y_spt[i]
            qry_image = x_qry[i]
            qry_label = y_qry[i]
            if self.distractor:
                spt_image = torch.concat((spt_image,unlabel_spt[i]))
                spt_unlabel_label = torch.full((unlabel_spt.size(1),), 5, dtype=torch.long,device=self.device)
                spt_label = torch.cat((spt_label,spt_unlabel_label))
                qry_image = torch.concat((qry_image,unlabel_qry[i]))
                qry_unlabel_label = torch.full((unlabel_qry.size(1),), 5, dtype=torch.long,device=self.device)
                qry_label = torch.cat((qry_label,qry_unlabel_label))
            if self.gan :
                spt_image = torch.concat((spt_image,gan_spt[i]))
                spt_gan_label = torch.full((gan_spt.size(1),), 5, dtype=torch.long,device=self.device)
                spt_label = torch.cat((spt_label,spt_gan_label))
                qry_image = torch.concat((qry_image,gan_qry[i]))
                qry_gan_label = torch.full((gan_qry.size(1),), 5, dtype=torch.long,device=self.device)
                qry_label = torch.cat((qry_label,qry_gan_label))

            # 1. run the i-th task and compute loss for k=0
            logits = self.net(spt_image, vars=None, bn_training=True)
            loss = F.cross_entropy(logits, spt_label)
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                if self.distractor or self.gan:
                    total_logits_q = self.net(qry_image, self.net.parameters(), bn_training=False)
                    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                    corrects["total_query_nway"][0] += total_q_correct
                    loss_q = F.cross_entropy(total_logits_q, qry_label)
                    losses_q[0] += loss_q
                    if self.distractor:
                        label_logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=False)
                        label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                        label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                        corrects["label_query_nway_recall"][0] += label_pred_q_correct
                        label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
                        label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                        corrects["query_nway_recall"][0] += label_pred_q_correct
                        

                        unlabel_logits_q = self.net(unlabel_qry[i], self.net.parameters(), bn_training=False)
                        unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                        other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                        corrects["distractor_query_nway_recall"][0] += other
                    if self.gan :
                        gan_logits_q = self.net(gan_qry[i], self.net.parameters(), bn_training=False)
                        gan_pred_q = F.softmax(gan_logits_q, dim=1).argmax(dim=1)
                        gan_counts = torch.eq(gan_pred_q, qry_gan_label).sum().item()
                        corrects["gan_query_nway"][0] += gan_counts
                else:
                    logits_q = self.net(qry_image, self.net.parameters(), bn_training=True)
                    loss_q = F.cross_entropy(logits_q, qry_label)
                    pred_q = F.softmax(logits_q[:,:-1], dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects["query_nway_recall"][0] += q_discrim_correct
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects["label_query_nway_recall"][0] += q_discrim_correct
                    losses_q[0] += loss_q
            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                if self.distractor or self.gan:
                    
                    total_logits_q = self.net(qry_image,fast_weights , bn_training=False)
                    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                    corrects["total_query_nway"][1] += total_q_correct
                    loss_q = F.cross_entropy(total_logits_q, qry_label)
                    losses_q[1] += loss_q
                    if self.distractor:
                        label_logits_q = self.net(x_qry[i], fast_weights, bn_training=False)
                        label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                        label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                        corrects["label_query_nway_recall"][1] += label_pred_q_correct
                        label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
                        label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                        corrects["query_nway_recall"][1] += label_pred_q_correct
                        
                        unlabel_logits_q = self.net(unlabel_qry[i], fast_weights, bn_training=False)
                        unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                        other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                        corrects["distractor_query_nway_recall"][1] += other
                else:
                    logits_q = self.net(qry_image, fast_weights, bn_training=False)
                    loss_q = F.cross_entropy(logits_q, qry_label)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects["query_nway_recall"][1] += q_discrim_correct
                    pred_q = F.softmax(logits_q[:,:-1], dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects["label_query_nway_recall"][1] += q_discrim_correct
                    losses_q[1] += loss_q

            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(spt_image, fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, spt_label)
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

                logits_q = self.net(qry_image, fast_weights, bn_training=True)
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.cross_entropy(logits_q, qry_label)
                losses_q[k + 1] += loss_q
                with torch.no_grad():
                    if self.distractor or self.gan:
                        total_logits_q = self.net(qry_image, fast_weights, bn_training=False)
                        total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                        total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                        corrects["total_query_nway"][k+1] += total_q_correct
                        if self.distractor:
                            label_logits_q = self.net(x_qry[i], fast_weights, bn_training=False)
                            label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                            label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                            corrects["label_query_nway_recall"][k+1] += label_pred_q_correct
                            label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
                            label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                            corrects["query_nway_recall"][k+1] += label_pred_q_correct
                            
                            unlabel_logits_q = self.net(unlabel_qry[i], fast_weights, bn_training=False)
                            unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                            other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                            corrects["distractor_query_nway_recall"][k+1] += other
                    else:
                        pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                        q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                        corrects["query_nway_recall"][k+1] += q_discrim_correct
                        pred_q = F.softmax(logits_q[:,:-1], dim=1).argmax(dim=1)
                        q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                        corrects["label_query_nway_recall"][k+1] += q_discrim_correct
        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = 0

        for num_step, loss in enumerate(losses_q[1:]):
            loss_q = loss_q + per_step_loss_importance_vectors[num_step] * loss / task_num
        # optimize theta parameters
        self.meta_optim.zero_grad()
        loss_q.backward()

        self.meta_optim.step()
        
        accs = {}
        if (self.distractor or self.gan):
            accs["total_query_nway"] = corrects["total_query_nway"] / (task_num * (querysz + unlabel_querysz + gan_qrysz))
            if self.distractor:
                accs["label_query_nway_recall"] = corrects["label_query_nway_recall"] / (task_num * querysz)
                accs["query_nway_recall"] = corrects["query_nway_recall"] / (task_num * querysz)
                accs["distractor_query_nway_recall"] = corrects["distractor_query_nway_recall"] / (task_num * unlabel_querysz)
            if gan_qrysz:
                accs["gan_query_nway"] = corrects["gan_query_nway"] / (task_num * gan_qrysz)
        else:
            accs["query_nway_recall"] = corrects["query_nway_recall"] / (task_num * querysz)
            accs["label_query_nway_recall"] = corrects["label_query_nway_recall"] / (task_num * querysz)
        return accs,loss_q


    def finetunning(self, x_spt, y_spt, x_qry, y_qry, unlabel_spt=None, unlabel_qry=None, gan_spt=None, gan_qry=None):

        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)

        corrects = {}
        corrects["total_query_nway"] = np.zeros(self.update_step_test + 1)

        unlabel_querysz = unlabel_qry.size(0)

        corrects["query_nway_recall"] = np.zeros(self.update_step_test + 1)
        corrects["distractor_query_nway_recall"] = np.zeros(self.update_step_test + 1)

        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)
        spt_image = x_spt
        spt_label = y_spt
        qry_image = x_qry
        qry_label = y_qry
        if self.distractor:
            spt_image = torch.concat((spt_image,unlabel_spt))
            spt_unlabel_label = torch.full((unlabel_spt.size(0),), 5, dtype=torch.long,device=self.device)
            spt_label = torch.cat((spt_label,spt_unlabel_label))

        qry_image = torch.concat((qry_image,unlabel_qry))
        qry_unlabel_label = torch.full((unlabel_qry.size(0),), 5, dtype=torch.long,device=self.device)
        qry_label = torch.cat((qry_label,qry_unlabel_label))

        # 1. run the i-th task and compute loss for k=0
        logits = net(spt_image)
        loss = F.cross_entropy(logits, spt_label)

        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))

        # this is the loss and accuracy before first update
        with torch.no_grad():

            total_logits_q = self.net(qry_image,self.net.parameters() , bn_training=False)
            total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
            total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
            corrects["total_query_nway"][0] += total_q_correct
            loss_q = F.cross_entropy(total_logits_q, qry_label)

            label_logits_q = self.net(x_qry, self.net.parameters(), bn_training=False)
            label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
            label_pred_q_correct = torch.eq(label_pred_q, y_qry).sum().item()
            corrects["query_nway_recall"][0] += label_pred_q_correct

            unlabel_logits_q = self.net(unlabel_qry, self.net.parameters(), bn_training=False)
            unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
            other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
            corrects["distractor_query_nway_recall"][0] += other


        # this is the loss and accuracy after the first update
        with torch.no_grad():

            total_logits_q = self.net(qry_image,fast_weights , bn_training=False)
            total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
            
            total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
            corrects["total_query_nway"][1] += total_q_correct
            loss_q = F.cross_entropy(total_logits_q, qry_label)

            label_logits_q = self.net(x_qry, fast_weights, bn_training=False)
            label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
            label_pred_q_correct = torch.eq(label_pred_q, y_qry).sum().item()
            corrects["query_nway_recall"][1] += label_pred_q_correct

            unlabel_logits_q = self.net(unlabel_qry, fast_weights, bn_training=False)
            
            unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
            other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
            corrects["distractor_query_nway_recall"][1] += other


        for k in range(1, self.update_step_test):
            # 1. run the i-th task and compute loss for k=1~K-1
            logits = net(spt_image, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, spt_label)
            # 2. compute grad on theta_pi
            grad = torch.autograd.grad(loss, fast_weights)
            # 3. theta_pi = theta_pi - train_lr * grad
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

            logits_q = net(qry_image, fast_weights, bn_training=True)
            # loss_q will be overwritten and just keep the loss_q on last update step.
            loss_q = F.cross_entropy(logits_q, qry_label)

            with torch.no_grad():

                total_logits_q = self.net(qry_image,fast_weights , bn_training=False)

                total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                corrects["total_query_nway"][k+1] += total_q_correct
                loss_q = F.cross_entropy(total_logits_q, qry_label)

                label_logits_q = self.net(x_qry, fast_weights, bn_training=False)
                label_pred_q = F.softmax(label_logits_q[:,:-1], dim=1).argmax(dim=1)
                label_pred_q_correct = torch.eq(label_pred_q, y_qry).sum().item()
                corrects["query_nway_recall"][k+1] += label_pred_q_correct

                unlabel_logits_q = self.net(unlabel_qry, fast_weights, bn_training=False)
                unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                corrects["distractor_query_nway_recall"][k+1] += other
        del net
        accs = {}

        # accs["total_query_nway"] = corrects["total_query_nway"] / (querysz)
        accs["total_query_nway"] = corrects["total_query_nway"] / (querysz + unlabel_querysz)
        accs["query_nway_recall"] = corrects["query_nway_recall"] / querysz
        accs["distractor_query_nway_recall"] = corrects["distractor_query_nway_recall"] / (unlabel_querysz)

        return accs


In [7]:
device = torch.device("cuda")
maml = Meta(args, config).to(device)

In [8]:
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))

In [None]:
path = args["save_path"]
step = 0
mkdir_p(path)
for epoch in range(args["epoch"]//1000):
        # fetch meta_batchsz num of episode each time

    train_dataloader = DataLoader(train_data_generator, args["task_num"], shuffle=True, num_workers=1, pin_memory=True)
    x_spt = y_spt = x_qry = y_qry = unlabel_spt = unlabel_qry = gan_qry = gan_spt = 0
    for idx,data  in enumerate(train_dataloader):
        if len(data) == 4:
            (x_spt, y_spt, x_qry, y_qry) = data
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
            if args["gan"]:
                accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry,step,gan_spt=gan_spt, gan_qry=gan_qry)
            else:
                accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry,step)
        else:
            (x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry) = data
            x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device), \
            unlabel_spt.to(device), unlabel_qry.to(device)
            if args["spy_gan_num"]:
                accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry,step,unlabel_spt_image=unlabel_spt, unlabel_qry_image=unlabel_qry,gan_spt=gan_spt, gan_qry=gan_qry)
            else:
                accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry,step,unlabel_spt_image=unlabel_spt, unlabel_qry_image=unlabel_qry)

        writer.add_scalar("Loss/train_loss", loss_q, step)
        if "total_query_nway" in accs:
            writer.add_scalar("Accuracy/train_total_query_nway", accs["total_query_nway"][-1], step)
        if "label_query_nway_recall" in accs:
            writer.add_scalar("Accuracy/train_label_query_nway_recall", accs["label_query_nway_recall"][-1], step)
        if "distractor_query_nway_recall" in accs:
            writer.add_scalar("Accuracy/train_distractor_query_nway_recall", accs["distractor_query_nway_recall"][-1], step)
        if "query_nway_recall" in accs:
            writer.add_scalar("Accuracy/train_query_nway_recall", accs["query_nway_recall"][-1], step)
        if "gan_query_nway_recall" in accs:
            writer.add_scalar("Accuracy/train_gan_query_nway_recall", accs["gan_query_nway_recall"][-1], step)
        if "query_nway" in accs:
            writer.add_scalar("Accuracy/train_query_nway_recall", accs["query_nway_recall"][-1], step)
        if step % 100 == 0:
            print("step:", step, "\ttraining acc:", accs)
        if step % 500 == 0:  # evaluation
            db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=1, pin_memory=True)
            accs_all_test = {
                            "total_query_nway":[],
                            "distractor_query_nway_recall":[],
                            "query_nway_recall":[],
                            "label_query_nway_recall":[],
                            "gan_query_nway":[]
            }
            for test_data in db_test:

                if len(test_data) == 4:
                    x_spt, y_spt, x_qry, y_qry = test_data
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)

                else:
                    x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = test_data
                    x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device),\
                                                unlabel_spt.squeeze(0).to(device), unlabel_qry.squeeze(0).to(device)


                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry)
                if "total_query_nway" in accs:
                    accs_all_test["total_query_nway"].append(accs["total_query_nway"])
                if "label_query_nway_recall" in accs:
                    accs_all_test["label_query_nway_recall"].append(accs["label_query_nway_recall"])
                if "distractor_query_nway_recall" in accs:
                    accs_all_test["distractor_query_nway_recall"].append(accs["distractor_query_nway_recall"])
                if "gan_query_nway" in accs:
                    accs_all_test["gan_query_nway"].append(accs["gan_query_nway"])
                if "query_nway_recall" in accs:
                    accs_all_test["query_nway_recall"].append(accs["query_nway_recall"])
            # [b, update_step+1]
            if "total_query_nway" in accs:
                accs["total_query_nway"] = np.array(accs_all_test["total_query_nway"]).mean(axis=0).astype(np.float16)
                writer.add_scalar("Accuracy/test_total_query_nway_accuracy", accs["total_query_nway"][-1], step)
            if "label_query_nway_recall" in accs:
                accs["label_query_nway_recall"] = np.array(accs_all_test["label_query_nway_recall"]).mean(axis=0).astype(np.float16)
                
                writer.add_scalar("Accuracy/test_label_query_nway_accuracy", accs["label_query_nway_recall"][-1], step)
            if "distractor_query_nway_recall" in accs:
                accs["distractor_query_nway_recall"] = np.array(accs_all_test["distractor_query_nway_recall"]).mean(axis=0).astype(np.float16)
                writer.add_scalar("Accuracy/test_distractor_query_nway_recall_accuracy", accs["distractor_query_nway_recall"][-1], step)
            if "gan_query_nway" in accs:
                accs["gan_query_nway"] = np.array(accs_all_test["gan_query_nway"]).mean(axis=0).astype(np.float16)
                writer.add_scalar("Accuracy/test_gan_query_nway_accuracy", accs["gan_query_nway"][-1], step)
            if "query_nway_recall" in accs:
                accs["query_nway_recall"] = np.array(accs_all_test["query_nway_recall"]).mean(axis=0).astype(np.float16)
                writer.add_scalar("Accuracy/test_query_nway_accuracy", accs["query_nway_recall"][-1], step)

            print("Test acc:", accs)

            torch.save(maml.state_dict(), "model_results/" + path + "/model_step" + str(step) + ".pt")
        step += 1

step: 0 	training acc: {'total_query_nway': array([0.16388889, 0.18888889, 0.225     , 0.24722222, 0.26111111,
       0.275     ]), 'label_query_nway_recall': array([0.18666667, 0.21      , 0.25333333, 0.28      , 0.29      ,
       0.3       ]), 'query_nway_recall': array([0.19666667, 0.24      , 0.27333333, 0.30333333, 0.31      ,
       0.32333333]), 'distractor_query_nway_recall': array([0.05      , 0.08333333, 0.08333333, 0.08333333, 0.11666667,
       0.13333333])}
Test acc: {'total_query_nway': array([0.1694, 0.1777, 0.1888, 0.1917, 0.2028, 0.1973], dtype=float16), 'query_nway_recall': array([0.2234, 0.2267, 0.2367, 0.2367, 0.2534, 0.25  ], dtype=float16), 'distractor_query_nway_recall': array([0.05   , 0.1    , 0.1333 , 0.1333 , 0.1333 , 0.11664],
      dtype=float16)}
step: 100 	training acc: {'total_query_nway': array([0.14444444, 0.22777778, 0.25833333, 0.26666667, 0.3       ,
       0.3       ]), 'label_query_nway_recall': array([0.14333333, 0.23666667, 0.27333333, 0.29    

In [None]:
for test_data in db_test:

    if len(test_data) == 4:
        x_spt, y_spt, x_qry, y_qry = test_data
        x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
    else:
        x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = test_data
        x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                     x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device),\
                                    unlabel_spt.squeeze(0).to(device), unlabel_qry.squeeze(0).to(device)

In [None]:
update_step_test = 10
corrects = {}
corrects["total_query_nway"] = np.zeros(10 + 1)

unlabel_querysz = unlabel_qry.size(0)

net = maml.net
spt_image = x_spt
spt_label = y_spt
qry_image = x_qry
qry_label = y_qry

with torch.no_grad():

    total_logits_q = net(qry_image,net.parameters() , bn_training=False)
    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
    corrects["total_query_nway"][0] += total_q_correct
    loss_q = F.cross_entropy(total_logits_q, qry_label)
    
logits = net(spt_image)
loss = F.cross_entropy(logits, spt_label)

grad = torch.autograd.grad(loss, net.parameters())
fast_weights = list(map(lambda p: p[1] - 0.0001 * p[0], zip(grad, net.parameters())))

# this is the loss and accuracy before first update



# this is the loss and accuracy after the first update
with torch.no_grad():

    total_logits_q = net(qry_image,fast_weights , bn_training=False)
    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)

    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
    corrects["total_query_nway"][1] += total_q_correct
    loss_q = F.cross_entropy(total_logits_q, qry_label)


for k in range(1, 10):
    # 1. run the i-th task and compute loss for k=1~K-1
    logits = net(spt_image, fast_weights, bn_training=True)
    loss = F.cross_entropy(logits, spt_label)
    # 2. compute grad on theta_pi
    grad = torch.autograd.grad(loss, fast_weights)
    # 3. theta_pi = theta_pi - train_lr * grad
    fast_weights = list(map(lambda p: p[1] - 0.0001 * p[0], zip(grad, fast_weights)))

    logits_q = net(qry_image, fast_weights, bn_training=True)
    # loss_q will be overwritten and just keep the loss_q on last update step.
    loss_q = F.cross_entropy(logits_q, qry_label)

    with torch.no_grad():

        total_logits_q = net(qry_image,fast_weights , bn_training=False)

        total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
        # print(F.softmax(total_logits_q, dim=1))
        total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
        corrects["total_query_nway"][k+1] += total_q_correct
        loss_q = F.cross_entropy(total_logits_q, qry_label)

In [None]:
corrects

In [None]:
import  torchvision.transforms as transforms
import matplotlib.pyplot as plt
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ])])
inv_tensor = invTrans(qry_image)
for i in range(len(inv_tensor)):


    img = (inv_tensor[i])
    # img = (img*255).long()
    plt.figure()
    plt.imshow(img.cpu().numpy().transpose(1,2,0))
    plt.show()
    # break

In [None]:
img.size()

In [None]:
corrects

In [None]:
import numpy as np
values = np.array([1,2,3,1,2,4,5,6,3,2,1])
searchval = 3
np.where(values == searchval)[0]

In [None]:
random.choice([5, 11])

In [None]:
x_spt[0][5]

In [None]:
import  torchvision.transforms as transforms
import matplotlib.pyplot as plt
invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ])])
inv_tensor = invTrans(x_spt[0])
for i in range(len(x_spt[0])):
    print(i)
    i = (x_spt[2][i])
    # i = (i*255).long()
    plt.figure()
    plt.imshow(i.cpu().numpy().transpose(1,2,0))
    plt.show()
    # break

In [None]:
x_spt[0][0]