In [2]:
import torch
import numpy as np
from MiniImagenet import MiniImagenet
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn


class Learner(nn.Module):
    def __init__(self, config, imgc, imgsz):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name == "conv2d":
                w = nn.Parameter(torch.ones(*param[:4]))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "linear":
                w = nn.Parameter(torch.ones(*param))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "bn":
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        for name, param in self.config:
            if name == "conv2d":
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name == "linear":
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name == "bn":
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = (
                    self.vars_bn[bn_idx],
                    self.vars_bn[bn_idx + 1],
                )
                x = F.batch_norm(
                    x, running_mean, running_var, weight=w, bias=b, training=bn_training
                )
                idx += 2
                bn_idx += 2
            elif name == "flatten":
                x = x.view(x.size(0), -1)
            elif name == "relu":
                x = F.relu(x, inplace=param[0])
            elif name == "max_pool2d":
                x = F.max_pool2d(x, param[0], param[1], param[2])

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x

    def parameters(self):
        return self.vars


class Meta(nn.Module):
    def __init__(
        self,
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(
                    lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters()),
                )
            )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
                )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
        )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
            )

            logits_q = net(x_qry, fast_weights, bn_training=True)
            loss_q = F.cross_entropy(logits_q, y_qry)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] = corrects[k + 1] + correct

        del net
        accs = np.array(corrects) / querysz
        return accs


def main():
    epoch = 60000
    n_way = 5
    k_spt = 1
    k_qry = 15
    imgsz = 84
    imgc = 3
    task_num = 4
    meta_lr = 1e-3
    update_lr = 0.01
    update_step = 5
    update_step_test = 10
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    config = [
        ("conv2d", [32, 3, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 1, 0]),
        ("flatten", []),
        ("linear", [n_way, 32 * 5 * 5]),
    ]

    device = torch.device("cuda")
    maml = Meta(
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ).to(device)

    mini = MiniImagenet(
        "mini-imagenet",
        mode="train",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=10000,
        resize=imgsz,
    )
    mini_test = MiniImagenet(
        "mini-imagenet",
        mode="test",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=100,
        resize=imgsz,
    )

    for epoch_count in range(epoch // 10000):
        db = DataLoader(mini, task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.to(device),
                y_spt.to(device),
                x_qry.to(device),
                y_qry.to(device),
            )
            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 30 == 0:
                print("step:", step, "\ttraining acc:", accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(
                    mini_test, 1, shuffle=True, num_workers=1, pin_memory=True
                )
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = (
                        x_spt.squeeze(0).to(device),
                        y_spt.squeeze(0).to(device),
                        x_qry.squeeze(0).to(device),
                        y_qry.squeeze(0).to(device),
                    )
                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print("Test acc:", accs)


if __name__ == "__main__":
    main()

shuffle DB :train, b:10000, 5-way, 1-shot, 15-query, resize:84
shuffle DB :test, b:100, 5-way, 1-shot, 15-query, resize:84
step: 0 	training acc: [0.21       0.22333333 0.23       0.23       0.22666667 0.23      ]


KeyboardInterrupt: 

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
import csv
import random
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn


class MiniImagenet(Dataset):
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        self.batchsz = batchsz
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.setsz = self.n_way * self.k_shot
        self.querysz = self.n_way * self.k_query
        self.resize = resize
        self.startidx = startidx
        print(
            "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
            % (mode, batchsz, n_way, k_shot, k_query, resize)
        )

        if mode == "train":
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        self.path = os.path.join(root, "images")
        csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)
            self.img2label[k] = i + self.startidx
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=",")
            next(csvreader, None)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        self.support_x_batch = []
        self.query_x_batch = []
        for b in range(batchsz):
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                selected_imgs_idx = np.random.choice(
                    len(self.data[cls]), self.k_shot + self.k_query, False
                )
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
                indexDtest = np.array(selected_imgs_idx[self.k_shot :])
                support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)
            self.query_x_batch.append(query_x)

    def __getitem__(self, index):
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)

        flatten_support_x = [
            os.path.join(self.path, item)
            for sublist in self.support_x_batch[index]
            for item in sublist
        ]
        support_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.support_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        flatten_query_x = [
            os.path.join(self.path, item)
            for sublist in self.query_x_batch[index]
            for item in sublist
        ]
        query_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.query_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        unique = np.unique(support_y)
        random.shuffle(unique)
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)

        return (
            support_x,
            torch.LongTensor(support_y_relative),
            query_x,
            torch.LongTensor(query_y_relative),
        )

    def __len__(self):
        return self.batchsz


class Learner(nn.Module):
    def __init__(self, config, imgc, imgsz):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name == "conv2d":
                w = nn.Parameter(torch.ones(*param[:4]))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "linear":
                w = nn.Parameter(torch.ones(*param))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "bn":
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        for name, param in self.config:
            if name == "conv2d":
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name == "linear":
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name == "bn":
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = (
                    self.vars_bn[bn_idx],
                    self.vars_bn[bn_idx + 1],
                )
                x = F.batch_norm(
                    x, running_mean, running_var, weight=w, bias=b, training=bn_training
                )
                idx += 2
                bn_idx += 2
            elif name == "flatten":
                x = x.view(x.size(0), -1)
            elif name == "relu":
                x = F.relu(x, inplace=param[0])
            elif name == "max_pool2d":
                x = F.max_pool2d(x, param[0], param[1], param[2])

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x

    def parameters(self):
        return self.vars


class Meta(nn.Module):
    def __init__(
        self,
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(
                    lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters()),
                )
            )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
                )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
        )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
            )

            logits_q = net(x_qry, fast_weights, bn_training=True)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] = corrects[k + 1] + correct

        del net
        accs = np.array(corrects) / querysz
        return accs


def main():
    epoch = 60000
    n_way = 5
    k_spt = 1
    k_qry = 15
    imgsz = 84
    imgc = 3
    task_num = 4
    meta_lr = 1e-3
    update_lr = 0.01
    update_step = 5
    update_step_test = 10

    # Set the paths to your mini-imagenet dataset
    root = "./mini-imagenet"

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    config = [
        ("conv2d", [32, 3, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 1, 0]),
        ("flatten", []),
        ("linear", [n_way, 32 * 5 * 5]),
    ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maml = Meta(
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        root,
        mode="train",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=10000,
        resize=imgsz,
    )
    mini_test = MiniImagenet(
        root,
        mode="test",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=100,
        resize=imgsz,
    )

    for epoch_count in range(epoch // 10000):
        db = DataLoader(mini, task_num, shuffle=True, num_workers=1, pin_memory=True)

        for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
            x_spt, y_spt, x_qry, y_qry = (
                x_spt.to(device),
                y_spt.to(device),
                x_qry.to(device),
                y_qry.to(device),
            )

            accs = maml(x_spt, y_spt, x_qry, y_qry)

            if step % 50 == 0:
                print("step:", step, "\ttraining acc:", accs)

            if step % 500 == 0:  # evaluation
                db_test = DataLoader(
                    mini_test, 1, shuffle=True, num_workers=1, pin_memory=True
                )
                accs_all_test = []

                for x_spt, y_spt, x_qry, y_qry in db_test:
                    x_spt, y_spt, x_qry, y_qry = (
                        x_spt.squeeze(0).to(device),
                        y_spt.squeeze(0).to(device),
                        x_qry.squeeze(0).to(device),
                        y_qry.squeeze(0).to(device),
                    )

                    accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
                    accs_all_test.append(accs)

                # [b, update_step+1]
                accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
                print("Test acc:", accs)


if __name__ == "__main__":
    main()

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

# Load and preprocess the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train.reshape((-1, 28, 28, 1))
x_test = x_test.reshape((-1, 28, 28, 1))


# Define the model architecture
def create_model():
    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Conv2D(
                32, (3, 3), activation="relu", input_shape=(28, 28, 1)
            ),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation="relu"),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation="relu"),
            tf.keras.layers.Dense(10),
        ]
    )
    return model


# Define the MAML model
class MAML(tf.keras.Model):
    def __init__(self, model):
        super(MAML, self).__init__()
        self.model = model

    def train_step(self, data):
        x, y = data
        x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
        y = tf.reshape(y, (-1,))  # Reshape the target labels
        with tf.GradientTape() as tape:
            y_pred = self.model(x)
            loss = self.compiled_loss(y, y_pred)
        gradients = tape.gradient(loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data
        x = tf.reshape(x, (-1, 28, 28, 1))  # Reshape the input tensor
        y = tf.reshape(y, (-1,))  # Reshape the target labels
        y_pred = self.model(x)
        self.compiled_loss(y, y_pred)
        self.compiled_metrics.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}


# Define the meta-learning parameters
num_meta_updates = 10
num_inner_updates = 5
meta_batch_size = 32
inner_batch_size = 10

# Create the MAML model
model = MAML(create_model())
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"],
)

# Initialize variables to store accuracy over time
meta_updates = []
accuracy_over_time = []

# Meta-training loop
for meta_update in range(num_meta_updates):
    # Sample a meta-batch of tasks
    meta_batch = tf.random.shuffle(tf.range(len(x_train)))[:meta_batch_size]

    # Inner loop updates for each task
    for task in meta_batch:
        task_data = (
            x_train[task : task + inner_batch_size],
            y_train[task : task + inner_batch_size],
        )
        for inner_update in range(num_inner_updates):
            model.train_step(task_data)

    # Evaluate on the meta-test set
    _, accuracy = model.evaluate(x_test, y_test)

    # Store the meta-update step and accuracy
    meta_updates.append(meta_update + 1)
    accuracy_over_time.append(accuracy)

# Fine-tuning on a new task
new_task_data = (x_test[:100], y_test[:100])
model.fit(new_task_data[0], new_task_data[1], epochs=10)

# Plot the accuracy over time graph
plt.plot(meta_updates, accuracy_over_time)
plt.xlabel("Meta-Update Step")
plt.ylabel("Accuracy")
plt.title("Accuracy Over Time")
plt.show()

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
import csv
import random
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
from sklearn.model_selection import train_test_split
from Pyfhel import Pyfhel


class MiniImagenet(Dataset):
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        self.batchsz = batchsz
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.setsz = self.n_way * self.k_shot
        self.querysz = self.n_way * self.k_query
        self.resize = resize
        self.startidx = startidx
        print(
            "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
            % (mode, batchsz, n_way, k_shot, k_query, resize)
        )

        if mode == "train":
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        self.path = os.path.join(root, "images")
        csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)
            self.img2label[k] = i + self.startidx
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=",")
            next(csvreader, None)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        self.support_x_batch = []
        self.query_x_batch = []
        for b in range(batchsz):
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                selected_imgs_idx = np.random.choice(
                    len(self.data[cls]), self.k_shot + self.k_query, False
                )
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
                indexDtest = np.array(selected_imgs_idx[self.k_shot :])
                support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)
            self.query_x_batch.append(query_x)

    def __getitem__(self, index):
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)

        flatten_support_x = [
            os.path.join(self.path, item)
            for sublist in self.support_x_batch[index]
            for item in sublist
        ]
        support_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.support_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        flatten_query_x = [
            os.path.join(self.path, item)
            for sublist in self.query_x_batch[index]
            for item in sublist
        ]
        query_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.query_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        unique = np.unique(support_y)
        random.shuffle(unique)
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)

        return (
            support_x,
            torch.LongTensor(support_y_relative),
            query_x,
            torch.LongTensor(query_y_relative),
        )

    def __len__(self):
        return self.batchsz


class Learner(nn.Module):
    def __init__(self, config, imgc, imgsz):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name == "conv2d":
                w = nn.Parameter(torch.ones(*param[:4]))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "linear":
                w = nn.Parameter(torch.ones(*param))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "bn":
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        for name, param in self.config:
            if name == "conv2d":
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name == "linear":
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name == "bn":
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = (
                    self.vars_bn[bn_idx],
                    self.vars_bn[bn_idx + 1],
                )
                x = F.batch_norm(
                    x, running_mean, running_var, weight=w, bias=b, training=bn_training
                )
                idx += 2
                bn_idx += 2
            elif name == "flatten":
                x = x.view(x.size(0), -1)
            elif name == "relu":
                x = F.relu(x, inplace=param[0])
            elif name == "max_pool2d":
                x = F.max_pool2d(x, param[0], param[1], param[2])

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x

    def parameters(self):
        return self.vars


class Meta(nn.Module):
    def __init__(
        self,
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(
                    lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters()),
                )
            )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
                )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
        )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
import csv
import random
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
from sklearn.model_selection import train_test_split
from Pyfhel import Pyfhel


class MiniImagenet(Dataset):
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        self.batchsz = batchsz
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.setsz = self.n_way * self.k_shot
        self.querysz = self.n_way * self.k_query
        self.resize = resize
        self.startidx = startidx
        print(
            "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
            % (mode, batchsz, n_way, k_shot, k_query, resize)
        )

        if mode == "train":
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        self.path = os.path.join(root, "images")
        csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)
            self.img2label[k] = i + self.startidx
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=",")
            next(csvreader, None)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        self.support_x_batch = []
        self.query_x_batch = []
        for b in range(batchsz):
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                selected_imgs_idx = np.random.choice(
                    len(self.data[cls]), self.k_shot + self.k_query, False
                )
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
                indexDtest = np.array(selected_imgs_idx[self.k_shot :])
                support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)
            self.query_x_batch.append(query_x)

    def __getitem__(self, index):
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)

        flatten_support_x = [
            os.path.join(self.path, item)
            for sublist in self.support_x_batch[index]
            for item in sublist
        ]
        support_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.support_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        flatten_query_x = [
            os.path.join(self.path, item)
            for sublist in self.query_x_batch[index]
            for item in sublist
        ]
        query_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.query_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        unique = np.unique(support_y)
        random.shuffle(unique)
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)

        return (
            support_x,
            torch.LongTensor(support_y_relative),
            query_x,
            torch.LongTensor(query_y_relative),
        )

    def __len__(self):
        return self.batchsz


class Learner(nn.Module):
    def __init__(self, config, imgc, imgsz):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name == "conv2d":
                w = nn.Parameter(torch.ones(*param[:4]))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "linear":
                w = nn.Parameter(torch.ones(*param))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "bn":
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        for name, param in self.config:
            if name == "conv2d":
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name == "linear":
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name == "bn":
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = (
                    self.vars_bn[bn_idx],
                    self.vars_bn[bn_idx + 1],
                )
                x = F.batch_norm(
                    x, running_mean, running_var, weight=w, bias=b, training=bn_training
                )
                idx += 2
                bn_idx += 2
            elif name == "flatten":
                x = x.view(x.size(0), -1)
            elif name == "relu":
                x = F.relu(x, inplace=param[0])
            elif name == "max_pool2d":
                x = F.max_pool2d(x, param[0], param[1], param[2])

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x

    def parameters(self):
        return self.vars


class Meta(nn.Module):
    def __init__(
        self,
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(
                    lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters()),
                )
            )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
                )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
        )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_

In [None]:
# entropy(logits, y_spt)
#             grad = torch.autograd.grad(loss, fast_weights)
#             fast_weights = list(
#                 map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
#             )

#             logits_q = net(x_qry, fast_weights, bn_training=True)

#             with torch.no_grad():
#                 pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
#                 correct = torch.eq(pred_q, y_qry).sum().item()
#                 corrects[k + 1] = corrects[k + 1] + correct

#         del net
#         accs = np.array(corrects) / querysz
#         return accs


# def train_model(model, x_train, y_train):
#     meta_updates = []
#     accuracy_over_time = []
#     for meta_update in range(num_meta_updates):
#         # Sample a meta-batch of tasks
#         meta_batch_indices = torch.randperm(len(x_train))[:meta_batch_size]

#         # Inner loop updates for each task
#         for task_index in meta_batch_indices:
#             task_data = (
#                 x_train[task_index : task_index + inner_batch_size],
#                 y_train[task_index : task_index + inner_batch_size],
#             )

#             # Create a copy of the model for the inner loop updates
#             inner_model = deepcopy(model)
#             inner_model.load_state_dict(model.state_dict())

#             for inner_update in range(num_inner_updates):
#                 # Perform inner loop update on the task-specific model
#                 inner_model.train_step(task_data)

#             # Update the original model with the adapted weights
#             model.load_state_dict(inner_model.state_dict())

#         # Evaluate on the meta-test set
#         _, accuracy = model.evaluate(x_test, y_test)

#         # Store the meta-update step and accuracy
#         meta_updates.append(meta_update + 1)
#         accuracy_over_time.append(accuracy)

#     avg_accuracy = sum(accuracy_over_time) / len(accuracy_over_time)
#     return model, avg_accuracy


# def generate_diffie_hellman_parameters():
#     parameters = dh.generate_parameters(generator=2, key_size=512)
#     return parameters


# def generate_diffie_hellman_keys(parameters):
#     private_key = parameters.generate_private_key()
#     public_key = private_key.public_key()
#     return private_key, public_key


# def derive_key(private_key, peer_public_key):
#     shared_key = private_key.exchange(peer_public_key)
#     derived_key = HKDF(
#         algorithm=hashes.SHA256(),
#         length=32,
#         salt=None,
#         info=b"handshake data",
#     ).derive(shared_key)
#     return derived_key


# def encrypt_message_AES(key, message):
#     serialized_obj = pickle.dumps(message)
#     cipher = Cipher(algorithms.AES(key), modes.ECB())
#     encryptor = cipher.encryptor()
#     padded_obj = serialized_obj + b" " * (16 - len(serialized_obj) % 16)
#     ciphertext = encryptor.update(padded_obj) + encryptor.finalize()
#     return ciphertext


# def decrypt_message_AES(key, ciphertext):
#     cipher = Cipher(algorithms.AES(key), modes.ECB())
#     decryptor = cipher.decryptor()
#     padded_obj = decryptor.update(ciphertext) + decryptor.finalize()
#     serialized_obj = padded_obj.rstrip(b" ")
#     obj = pickle.loads(serialized_obj)
#     return obj


# def setup_AES():
#     num_clients = len(clients)
#     parameters = generate_diffie_hellman_parameters()
#     server_private_key, server_public_key = generate_diffie_hellman_keys(parameters)
#     client_keys = [generate_diffie_hellman_keys(parameters) for _ in range(num_clients)]
#     shared_keys = [
#         derive_key(server_private_key, client_public_key)
#         for _, client_public_key in client_keys
#     ]
#     client_shared_keys = [
#         derive_key(client_private_key, server_public_key)
#         for client_private_key, _ in client_keys
#     ]

#     return client_keys, shared_keys, client_shared_keys


# def encrypt_wt(wtarray, i):
#     cwt = []
#     for layer in wtarray:
#         flat_array = layer.astype(np.float64).flatten()
#         chunks = np.array_split(flat_array, (len(flat_array) + 2 ** 10 - 1) // 2 ** 10)
#         clayer = []
#         for chunk in chunks:
#             ptxt = HE.encodeFrac(chunk)
#             ctxt = HE.encryptPtxt(ptxt)
#             clayer.append(ctxt)
#         cwt.append(clayer.copy())
#     ciphertext = encrypt_message_AES(client_shared_keys[i], cwt)
#     return ciphertext


# def aggregate_wt(encrypted_cwts):
#     cwts = []
#     for i, ecwt in enumerate(encrypted_cwts):
#         cwts.append(decrypt_message_AES(shared_keys[i], ecwt))
#     resmodel = []
#     for j in range(len(cwts[0])):  # for layers
#         layer = []
#         for k in range(len(cwts[0][j])):  # for chunks
#             tmp = cwts[0][j][k].copy()
#             for i in range(1, len(cwts)):  # for clients
#                 tmp = tmp + cwts[i][j][k]
#             tmp = tmp / len(cwts)
#             layer.append(tmp)
#         resmodel.append(layer)

#     res = [resmodel.copy() for _ in range(len(clients))]
#     return res


# def decrypt_weights(res):
#     decrypted_weights = []
#     for client_weights, model in zip(res, models):
#         decrypted_client_weights = []
#         wtarray = model.get_weights()
#         for layer_weights, layer in zip(client_weights, wtarray):
#             decrypted_layer_weights = []
#             flat_array = layer.astype(np.float64).flatten()
#             chunks = np.array_split(flat_array, (len(flat_array) + 2 ** 10 - 1) // 2 ** 10)
#             for chunk, encrypted_chunk in zip(chunks, layer_weights):
#                 decrypted_chunk = HE.decryptFrac(encrypted_chunk)
#                 original_chunk_size = len(chunk)
#                 decrypted_chunk = decrypted_chunk[:original_chunk_size]
#                 decrypted_layer_weights.append(decrypted_chunk)
#             decrypted_layer_weights = np.concatenate(decrypted_layer_weights, axis=0)
#             decrypted_layer_weights = decrypted_layer_weights.reshape(layer.shape)
#             decrypted_client_weights.append(decrypted_layer_weights)
#         decrypted_weights.append(decrypted_client_weights)
#     return decrypted_weights


# def main():
#     epoch = 60000
#     n_way = 5
#     k_spt = 1
#     k_qry = 15
#     imgsz = 84
#     imgc = 3
#     task_num = 4
#     meta_lr = 1e-3
#     update_lr = 0.01
#     update_step = 5
#     update_step_test = 10

#     # Set the paths to your mini-imagenet dataset
#     root = "./mini-imagenet"

#     torch.manual_seed(222)
#     torch.cuda.manual_seed_all(222)
#     np.random.seed(222)

#     config = [
#         ("conv2d", [32, 3, 3, 3, 1, 0]),
#         ("relu", [True]),
#         ("bn", [32]),
#         ("max_pool2d", [2, 2, 0]),
#         ("conv2d", [32, 32, 3, 3, 1, 0]),
#         ("relu", [True]),
#         ("bn", [32]),
#         ("max_pool2d", [2, 2, 0]),
#         ("conv2d", [32, 32, 3, 3, 1, 0]),
#         ("relu", [True]),
#         ("bn", [32]),
#         ("max_pool2d", [2, 2, 0]),
#         ("conv2d", [32, 32, 3, 3, 1, 0]),
#         ("relu", [True]),
#         ("bn", [32]),
#         ("max_pool2d", [2, 1, 0]),
#         ("flatten", []),
#         ("linear", [n_way, 32 * 5 * 5]),
#     ]

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     maml = Meta(
#         n_way,
#         k_spt,
#         k_qry,
#         task_num,
#         update_step,
#         update_step_test,
#         update_lr,
#         meta_lr,
#         config,
#         imgc,
#         imgsz,
#     ).to(device)

#     tmp = filter(lambda x: x.requires_grad, maml.parameters())
#     num = sum(map(lambda x: np.prod(x.shape), tmp))
#     print(maml)
#     print("Total trainable tensors:", num)

#     # batchsz here means total episode number
#     mini = MiniImagenet(
#         root,
#         mode="train",
#         n_way=n_way,
#         k_shot=k_spt,
#         k_query=k_qry,
#         batchsz=10000,
#         resize=imgsz,
#     )
#     mini_test = MiniImagenet(
#         root,
#         mode="test",
#         n_way=n_way,
#         k_shot=k_spt,
#         k_query=k_qry,
#         batchsz=100,
#         resize=imgsz,
#     )

#     # Load the MNIST dataset
#     (x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

#     # Preprocess the data
#     x_train_all = x_train_all.astype(np.float32) / 255.0
#     x_test = x_test.astype(np.float32) / 255.0

#     # Split data into n parts
#     X_train, _, y_train, _ = train_test_split(
#         x_train_all, y_train_all, test_size=0.2, random_state=42
#     )
#     n_parts = len(clients)
#     part_size = len(X_train) // n_parts
#     dataset_parts = []
#     for i in range(n_parts):
#         start = i * part_size
#         end = (i + 1) * part_size
#         X_part = X_train[start:end]
#         y_part = y_train[start:end]
#         dataset_parts.append((X_part, y_part))

#     models = [maml for _ in range(len(clients))]

#     HE = Pyfhel()
#     ckks_params = {
#         "scheme": "CKKS",
#         "n": 2 ** 14,
#         "scale": 2 ** 30,
#         "qi_sizes": [60, 30, 30, 30, 60],
#     }
#     HE.contextGen(**ckks_params)
#     HE.keyGen()
#     HE.rotateKeyGen()

#     client_keys, shared_keys, client_shared_keys = setup_AES()

#     accuracies = [[] for _ in range(len(clients))]
#     losses = [[] for _ in range(len(clients))]

#     meta_batch_size = 32  # Number of tasks per meta-update
#     inner_batch_size = 5  # Number of examples per task
#     num_inner_updates = 5  # Number of inner loop updates per task
#     num_meta_updates = 100

#     cwts = [encrypt_wt(model.get_weights(), i) for i, model in enumerate(models)]
#     for e in tqdm(range(epoch)):
#         cwts = aggregate_wt(cwts)
#         wts = decrypt_weights(cwts)
#         cwts = []
#         for wt, model, dataset, i in zip(wts, models, dataset_parts, range(len(clients))):
#             model.set_weights(wt)
#             model, accuracy = train_model(model, dataset[0], dataset[1])
#             accuracies[i].append(accuracy)
#             print("Accuracies", accuracy)
#             wtarray = model.get_weights()
#             cwts.append(encrypt_wt(wtarray, i))

#     import matplotlib.pyplot as plt

#     epochs_range = range(1, epoch + 1)

#     plt.figure(figsize=(10, 5))
#     for i, client in enumerate(clients):
#         plt.plot(
#             epochs_range,
#             accuracies[i],
#             label=f"Client {client}" if client != 0 else "Aggregate",
#         )
#     plt.xlabel("Epochs")
#     plt.ylabel("Accuracy")
#     plt.legend()
#     plt.title("Accuracy for Each Client")
#     plt.show()

#     plt.figure(figsize=(10, 5))
#     for i, client in enumerate(clients):
#         plt.plot(
#             epochs_range, losses[i], label=f"Client {client}" if client != 0 else "Aggregate"
#         )
#     plt.xlabel("Epochs")
#     plt.ylabel("Loss")
#     plt.legend()
#     plt.title("Loss for Each Client")
#     plt.show()


# if __name__ == "__main__":
#     main()
# ```

# This code combines the PyTorch model with MAML and includes the rest of the code for data loading, encryption, decryption, and aggregation. The `train_model` function has been updated to work with the PyTorch model.

# Please note that you may need to adjust the paths and settings according to your specific dataset and requirements. Also, make sure you have the necessary dependencies installed, such as PyTorch, NumPy, Pyfhel, and cryptography.

In [22]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
import csv
import random
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
from sklearn.model_selection import train_test_split
from Pyfhel import Pyfhel




HE = Pyfhel()
ckks_params = {
    "scheme": "CKKS",
    "n": 2**14,
    "scale": 2**30,
    "qi_sizes": [60, 30, 30, 30, 60],
}
HE.contextGen(**ckks_params)
HE.keyGen()
HE.rotateKeyGen()



class MiniImagenet(Dataset):
   def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
       self.batchsz = batchsz
       self.n_way = n_way
       self.k_shot = k_shot
       self.k_query = k_query
       self.setsz = self.n_way * self.k_shot
       self.querysz = self.n_way * self.k_query
       self.resize = resize
       self.startidx = startidx
       print(
           "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
           % (mode, batchsz, n_way, k_shot, k_query, resize)
       )

       if mode == "train":
           self.transform = transforms.Compose(
               [
                   lambda x: Image.open(x).convert("RGB"),
                   transforms.Resize((self.resize, self.resize)),
                   transforms.ToTensor(),
                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
               ]
           )
       else:
           self.transform = transforms.Compose(
               [
                   lambda x: Image.open(x).convert("RGB"),
                   transforms.Resize((self.resize, self.resize)),
                   transforms.ToTensor(),
                   transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
               ]
           )

       self.path = os.path.join(root, "images")
       csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
       self.data = []
       self.img2label = {}
       for i, (k, v) in enumerate(csvdata.items()):
           self.data.append(v)
           self.img2label[k] = i + self.startidx
       self.cls_num = len(self.data)

       self.create_batch(self.batchsz)

   def loadCSV(self, csvf):
       dictLabels = {}
       with open(csvf) as csvfile:
           csvreader = csv.reader(csvfile, delimiter=",")
           next(csvreader, None)
           for i, row in enumerate(csvreader):
               filename = row[0]
               label = row[1]
               if label in dictLabels.keys():
                   dictLabels[label].append(filename)
               else:
                   dictLabels[label] = [filename]
       return dictLabels

   def create_batch(self, batchsz):
       self.support_x_batch = []
       self.query_x_batch = []
       for b in range(batchsz):
           selected_cls = np.random.choice(self.cls_num, self.n_way, False)
           np.random.shuffle(selected_cls)
           support_x = []
           query_x = []
           for cls in selected_cls:
               selected_imgs_idx = np.random.choice(
                   len(self.data[cls]), self.k_shot + self.k_query, False
               )
               np.random.shuffle(selected_imgs_idx)
               indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
               indexDtest = np.array(selected_imgs_idx[self.k_shot :])
               support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
               query_x.append(np.array(self.data[cls])[indexDtest].tolist())

           random.shuffle(support_x)
           random.shuffle(query_x)

           self.support_x_batch.append(support_x)
           self.query_x_batch.append(query_x)

   def __getitem__(self, index):
       support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
       support_y = np.zeros((self.setsz), dtype=np.int32)
       query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
       query_y = np.zeros((self.querysz), dtype=np.int32)

       flatten_support_x = [
           os.path.join(self.path, item)
           for sublist in self.support_x_batch[index]
           for item in sublist
       ]
       support_y = np.array(
           [
               self.img2label[item[:9]]
               for sublist in self.support_x_batch[index]
               for item in sublist
           ]
       ).astype(np.int32)

       flatten_query_x = [
           os.path.join(self.path, item)
           for sublist in self.query_x_batch[index]
           for item in sublist
       ]
       query_y = np.array(
           [
               self.img2label[item[:9]]
               for sublist in self.query_x_batch[index]
               for item in sublist
           ]
       ).astype(np.int32)

       unique = np.unique(support_y)
       random.shuffle(unique)
       support_y_relative = np.zeros(self.setsz)
       query_y_relative = np.zeros(self.querysz)
       for idx, l in enumerate(unique):
           support_y_relative[support_y == l] = idx
           query_y_relative[query_y == l] = idx

       for i, path in enumerate(flatten_support_x):
           support_x[i] = self.transform(path)

       for i, path in enumerate(flatten_query_x):
           query_x[i] = self.transform(path)

       return (
           support_x,
           torch.LongTensor(support_y_relative),
           query_x,
           torch.LongTensor(query_y_relative),
       )

   def __len__(self):
       return self.batchsz


class Learner(nn.Module):
   def __init__(self, config, imgc, imgsz):
       super(Learner, self).__init__()
       self.config = config
       self.vars = nn.ParameterList()
       self.vars_bn = nn.ParameterList()

       for i, (name, param) in enumerate(self.config):
           if name == "conv2d":
               w = nn.Parameter(torch.ones(*param[:4]))
               torch.nn.init.kaiming_normal_(w)
               self.vars.append(w)
               self.vars.append(nn.Parameter(torch.zeros(param[0])))
           elif name == "linear":
               w = nn.Parameter(torch.ones(*param))
               torch.nn.init.kaiming_normal_(w)
               self.vars.append(w)
               self.vars.append(nn.Parameter(torch.zeros(param[0])))
           elif name == "bn":
               w = nn.Parameter(torch.ones(param[0]))
               self.vars.append(w)
               self.vars.append(nn.Parameter(torch.zeros(param[0])))
               running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
               running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
               self.vars_bn.extend([running_mean, running_var])

   def forward(self, x, vars=None, bn_training=True):
       if vars is None:
           vars = self.vars

       idx = 0
       bn_idx = 0
       for name, param in self.config:
           if name == "conv2d":
               w, b = vars[idx], vars[idx + 1]
               x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
               idx += 2
           elif name == "linear":
               w, b = vars[idx], vars[idx + 1]
               x = F.linear(x, w, b)
               idx += 2
           elif name == "bn":
               w, b = vars[idx], vars[idx + 1]
               running_mean, running_var = (
                   self.vars_bn[bn_idx],
                   self.vars_bn[bn_idx + 1],
               )
               x = F.batch_norm(
                   x, running_mean, running_var, weight=w, bias=b, training=bn_training
               )
               idx += 2
               bn_idx += 2
           elif name == "flatten":
               x = x.view(x.size(0), -1)
           elif name == "relu":
               x = F.relu(x, inplace=param[0])
           elif name == "max_pool2d":
               x = F.max_pool2d(x, param[0], param[1], param[2])

       assert idx == len(vars)
       assert bn_idx == len(self.vars_bn)

       return x

   def parameters(self):
       return self.vars


class Meta(nn.Module):
    def __init__(
       self,
       n_way,
       k_spt,
       k_qry,
       task_num,
       update_step,
       update_step_test,
       update_lr,
       meta_lr,
       config,
       imgc,
       imgsz,
   ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
               map(
                   lambda p: p[1] - self.update_lr * p[0],
                   zip(grad, self.net.parameters()),
               )
           )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                   map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
               )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
           map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
       )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
            )

            logits_q = net(x_qry, fast_weights, bn_training=True)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] = corrects[k + 1] + correct

        del net
        accs = np.array(corrects) / querysz
        return accs


def train_model(model, x_train, y_train):
    meta_updates = []
    accuracy_over_time = []
    for meta_update in range(num_meta_updates):
        # Sample a meta-batch of tasks
        meta_batch_indices = torch.randperm(len(x_train))[:meta_batch_size]

        # Inner loop updates for each task
        for task_index in meta_batch_indices:
            task_data = (
                x_train[task_index : task_index + inner_batch_size],
                y_train[task_index : task_index + inner_batch_size],
            )

            # Create a copy of the model for the inner loop updates
            inner_model = deepcopy(model)
            inner_model.load_state_dict(model.state_dict())

            for inner_update in range(num_inner_updates):
                # Perform inner loop update on the task-specific model
                inner_model.train_step(task_data)

            # Update the original model with the adapted weights
            model.load_state_dict(inner_model.state_dict())

        # Evaluate on the meta-test set
        _, accuracy = model.evaluate(x_test, y_test)

        # Store the meta-update step and accuracy
        meta_updates.append(meta_update + 1)
        accuracy_over_time.append(accuracy)

    avg_accuracy = sum(accuracy_over_time) / len(accuracy_over_time)
    return model, avg_accuracy


def generate_diffie_hellman_parameters():
    parameters = dh.generate_parameters(generator=2, key_size=512)
    return parameters


def generate_diffie_hellman_keys(parameters):
    private_key = parameters.generate_private_key()
    public_key = private_key.public_key()
    return private_key, public_key


def derive_key(private_key, peer_public_key):
    shared_key = private_key.exchange(peer_public_key)
    derived_key = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=None,
        info=b"handshake data",
    ).derive(shared_key)
    return derived_key


def encrypt_message_AES(key, message):
    serialized_obj = pickle.dumps(message)
    cipher = Cipher(algorithms.AES(key), modes.ECB())
    encryptor = cipher.encryptor()
    padded_obj = serialized_obj + b" " * (16 - len(serialized_obj) % 16)
    ciphertext = encryptor.update(padded_obj) + encryptor.finalize()
    return ciphertext


def decrypt_message_AES(key, ciphertext):
    cipher = Cipher(algorithms.AES(key), modes.ECB())
    decryptor = cipher.decryptor()
    padded_obj = decryptor.update(ciphertext) + decryptor.finalize()
    serialized_obj = padded_obj.rstrip(b" ")
    obj = pickle.loads(serialized_obj)
    return obj


def setup_AES():
    num_clients = len(clients)
    parameters = generate_diffie_hellman_parameters()
    server_private_key, server_public_key = generate_diffie_hellman_keys(parameters)
    client_keys = [generate_diffie_hellman_keys(parameters) for _ in range(num_clients)]
    shared_keys = [
        derive_key(server_private_key, client_public_key)
        for _, client_public_key in client_keys
    ]
    client_shared_keys = [
        derive_key(client_private_key, server_public_key)
        for client_private_key, _ in client_keys
    ]

    return client_keys, shared_keys, client_shared_keys


def encrypt_wt(wtarray, i):
    cwt = []
    for layer in wtarray:
        flat_array = layer.astype(np.float64).flatten()
        chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
        clayer = []
        for chunk in chunks:
            ptxt = HE.encodeFrac(chunk)
            ctxt = HE.encryptPtxt(ptxt)
            clayer.append(ctxt)
        cwt.append(clayer.copy())
    ciphertext = encrypt_message_AES(client_shared_keys[i], cwt)
    return ciphertext


def aggregate_wt(encrypted_cwts):
    cwts = []
    for i, ecwt in enumerate(encrypted_cwts):
        cwts.append(decrypt_message_AES(shared_keys[i], ecwt))
    resmodel = []
    for j in range(len(cwts[0])):  # for layers
        layer = []
        for k in range(len(cwts[0][j])):  # for chunks
            tmp = cwts[0][j][k].copy()
            for i in range(1, len(cwts)):  # for clients
                tmp = tmp + cwts[i][j][k]
            tmp = tmp / len(cwts)
            layer.append(tmp)
        resmodel.append(layer)

    res = [resmodel.copy() for _ in range(len(clients))]
    return res


def decrypt_weights(res):
    decrypted_weights = []
    for client_weights, model in zip(res, models):
        decrypted_client_weights = []
        wtarray = model.get_weights()
        for layer_weights, layer in zip(client_weights, wtarray):
            decrypted_layer_weights = []
            flat_array = layer.astype(np.float64).flatten()
            chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
            for chunk, encrypted_chunk in zip(chunks, layer_weights):
                decrypted_chunk = HE.decryptFrac(encrypted_chunk)
                original_chunk_size = len(chunk)
                decrypted_chunk = decrypted_chunk[:original_chunk_size]
                decrypted_layer_weights.append(decrypted_chunk)
            decrypted_layer_weights = np.concatenate(decrypted_layer_weights, axis=0)
            decrypted_layer_weights = decrypted_layer_weights.reshape(layer.shape)
            decrypted_client_weights.append(decrypted_layer_weights)
        decrypted_weights.append(decrypted_client_weights)
    return decrypted_weights


def main():
    epoch = 60000
    n_way = 5
    k_spt = 1
    k_qry = 15
    imgsz = 84
    imgc = 3
    task_num = 4
    meta_lr = 1e-3
    update_lr = 0.01
    update_step = 5
    update_step_test = 10

    # Set the paths to your mini-imagenet dataset
    root = "./mini-imagenet"

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    config = [
        ("conv2d", [32, 3, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 1, 0]),
        ("flatten", []),
        ("linear", [n_way, 32 * 5 * 5]),
    ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maml = Meta(
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        root,
        mode="train",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=10000,
        resize=imgsz,
    )
    mini_test = MiniImagenet(
        root,
        mode="test",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=100,
        resize=imgsz,
    )

    # Load the MNIST dataset
    (x_train_all, y_train_all), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

    # Preprocess the data
    x_train_all = x_train_all.astype(np.float32) / 255.0
    x_test = x_test.astype(np.float32) / 255.0

    # Split data into n parts
    X_train, _, y_train, _ = train_test_split(
        x_train_all, y_train_all, test_size=0.2, random_state=42
    )
# Split data into n parts
    # Split data into n parts
    n_parts = len(clients)
    dataset_parts = []
    for i in range(n_parts):
        train_loader = DataLoader(
            mini, batch_size=task_num, shuffle=True, num_workers=1, pin_memory=True
        )
        dataset_parts.append(train_loader)

    models = [maml for _ in range(len(clients))]

    HE = Pyfhel()
    ckks_params = {
        "scheme": "CKKS",
        "n": 2**14,
        "scale": 2**30,
        "qi_sizes": [60, 30, 30, 30, 60],
    }
    HE.contextGen(**ckks_params)
    HE.keyGen()
    HE.rotateKeyGen()

    client_keys, shared_keys, client_shared_keys = setup_AES()

    accuracies = [[] for _ in range(len(clients))]
    losses = [[] for _ in range(len(clients))]

    meta_batch_size = 32  # Number of tasks per meta-update
    inner_batch_size = 5  # Number of examples per task
    num_inner_updates = 5  # Number of inner loop updates per task
    num_meta_updates = 100

    cwts = [encrypt_wt(model.get_weights(), i) for i, model in enumerate(models)]
    for e in tqdm(range(epoch)):
        cwts = aggregate_wt(cwts)
        wts = decrypt_weights(cwts)
        cwts = []
        for wt, model, dataset, i in zip(
            wts, models, dataset_parts, range(len(clients))
        ):
            model.set_weights(wt)
            model, accuracy = train_model(model, dataset[0], dataset[1])
            accuracies[i].append(accuracy)
            print("Accuracies", accuracy)
            wtarray = model.get_weights()
            cwts.append(encrypt_wt(wtarray, i))

    import matplotlib.pyplot as plt

    epochs_range = range(1, epoch + 1)

    plt.figure(figsize=(10, 5))
    for i, client in enumerate(clients):
        plt.plot(
            epochs_range,
            accuracies[i],
            label=f"Client {client}" if client != 0 else "Aggregate",
        )
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title("Accuracy for Each Client")
    plt.show()

    plt.figure(figsize=(10, 5))
    for i, client in enumerate(clients):
        plt.plot(
            epochs_range,
            losses[i],
            label=f"Client {client}" if client != 0 else "Aggregate",
        )
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Loss for Each Client")
    plt.show()


if __name__ == "__main__":
    main()

Meta(
  (net): Learner(
    (vars): ParameterList(
        (0): Parameter containing: [torch.float32 of size 32x3x3x3 (cuda:0)]
        (1): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (2): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (3): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (4): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:0)]
        (5): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (6): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (7): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (8): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:0)]
        (9): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (10): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (11): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (12): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:

NameError: name 'tf' is not defined

In [24]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
import numpy as np
from PIL import Image
import csv
import random
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn
from tqdm import tqdm
import copy
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import dh
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
import pickle
import sys
from sklearn.model_selection import train_test_split
from Pyfhel import Pyfhel

clients = [0, 1, 2, 3]

HE = Pyfhel()
ckks_params = {
    "scheme": "CKKS",
    "n": 2**14,
    "scale": 2**30,
    "qi_sizes": [60, 30, 30, 30, 60],
}
HE.contextGen(**ckks_params)
HE.keyGen()
HE.rotateKeyGen()
class MiniImagenet(Dataset):
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        self.batchsz = batchsz
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.setsz = self.n_way * self.k_shot
        self.querysz = self.n_way * self.k_query
        self.resize = resize
        self.startidx = startidx
        print(
            "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
            % (mode, batchsz, n_way, k_shot, k_query, resize)
        )

        if mode == "train":
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        self.path = os.path.join(root, "images")
        csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)
            self.img2label[k] = i + self.startidx
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=",")
            next(csvreader, None)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        self.support_x_batch = []
        self.query_x_batch = []
        for b in range(batchsz):
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                selected_imgs_idx = np.random.choice(
                    len(self.data[cls]), self.k_shot + self.k_query, False
                )
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
                indexDtest = np.array(selected_imgs_idx[self.k_shot :])
                support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)
            self.query_x_batch.append(query_x)

    def __getitem__(self, index):
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)

        flatten_support_x = [
            os.path.join(self.path, item)
            for sublist in self.support_x_batch[index]
            for item in sublist
        ]
        support_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.support_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        flatten_query_x = [
            os.path.join(self.path, item)
            for sublist in self.query_x_batch[index]
            for item in sublist
        ]
        query_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.query_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        unique = np.unique(support_y)
        random.shuffle(unique)
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)

        return (
            support_x,
            torch.LongTensor(support_y_relative),
            query_x,
            torch.LongTensor(query_y_relative),
        )

    def __len__(self):
        return self.batchsz


class Learner(nn.Module):
    def __init__(self, config, imgc, imgsz):
        super(Learner, self).__init__()
        self.config = config
        self.vars = nn.ParameterList()
        self.vars_bn = nn.ParameterList()

        for i, (name, param) in enumerate(self.config):
            if name == "conv2d":
                w = nn.Parameter(torch.ones(*param[:4]))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "linear":
                w = nn.Parameter(torch.ones(*param))
                torch.nn.init.kaiming_normal_(w)
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
            elif name == "bn":
                w = nn.Parameter(torch.ones(param[0]))
                self.vars.append(w)
                self.vars.append(nn.Parameter(torch.zeros(param[0])))
                running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False)
                running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False)
                self.vars_bn.extend([running_mean, running_var])

    def forward(self, x, vars=None, bn_training=True):
        if vars is None:
            vars = self.vars

        idx = 0
        bn_idx = 0
        for name, param in self.config:
            if name == "conv2d":
                w, b = vars[idx], vars[idx + 1]
                x = F.conv2d(x, w, b, stride=param[4], padding=param[5])
                idx += 2
            elif name == "linear":
                w, b = vars[idx], vars[idx + 1]
                x = F.linear(x, w, b)
                idx += 2
            elif name == "bn":
                w, b = vars[idx], vars[idx + 1]
                running_mean, running_var = (
                    self.vars_bn[bn_idx],
                    self.vars_bn[bn_idx + 1],
                )
                x = F.batch_norm(
                    x, running_mean, running_var, weight=w, bias=b, training=bn_training
                )
                idx += 2
                bn_idx += 2
            elif name == "flatten":
                x = x.view(x.size(0), -1)
            elif name == "relu":
                x = F.relu(x, inplace=param[0])
            elif name == "max_pool2d":
                x = F.max_pool2d(x, param[0], param[1], param[2])

        assert idx == len(vars)
        assert bn_idx == len(self.vars_bn)

        return x

    def parameters(self):
        return self.vars


class Meta(nn.Module):
    def __init__(
        self,
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ):
        super(Meta, self).__init__()
        self.n_way = n_way
        self.k_spt = k_spt
        self.k_qry = k_qry
        self.task_num = task_num
        self.update_step = update_step
        self.update_step_test = update_step_test
        self.update_lr = update_lr
        self.meta_lr = meta_lr
        self.config = config
        self.imgc = imgc
        self.imgsz = imgsz

        self.net = Learner(config, imgc, imgsz)
        self.meta_optim = optim.Adam(self.net.parameters(), lr=meta_lr)

    def forward(self, x_spt, y_spt, x_qry, y_qry):
        task_num = x_spt.size(0)
        querysz = x_qry.size(1)

        losses_q = [0 for _ in range(self.update_step + 1)]
        corrects = [0 for _ in range(self.update_step + 1)]

        for i in range(task_num):
            logits = self.net(x_spt[i], vars=None, bn_training=True)
            loss = F.cross_entropy(logits, y_spt[i])
            grad = torch.autograd.grad(loss, self.net.parameters())
            fast_weights = list(
                map(
                    lambda p: p[1] - self.update_lr * p[0],
                    zip(grad, self.net.parameters()),
                )
            )

            with torch.no_grad():
                logits_q = self.net(x_qry[i], self.net.parameters(), bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[0] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[0] = corrects[0] + correct

            with torch.no_grad():
                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[1] += loss_q
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry[i]).sum().item()
                corrects[1] = corrects[1] + correct

            for k in range(1, self.update_step):
                logits = self.net(x_spt[i], fast_weights, bn_training=True)
                loss = F.cross_entropy(logits, y_spt[i])
                grad = torch.autograd.grad(loss, fast_weights)
                fast_weights = list(
                    map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
                )

                logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
                loss_q = F.cross_entropy(logits_q, y_qry[i])
                losses_q[k + 1] += loss_q

                with torch.no_grad():
                    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                    correct = torch.eq(pred_q, y_qry[i]).sum().item()
                    corrects[k + 1] = corrects[k + 1] + correct

        loss_q = losses_q[-1] / task_num
        self.meta_optim.zero_grad()
        loss_q.backward()
        self.meta_optim.step()

        accs = np.array(corrects) / (querysz * task_num)
        return accs

    def finetunning(self, x_spt, y_spt, x_qry, y_qry):
        assert len(x_spt.shape) == 4
        querysz = x_qry.size(0)
        corrects = [0 for _ in range(self.update_step_test + 1)]
        net = deepcopy(self.net)

        logits = net(x_spt)
        loss = F.cross_entropy(logits, y_spt)
        grad = torch.autograd.grad(loss, net.parameters())
        fast_weights = list(
            map(lambda p: p[1] - self.update_lr * p[0], zip(grad, net.parameters()))
        )

        with torch.no_grad():
            logits_q = net(x_qry, net.parameters(), bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[0] = corrects[0] + correct

        with torch.no_grad():
            logits_q = net(x_qry, fast_weights, bn_training=True)
            pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
            correct = torch.eq(pred_q, y_qry).sum().item()
            corrects[1] = corrects[1] + correct

        for k in range(1, self.update_step_test):
            logits = net(x_spt, fast_weights, bn_training=True)
            loss = F.cross_entropy(logits, y_spt)
            grad = torch.autograd.grad(loss, fast_weights)
            fast_weights = list(
                map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights))
            )

            logits_q = net(x_qry, fast_weights, bn_training=True)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, y_qry).sum().item()
                corrects[k + 1] = corrects[k + 1] + correct

        del net
        accs = np.array(corrects) / querysz
        return accs


def train_model(model, train_loader):
    for meta_update in range(num_meta_updates):
        # Sample a meta-batch of tasks
        x_spt, y_spt, x_qry, y_qry = next(iter(train_loader))
        x_spt, y_spt, x_qry, y_qry = (
            x_spt.to(device),
            y_spt.to(device),
            x_qry.to(device),
            y_qry.to(device),
        )

        # Perform meta-update
        accs = model(x_spt, y_spt, x_qry, y_qry)

    return model, accs[-1]


def generate_diffie_hellman_parameters():
    parameters = dh.generate_parameters(generator=2, key_size=512)
    return parameters


def generate_diffie_hellman_keys(parameters):
    private_key = parameters.generate_private_key()
    public_key = private_key.public_key()
    return private_key, public_key


def derive_key(private_key, peer_public_key):
    shared_key = private_key.exchange(peer_public_key)
    derived_key = HKDF(
        algorithm=hashes.SHA256(),
        length=32,
        salt=None,
        info=b"handshake data",
    ).derive(shared_key)
    return derived_key


def encrypt_message_AES(key, message):
    serialized_obj = pickle.dumps(message)
    cipher = Cipher(algorithms.AES(key), modes.ECB())
    encryptor = cipher.encryptor()
    padded_obj = serialized_obj + b" " * (16 - len(serialized_obj) % 16)
    ciphertext = encryptor.update(padded_obj) + encryptor.finalize()
    return ciphertext


def decrypt_message_AES(key, ciphertext):
    cipher = Cipher(algorithms.AES(key), modes.ECB())
    decryptor = cipher.decryptor()
    padded_obj = decryptor.update(ciphertext) + decryptor.finalize()
    serialized_obj = padded_obj.rstrip(b" ")
    obj = pickle.loads(serialized_obj)
    return obj


def setup_AES():
    num_clients = len(clients)
    parameters = generate_diffie_hellman_parameters()
    server_private_key, server_public_key = generate_diffie_hellman_keys(parameters)
    client_keys = [generate_diffie_hellman_keys(parameters) for _ in range(num_clients)]
    shared_keys = [
        derive_key(server_private_key, client_public_key)
        for _, client_public_key in client_keys
    ]
    client_shared_keys = [
        derive_key(client_private_key, server_public_key)
        for client_private_key, _ in client_keys
    ]

    return client_keys, shared_keys, client_shared_keys


def encrypt_wt(state_dict, i):
    cwt = {}
    for layer_name, layer_weights in state_dict.items():
        flat_array = layer_weights.cpu().numpy().astype(np.float64).flatten()
        chunks = np.array_split(flat_array, (len(flat_array) + 2**10 - 1) // 2**10)
        clayer = []
        for chunk in chunks:
            ptxt = HE.encodeFrac(chunk)
            ctxt = HE.encryptPtxt(ptxt)
            clayer.append(ctxt)
        cwt[layer_name] = clayer
    ciphertext = encrypt_message_AES(client_shared_keys[i], cwt)
    return ciphertext


def aggregate_wt(encrypted_cwts):
    cwts = []
    for i, ecwt in enumerate(encrypted_cwts):
        cwts.append(decrypt_message_AES(shared_keys[i], ecwt))

    resmodel = {}
    for layer_name in cwts[0].keys():  # for layers
        layer = []
        for k in range(len(cwts[0][layer_name])):  # for chunks
            tmp = cwts[0][layer_name][k].copy()
            for i in range(1, len(cwts)):  # for clients
                tmp = tmp + cwts[i][layer_name][k]
            tmp = tmp / len(cwts)
            layer.append(tmp)
        resmodel[layer_name] = layer

    res = [resmodel.copy() for _ in range(len(clients))]
    return res


def decrypt_weights(res):
    decrypted_weights = []
    for client_weights, model in zip(res, models):
        decrypted_client_weights = {}
        for layer_name, layer_weights in client_weights.items():
            decrypted_layer_weights = []
            for encrypted_chunk in layer_weights:
                decrypted_chunk = HE.decryptFrac(encrypted_chunk)
                decrypted_layer_weights.append(decrypted_chunk)
            decrypted_layer_weights = np.concatenate(decrypted_layer_weights, axis=0)
            decrypted_layer_weights = torch.from_numpy(decrypted_layer_weights)
            decrypted_client_weights[layer_name] = decrypted_layer_weights
        decrypted_weights.append(decrypted_client_weights)
    return decrypted_weights


def main():
    clients = [0, 1, 2, 3]
    epoch = 60000
    n_way = 5
    k_spt = 1
    k_qry = 15
    imgsz = 84
    imgc = 3
    task_num = 4
    meta_lr = 1e-3
    update_lr = 0.01
    update_step = 5
    update_step_test = 10

    # Set the paths to your mini-imagenet dataset
    root = "./mini-imagenet"

    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    config = [
        ("conv2d", [32, 3, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 2, 0]),
        ("conv2d", [32, 32, 3, 3, 1, 0]),
        ("relu", [True]),
        ("bn", [32]),
        ("max_pool2d", [2, 1, 0]),
        ("flatten", []),
        ("linear", [n_way, 32 * 5 * 5]),
    ]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    maml = Meta(
        n_way,
        k_spt,
        k_qry,
        task_num,
        update_step,
        update_step_test,
        update_lr,
        meta_lr,
        config,
        imgc,
        imgsz,
    ).to(device)

    tmp = filter(lambda x: x.requires_grad, maml.parameters())
    num = sum(map(lambda x: np.prod(x.shape), tmp))
    print(maml)
    print("Total trainable tensors:", num)

    # batchsz here means total episode number
    mini = MiniImagenet(
        root,
        mode="train",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=10000,
        resize=imgsz,
    )
    mini_test = MiniImagenet(
        root,
        mode="test",
        n_way=n_way,
        k_shot=k_spt,
        k_query=k_qry,
        batchsz=100,
        resize=imgsz,
    )

    # Split data into n parts
    n_parts = len(clients)
    dataset_parts = []
    for i in range(n_parts):
        train_loader = DataLoader(
            mini, batch_size=task_num, shuffle=True, num_workers=1, pin_memory=True
        )
        dataset_parts.append(train_loader)

    models = [maml for _ in range(len(clients))]

    HE = Pyfhel()
    ckks_params = {
        "scheme": "CKKS",
        "n": 2**14,
        "scale": 2**30,
        "qi_sizes": [60, 30, 30, 30, 60],
    }
    HE.contextGen(**ckks_params)
    HE.keyGen()
    HE.rotateKeyGen()

    client_keys, shared_keys, client_shared_keys = setup_AES()

    accuracies = [[] for _ in range(len(clients))]
    losses = [[] for _ in range(len(clients))]

    meta_batch_size = 32  # Number of tasks per meta-update
    inner_batch_size = 5  # Number of examples per task
    num_inner_updates = 5  # Number of inner loop updates per task
    num_meta_updates = 100

    cwts = [encrypt_wt(model.state_dict(), i) for i, model in enumerate(models)]
    for e in tqdm(range(epoch)):
        cwts = aggregate_wt(cwts)
        wts = decrypt_weights(cwts)
        cwts = []
        for wt, model, dataset, i in zip(wts, models, dataset_parts, range(len(clients))):
            model.load_state_dict(wt)
            model, accuracy = train_model(model, dataset)
            accuracies[i].append(accuracy.item())
            print("Accuracies", accuracy)

            # Store the loss value for each client
            loss = 1.0 - accuracy.item()
            losses[i].append(loss)

            wtarray = model.state_dict()
            cwts.append(encrypt_wt(wtarray, i))

    import matplotlib.pyplot as plt

    epochs_range = range(1, epoch + 1)

    plt.figure(figsize=(10, 5))
    for i, client in enumerate(clients):
        plt.plot(
            epochs_range,
            accuracies[i],
            label=f"Client {client}" if client != 0 else "Aggregate",
        )
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title("Accuracy for Each Client")
    plt.show()

    plt.figure(figsize=(10, 5))
    for i, client in enumerate(clients):
        plt.plot(
            epochs_range,
            losses[i],
            label=f"Client {client}" if client != 0 else "Aggregate",
        )
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Loss for Each Client")
    plt.show()


if __name__ == "__main__":
    main()

Meta(
  (net): Learner(
    (vars): ParameterList(
        (0): Parameter containing: [torch.float32 of size 32x3x3x3 (cuda:0)]
        (1): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (2): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (3): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (4): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:0)]
        (5): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (6): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (7): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (8): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:0)]
        (9): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (10): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (11): Parameter containing: [torch.float32 of size 32 (cuda:0)]
        (12): Parameter containing: [torch.float32 of size 32x32x3x3 (cuda:

NameError: name 'client_shared_keys' is not defined