In [1]:
import torch
import numpy as np
from MiniImagenet import MiniImagenet
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim
from copy import deepcopy
import torch.nn as nn

In [None]:
class MiniImagenet(Dataset):
    def __init__(self, root, mode, batchsz, n_way, k_shot, k_query, resize, startidx=0):
        self.batchsz = batchsz
        self.n_way = n_way
        self.k_shot = k_shot
        self.k_query = k_query
        self.setsz = self.n_way * self.k_shot
        self.querysz = self.n_way * self.k_query
        self.resize = resize
        self.startidx = startidx
        print(
            "shuffle DB :%s, b:%d, %d-way, %d-shot, %d-query, resize:%d"
            % (mode, batchsz, n_way, k_shot, k_query, resize)
        )

        if mode == "train":
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )
        else:
            self.transform = transforms.Compose(
                [
                    lambda x: Image.open(x).convert("RGB"),
                    transforms.Resize((self.resize, self.resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ]
            )

        self.path = os.path.join(root, "images")
        csvdata = self.loadCSV(os.path.join(root, mode + ".csv"))
        self.data = []
        self.img2label = {}
        for i, (k, v) in enumerate(csvdata.items()):
            self.data.append(v)
            self.img2label[k] = i + self.startidx
        self.cls_num = len(self.data)

        self.create_batch(self.batchsz)

    def loadCSV(self, csvf):
        dictLabels = {}
        with open(csvf) as csvfile:
            csvreader = csv.reader(csvfile, delimiter=",")
            next(csvreader, None)
            for i, row in enumerate(csvreader):
                filename = row[0]
                label = row[1]
                if label in dictLabels.keys():
                    dictLabels[label].append(filename)
                else:
                    dictLabels[label] = [filename]
        return dictLabels

    def create_batch(self, batchsz):
        self.support_x_batch = []
        self.query_x_batch = []
        for b in range(batchsz):
            selected_cls = np.random.choice(self.cls_num, self.n_way, False)
            np.random.shuffle(selected_cls)
            support_x = []
            query_x = []
            for cls in selected_cls:
                selected_imgs_idx = np.random.choice(
                    len(self.data[cls]), self.k_shot + self.k_query, False
                )
                np.random.shuffle(selected_imgs_idx)
                indexDtrain = np.array(selected_imgs_idx[: self.k_shot])
                indexDtest = np.array(selected_imgs_idx[self.k_shot :])
                support_x.append(np.array(self.data[cls])[indexDtrain].tolist())
                query_x.append(np.array(self.data[cls])[indexDtest].tolist())

            random.shuffle(support_x)
            random.shuffle(query_x)

            self.support_x_batch.append(support_x)
            self.query_x_batch.append(query_x)

    def __getitem__(self, index):
        support_x = torch.FloatTensor(self.setsz, 3, self.resize, self.resize)
        support_y = np.zeros((self.setsz), dtype=np.int32)
        query_x = torch.FloatTensor(self.querysz, 3, self.resize, self.resize)
        query_y = np.zeros((self.querysz), dtype=np.int32)

        flatten_support_x = [
            os.path.join(self.path, item)
            for sublist in self.support_x_batch[index]
            for item in sublist
        ]
        support_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.support_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        flatten_query_x = [
            os.path.join(self.path, item)
            for sublist in self.query_x_batch[index]
            for item in sublist
        ]
        query_y = np.array(
            [
                self.img2label[item[:9]]
                for sublist in self.query_x_batch[index]
                for item in sublist
            ]
        ).astype(np.int32)

        unique = np.unique(support_y)
        random.shuffle(unique)
        support_y_relative = np.zeros(self.setsz)
        query_y_relative = np.zeros(self.querysz)
        for idx, l in enumerate(unique):
            support_y_relative[support_y == l] = idx
            query_y_relative[query_y == l] = idx

        for i, path in enumerate(flatten_support_x):
            support_x[i] = self.transform(path)

        for i, path in enumerate(flatten_query_x):
            query_x[i] = self.transform(path)

        return (
            support_x,
            torch.LongTensor(support_y_relative),
            query_x,
            torch.LongTensor(query_y_relative),
        )

    def __len__(self):
        return self.batchsz

In [6]:
import os
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data import DataLoader, Dataset, Subset
from torch.utils.data import random_split, SubsetRandomSampler
from torchvision import datasets, transforms, models 
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
import pytorch_lightning as pl
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from PIL import Image


In [7]:
root_dir = "./mini-imagenet"
csv_file = "./mini-imagenet/train.csv"

In [8]:
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import pandas as pd
import os

In [9]:
clients = [0 , 1 ,2]

In [None]:
n_parts = len(clients)
part_size = len(X_train) // n_parts
dataset_parts = []
for i in range(n_parts):
    start = i * part_size
    end = (i + 1) * part_size
    X_part = X_train[start:end]
    y_part = y_train[start:end]
    dataset_parts.append((X_part, y_part))

In [17]:
root_dir = root_dir
csv_file = csv_file

data = pd.read_csv(csv_file)
data = data.sample(frac=1).reset_index(drop=True)

n_parts = len(clients)
dataset_parts = []
part_size = len(data) // n_parts
for i in range(n_parts):
    start = i * part_size
    end = (i + 1) * part_size
    dataset_parts.append(data.iloc[start:end].values)

In [18]:
dataset_parts

[array([['n0774760700001069.jpg', 'n07747607'],
        ['n0450941700000760.jpg', 'n04509417'],
        ['n0383889900001242.jpg', 'n03838899'],
        ...,
        ['n0461250400000001.jpg', 'n04612504'],
        ['n0279516900001176.jpg', 'n02795169'],
        ['n0388860500000226.jpg', 'n03888605']], dtype=object),
 array([['n0425813800001252.jpg', 'n04258138'],
        ['n0388860500000014.jpg', 'n03888605'],
        ['n0282342800000353.jpg', 'n02823428'],
        ...,
        ['n0425813800000337.jpg', 'n04258138'],
        ['n0211127700000559.jpg', 'n02111277'],
        ['n0210855100000336.jpg', 'n02108551']], dtype=object),
 array([['n0279516900000399.jpg', 'n02795169'],
        ['n0216545600000178.jpg', 'n02165456'],
        ['n0450941700001187.jpg', 'n04509417'],
        ...,
        ['n0367648300000751.jpg', 'n03676483'],
        ['n1305456000001232.jpg', 'n13054560'],
        ['n0424354600001190.jpg', 'n04243546']], dtype=object)]

In [None]:
dataset = torch.utils.data.TensorDataset(torch.tensor(X), torch.tensor(y))