In [1]:
import  torch, os
import  numpy as np
import  scipy.stats
from    torch.utils.data import DataLoader
from    torch.optim import lr_scheduler
import  random, sys, pickle
from utils.dataloader import train_data_gen , test_data_gen
# from maml_meta import Meta
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from    torch import optim
from maml_learner import Learner
from    torch.nn import functional as F
from    copy import deepcopy
import json
import shutil

In [2]:
with open('maml_configs/5way1shot3distractor.json') as json_file:
    args = json.load(json_file)
print(args)
# if os.path.exists('maml_runs/' + args["save_path"]):
#     shutil.rmtree('maml_runs/' + args["save_path"])
writer = SummaryWriter('maml_runs/' + args["save_path"])

{'epoch': 96000, 'n_way': 5, 'k_spt': 1, 'k_qry': 15, 'img_sz': 84, 'task_num': 5, 'img_c': 3, 'meta_lr': 0.001, 'update_lr': 0.01, 'update_step': 5, 'update_step_test': 10, 'loss': 'cross_entropy', 'min_learning_rate': 1e-15, 'number_of_training_steps_per_iter': 4, 'multi_step_loss_num_epochs': 15, 'spy_gen_num': 5, 'qry_gen_num': 25, 'num_distractor': 3, 'spy_distractor_num': 1, 'qry_distractor_num': 15, 'batch_for_gradient': 25, 'no_save': 0, 'learn_inner_lr': 0, 'create_graph': 0, 'msl': 0, 'single_fast_test': 0, 'consine_schedule': 0, 'save_path': '5way1shot5distractor'}


In [3]:
def mkdir_p(path):
        
    if not os.path.exists("maml/" + path):
        os.makedirs("maml/" + path)     

In [4]:
spt_size = args["k_spt"] * args["n_way"]
qry_size = args["k_qry"] * args["n_way"]

In [5]:
config = [
    ('conv2d', [32, 3, 3, 3, 1, 0]),
    ('relu', [True]),
    ('bn', [32]),
    ('max_pool2d', [2, 2, 0]),
    ('conv2d', [32, 32, 3, 3, 1, 0]),
    ('relu', [True]),
    ('bn', [32]),
    ('max_pool2d', [2, 2, 0]),
    ('conv2d', [32, 32, 3, 3, 1, 0]),
    ('relu', [True]),
    ('bn', [32]),
    ('max_pool2d', [2, 2, 0]),
    ('conv2d', [32, 32, 3, 3, 1, 0]),
    ('relu', [True]),
    ('bn', [32]),
    ('max_pool2d', [2, 1, 0]),
    ('flatten', []),
    ('linear', [6, 32 * 5 * 5])
]

In [6]:
train_data_generator = train_data_gen(args)
test_data_generator = test_data_gen(args)

load datasets/BelgiumTSC
load complete time 0.46569180488586426
load datasets/ArTS
load complete time 0.45115089416503906
load datasets/chinese_traffic_sign
load complete time 0.8656120300292969
load datasets/CVL
load complete time 0.5910384654998779
load datasets/FullJCNN2013
load complete time 0.3078114986419678
load datasets/logo_2k
load complete time 1.0549578666687012
load datasets/GTSRB
load complete time 0.1003561019897461
load datasets/DFG
load complete time 0.03681230545043945


In [7]:
train_dataloader = DataLoader(train_data_generator, args["task_num"], shuffle=True, num_workers=1, pin_memory=True)

In [8]:
class Meta(nn.Module):
    """
    Meta Learner
    """
    def __init__(self, args, config):
        """

        :param args:
        """
        super(Meta, self).__init__()

        self.update_lr = args["update_lr"]
        self.meta_lr = args["meta_lr"]
        self.n_way = args["n_way"]
        self.k_spt = args["k_spt"]
        self.k_qry = args["k_qry"]
        self.task_num = args["task_num"]
        self.update_step = args["update_step"]
        self.update_step_test = args["update_step_test"]

        self.distractor = args["num_distractor"]
        self.net = Learner(config, args["img_c"], args["img_sz"])
        self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)
        self.device = torch.device('cuda')

    def forward(self, x_spt, y_spt, x_qry, y_qry, unlabel_spt_image=None, unlabel_qry_image=None):
        """

        :param x_spt:   [b, setsz, c_, h, w]
        :param y_spt:   [b, setsz]
        :param x_qry:   [b, querysz, c_, h, w]
        :param y_qry:   [b, querysz]
        :return:
        """
        task_num, setsz, c_, h, w = x_spt.size()
        querysz = x_qry.size(1)
        if self.distractor:
            unlabel_querysz = unlabel_qry.size(1)
            corrects = {key: np.zeros(self.update_step + 1) for key in 
                            [
                            "total_query_nway",
                            "label_query_nway",
                            "unlabel_query_nway"
                            ]}
        else:
            corrects = {key: np.zeros(self.update_step + 1) for key in 
                [
                "query_nway"
                ]}
        losses_q = [0 for _ in range(self.update_step + 1)]  # losses_q[i] is the loss on step i

        for i in range(task_num):
            if self.distractor:
                spt_image = torch.concat((x_spt[i],unlabel_spt[i]))
                spt_unlabel_label = torch.full((unlabel_spt.size(1),), 5, dtype=torch.long,device=self.device)
                spt_label = torch.cat((y_spt[i],spt_unlabel_label))
                qry_image = torch.concat((x_qry[i],unlabel_qry[i]))
                qry_unlabel_label = torch.full((unlabel_qry.size(1),), 5, dtype=torch.long,device=self.device)
                qry_label = torch.cat((y_qry[i],qry_unlabel_label))
            else:
                spt_image = x_spt[i]
                spt_label = y_spt[i]
                qry_image = x_qry[i]
                qry_label = y_qry[i]
            # 1. run the i-th task and compute loss for k=0
            logits = self.net(spt_image, vars=None, bn_training=True)
            loss = F.cross_entropy(logits, spt_label)
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

            # this is the loss and accuracy before first update
            with torch.no_grad():
                # [setsz, nway]
                if self.distractor:
                    total_logits_q = self.net(qry_image, self.net.parameters(), bn_training=False)
                    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                    corrects['total_query_nway'][0] += total_q_correct
                    loss_q = F.cross_entropy(total_logits_q, qry_label)
                    losses_q[0] += loss_q

                    label_logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=False)
                    label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                    label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                    corrects['label_query_nway'][0] += label_pred_q_correct

                    unlabel_logits_q = self.net(unlabel_qry[i], self.net.parameters(), bn_training=False)
                    unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                    other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                    corrects["unlabel_query_nway"][0] += other
                else:
                    logits_q = self.net(qry_image, self.net.parameters(), bn_training=True)
                    loss_q = F.cross_entropy(logits_q, qry_label)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects['query_nway'][0] += q_discrim_correct
                    losses_q[0] += loss_q
            # this is the loss and accuracy after the first update
            with torch.no_grad():
                # [setsz, nway]
                if self.distractor:
                    total_logits_q = self.net(qry_image,fast_weights , bn_training=False)
                    total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                    total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                    corrects['total_query_nway'][1] += total_q_correct
                    loss_q = F.cross_entropy(total_logits_q, qry_label)
                    losses_q[1] += loss_q

                    label_logits_q = self.net(x_qry[i], fast_weights, bn_training=False)
                    label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                    label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                    corrects['label_query_nway'][1] += label_pred_q_correct

                    unlabel_logits_q = self.net(unlabel_qry[i], self.net.parameters(), bn_training=False)
                    unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                    other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                    corrects["unlabel_query_nway"][1] += other
                else:
                    logits_q = self.net(qry_image, fast_weights, bn_training=False)
                    loss_q = F.cross_entropy(logits_q, qry_label)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects['query_nway'][1] += q_discrim_correct
                    losses_q[1] += loss_q
            for k in range(1, self.update_step):
                # 1. run the i-th task and compute loss for k=1~K-1
                logits = self.net(spt_image, fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, spt_label)
                # 2. compute grad on theta_pi
                grad = torch.autograd.grad(loss, fast_weights)
                # 3. theta_pi = theta_pi - train_lr * grad
                fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

                logits_q = self.net(qry_image, fast_weights, bn_training=True)
                # loss_q will be overwritten and just keep the loss_q on last update step.
                loss_q = F.cross_entropy(logits_q, qry_label)
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    if self.distractor:
                        total_logits_q = self.net(qry_image, fast_weights, bn_training=False)
                        total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                        total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                        corrects['total_query_nway'][k+1] += total_q_correct

                        label_logits_q = self.net(x_qry[i], fast_weights, bn_training=False)
                        label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                        label_pred_q_correct = torch.eq(label_pred_q, y_qry[i]).sum().item()
                        corrects['label_query_nway'][k+1] += label_pred_q_correct

                        unlabel_logits_q = self.net(unlabel_qry[i], fast_weights, bn_training=False)
                        unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                        other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                        corrects["unlabel_query_nway"][k+1] += other

                    else:
                        pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                        q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                        corrects['query_nway'][k+1] += q_discrim_correct

        # end of all tasks
        # sum over all losses on query set across all tasks
        loss_q = losses_q[-1] / task_num
        # optimize theta parameters
        self.meta_optim.zero_grad()
        loss_q.backward()

        self.meta_optim.step()
        
        accs = {}
        if self.distractor:
            accs["total_query_nway"] = corrects["total_query_nway"] / (task_num * (querysz + unlabel_querysz))
            accs["label_query_nway"] = corrects["label_query_nway"] / (task_num * querysz)
            accs["unlabel_query_nway"] = corrects["unlabel_query_nway"] / (task_num * unlabel_querysz)
        else:
            accs["query_nway"] = corrects["query_nway"] / (task_num * (querysz + querysz))
        return accs,loss_q


    def finetunning(self, x_spt, y_spt, x_qry, y_qry, unlabel_spt=None, unlabel_qry=None):
        """

        :param x_spt:   [setsz, c_, h, w]
        :param y_spt:   [setsz]
        :param x_qry:   [querysz, c_, h, w]
        :param y_qry:   [querysz]
        :return:
        """
        assert len(x_spt.shape) == 4

        querysz = x_qry.size(0)
        if self.distractor:
            unlabel_querysz = unlabel_qry.size(0)
            corrects = {key: np.zeros(self.update_step_test + 1) for key in 
                            [
                            "total_query_nway",
                            "label_query_nway",
                            "unlabel_query_nway"
                            ]}
        else:
            corrects = {key: np.zeros(self.update_step_test + 1) for key in 
                            [
                            "query_nway"
                            ]}
        # in order to not ruin the state of running_mean/variance and bn_weight/bias
        # we finetunning on the copied model instead of self.net
        net = deepcopy(self.net)
        
        if self.distractor:
            spt_image = torch.concat((x_spt,unlabel_spt))
            spt_unlabel_label = torch.full((unlabel_spt.size(0),), 5, dtype=torch.long,device=self.device)
            spt_label = torch.cat((y_spt,spt_unlabel_label))
            qry_image = torch.concat((x_qry,unlabel_qry))
            qry_unlabel_label = torch.full((unlabel_qry.size(0),), 5, dtype=torch.long,device=self.device)
            qry_label = torch.cat((y_qry,qry_unlabel_label))
        else:
            spt_image = x_spt
            spt_label = y_spt
            qry_image = x_qry
            qry_label = y_qry
        # 1. run the i-th task and compute loss for k=0
        logits = net(spt_image)
        loss = F.cross_entropy(logits, spt_label)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters())))

        # this is the loss and accuracy before first update
        with torch.no_grad():
            if self.distractor:
                total_logits_q = self.net(qry_image, self.net.parameters(), bn_training=False)
                total_pred_q = F.softmax(total_logits_q, dim=1).argmax(dim=1)
                total_q_correct = torch.eq(total_pred_q, qry_label).sum().item()
                corrects['total_query_nway'][0] += total_q_correct

                label_logits_q = self.net(x_qry, self.net.parameters(), bn_training=False)
                label_pred_q = F.softmax(label_logits_q, dim=1).argmax(dim=1)
                label_pred_q_correct = torch.eq(label_pred_q, y_qry).sum().item()
                corrects['label_query_nway'][0] += label_pred_q_correct

                unlabel_logits_q = self.net(unlabel_qry, self.net.parameters(), bn_training=False)
                unlabel_pred_q = F.softmax(unlabel_logits_q, dim=1).argmax(dim=1)
                other = torch.eq(unlabel_pred_q, qry_unlabel_label).sum().item()
                corrects["unlabel_query_nway"][0] += other
            else:
                logits_q = self.net(qry_image, self.net.parameters(), bn_training=True)
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                corrects['query_nway'][0] += q_discrim_correct
        # this is the loss and accuracy after the first update
        with torch.no_grad():
            if self.distractor:
                logits_q = self.net(qry_image, fast_weights, bn_training=False)
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                corrects['total_query_nway'][1] += q_discrim_correct

                logits_q = self.net(x_qry, fast_weights, bn_training=False)
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                q_discrim_correct = torch.eq(pred_q, y_qry).sum().item()
                corrects['label_query_nway'][1] += q_discrim_correct

                logits_q = self.net(unlabel_qry, fast_weights, bn_training=False)
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                other = torch.eq(pred_q, qry_unlabel_label).sum().item()
                corrects["unlabel_query_nway"][1] += other
            else:
                logits_q = self.net(qry_image, fast_weights, bn_training=True)
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                corrects['query_nway'][1] += q_discrim_correct
        for k in range(1, self.update_step_test):
            # 1. run the i-th task and compute loss for k=1~K-1
            logits = net(spt_image, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, spt_label)
            # 2. compute grad on theta_pi
            grad = torch.autograd.grad(loss, fast_weights)
            # 3. theta_pi = theta_pi - train_lr * grad
            fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

            logits_q = net(qry_image, fast_weights, bn_training=True)
            # loss_q will be overwritten and just keep the loss_q on last update step.
            loss_q = F.cross_entropy(logits_q, qry_label)

            with torch.no_grad():
                if self.distractor:
                    logits_q = self.net(qry_image, fast_weights, bn_training=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects['total_query_nway'][k+1] += q_discrim_correct

                    logits_q = self.net(x_qry, fast_weights, bn_training=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, y_qry).sum().item()
                    corrects['label_query_nway'][k+1] += q_discrim_correct

                    logits_q = self.net(unlabel_qry, fast_weights, bn_training=False)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    other = torch.eq(pred_q, qry_unlabel_label).sum().item()
                    corrects["unlabel_query_nway"][k+1] += other
                else:
                    logits_q = self.net(qry_image, fast_weights, bn_training=True)
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    q_discrim_correct = torch.eq(pred_q, qry_label).sum().item()
                    corrects['query_nway'][k+1] += q_discrim_correct
        del net
        if self.distractor:
            accs["total_query_nway"] = corrects["total_query_nway"] / (querysz + unlabel_querysz)
            accs["label_query_nway"] = corrects["label_query_nway"] / querysz
            accs["unlabel_query_nway"] = corrects["unlabel_query_nway"] / unlabel_querysz
        else:
            accs["query_nway"] = corrects["query_nway"] / querysz
        return accs

In [9]:
device = torch.device('cuda')
maml = Meta(args, config).to(device)

In [10]:
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))

In [None]:
path = args["save_path"]
step = 0
mkdir_p(path)
for epoch in range(args["epoch"]//6000):
        # fetch meta_batchsz num of episode each time

    train_dataloader = DataLoader(train_data_generator, args["task_num"], shuffle=True, num_workers=1, pin_memory=True)

    for _,data  in enumerate(train_dataloader):
        if len(data) == 4:
            (x_spt, y_spt, x_qry, y_qry) = data
            x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
            accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry)
        else:
            (x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry) = data
            x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device), \
            unlabel_spt.to(device), unlabel_qry.to(device)
            accs,loss_q = maml(x_spt, y_spt, x_qry, y_qry,unlabel_spt_image=unlabel_spt, unlabel_qry_image=unlabel_qry)

        writer.add_scalar('Loss/train_loss', loss_q, step)
        writer.add_scalar('Accuracy/train_total_query_nway_accuracy', accs["total_query_nway"][-1], step)
        writer.add_scalar('Accuracy/train_label_query_nway_accuracy', accs["label_query_nway"][-1], step)
        writer.add_scalar('Accuracy/train_unlabel_query_nway_accuracy', accs["unlabel_query_nway"][-1], step)
        if step % 30 == 0:
            print('step:', step, '\ttraining acc:', accs)
        if step % 100 == 0:  # evaluation
            db_test = DataLoader(test_data_generator, 1, shuffle=True, num_workers=1, pin_memory=True)
            accs_all_test = {
                            "total_query_nway":[],
                            "unlabel_query_nway":[],
                            "label_query_nway":[]
            }

            for test_data in db_test:
                if len(test_data) == 4:
                    x_spt, y_spt, x_qry, y_qry = test_data
                    x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                else:
                    x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = test_data
                    x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
                                                 x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device),\
                                                unlabel_spt.squeeze(0).to(device), unlabel_qry.squeeze(0).to(device)

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry, unlabel_spt, unlabel_qry)
                accs_all_test["total_query_nway"].append(accs["total_query_nway"])
                accs_all_test["label_query_nway"].append(accs["label_query_nway"])
                accs_all_test["unlabel_query_nway"].append(accs["unlabel_query_nway"])

            # [b, update_step+1]
            accs["total_query_nway"] = np.array(accs_all_test["total_query_nway"]).mean(axis=0).astype(np.float16)
            accs["label_query_nway"] = np.array(accs_all_test["label_query_nway"]).mean(axis=0).astype(np.float16)
            accs["unlabel_query_nway"] = np.array(accs_all_test["unlabel_query_nway"]).mean(axis=0).astype(np.float16)
            
            
            print('Test acc:', accs)
            writer.add_scalar('Accuracy/test_total_query_nway_accuracy', accs["total_query_nway"][-1], step)
            writer.add_scalar('Accuracy/test_label_query_nway_accuracy', accs["label_query_nway"][-1], step)
            writer.add_scalar('Accuracy/test_unlabel_query_nway_accuracy', accs["unlabel_query_nway"][-1], step)

            torch.save({'model_state_dict': maml.state_dict()}, "maml/" + path + "/model_step" + str(step) + ".pt")
        step += 1

step: 0 	training acc: {'total_query_nway': array([0.175     , 0.34      , 0.345     , 0.385     , 0.37833333,
       0.39      ]), 'label_query_nway': array([0.152     , 0.23733333, 0.26933333, 0.31466667, 0.31466667,
       0.34133333]), 'unlabel_query_nway': array([0.21333333, 0.21333333, 0.47111111, 0.50222222, 0.48444444,
       0.47111111])}
Test acc: {'total_query_nway': array([0.1696, 0.3333, 0.3494, 0.3572, 0.3608, 0.3657, 0.3674, 0.3694,
       0.3713, 0.3728, 0.3738], dtype=float16), 'label_query_nway': array([0.169 , 0.2544, 0.277 , 0.2827, 0.2844, 0.2883, 0.2896, 0.2917,
       0.2932, 0.2952, 0.2966], dtype=float16), 'unlabel_query_nway': array([0.1708, 0.4653, 0.4702, 0.4812, 0.4883, 0.4944, 0.4978, 0.4993,
       0.5015, 0.5015, 0.5024], dtype=float16)}
step: 30 	training acc: {'total_query_nway': array([0.24      , 0.43666667, 0.43166667, 0.46833333, 0.47333333,
       0.48833333]), 'label_query_nway': array([0.07733333, 0.36266667, 0.39733333, 0.408     , 0.42133333,
