In [49]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

import numpy as np

from collections import OrderedDict, defaultdict

In [67]:
class ClientDataset(Dataset):
    def __init__(self, img_tensors, lbl_tensors, transform=None):
        self.img_tensors = img_tensors
        self.lbl_tensors = lbl_tensors
        self.transform = transform
    
    def __len__(self):
        return self.lbl_tensors.shape[0]
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        return self.img_tensors[idx], self.lbl_tensors[idx]
    
def create_client_data_loaders(client_nums, data_folder, batch_size, random_mode=False):
    data_loaders = []
    for idx in range(client_nums):
        # loading data to tensors
        img_tensor_file = data_folder + f'client_{idx}_img.pt'
        lbl_tensor_file = data_folder + f'client_{idx}_lbl.pt'
        img_tensors = torch.load(img_tensor_file) # this contains 494 images, currently 76
        lbl_tensors = torch.load(lbl_tensor_file)

        # creating a dataset which can be fed to dataloader
        client_dataset = ClientDataset(img_tensors, lbl_tensors)
        data_loader = DataLoader(client_dataset, batch_size=batch_size, shuffle=random_mode)
        data_loaders.append(data_loader)
    return data_loaders

In [78]:
# Save the tensor of images and labels for clients
username = 'fnx11'
data_folder = f'/home/{username}/thesis/codes/Playground/data/fed_data/'
client_nums = 20
client_data_loaders = create_client_data_loaders(client_nums, data_folder, 8)

### Let's see some of the loaders

In [79]:
import matplotlib.pyplot as plt
%matplotlib inline

# helper function to un-normalize and display an image
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(np.transpose(img, (1, 2, 0)))

def visualize(images, labels):
    images = images.numpy() # convert images to numpy for display
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(25, 4))
    # display 20 images
    for idx in np.arange(8):
        ax = fig.add_subplot(2, 8/2, idx+1, xticks=[], yticks=[])
        imshow(images[idx])
        ax.set_title(classes[labels[idx]])
    

In [84]:
client_iter = iter(client_data_loaders[15])
# specify the image classes
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']
for step in range(len(client_iter)):
    images, labels = next(client_iter)  
#     visualize(images, labels)
    print(images.shape)
    print(labels.shape)
    break

torch.Size([8, 3, 32, 32])
torch.Size([8])
