In [140]:
import os
import torch
from abc import ABC
from pathlib import Path
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import TensorDataset,DataLoader,random_split
from torchvision import transforms

from typing import Optional, Union, Tuple

class BasicDataSet(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
class DictDataSet(Dataset):
    """
    # Define your data dictionary
    data_dict = {'input': torch.randn(2, 10), 'target': torch.randn(2, 5)}

    # Create your dataset
    my_dataset = DictDataSet(data_dict)

    # Create a DataLoader from your dataset
    batch_size = 2
    dataloader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)
    """
    def __init__(self, data_dict,transforms=None,key_of_transforms=None):
        self.data_dict = data_dict
        self.keys = list(data_dict.keys())
        if transforms is not None:
            self.transforms = transforms
            self.key_of_transforms = key_of_transforms
        print(self.keys)
        print(self.key_of_transforms)
        
    def __len__(self):
        return len(self.data_dict[self.keys[0]])

    def __getitem__(self, idx):
        batch_dict = {}
        for key in self.keys:
            if key == self.key_of_transforms:
                batch_dict[key] = self.transforms(self.data_dict[key][idx]).squeeze()
            else:
                batch_dict[key] = self.data_dict[key][idx]
        return batch_dict

    
with open("./data/simulations_metadata.cp","rb") as file:
    simulations_metadata = pickle.load(file)
    
with open("./data/client_metadata.cp","rb") as file:
    client_metadata = pickle.load(file)

from torchvision import transforms

class BaseDataLoader(ABC):

    name_="base_data_loader"

    def __init__(self,X:Union[torch.Tensor,dict]=None,type="client",batch_size:int = 32,training_proportion:float = 0.9,device:torch.device=torch.device("cpu"),rank:int=0,**kwargs):
        super(BaseDataLoader,self).__init__()
        self.training_proportion = training_proportion
        self.batch_size = batch_size
        
        if type == "client":
            self.transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[client_metadata["rescaling"][0]], std=[client_metadata["rescaling"][0]]),
                transforms.Resize((11, 24))
            ])
            self.key_of_transforms = "y_diff"
        elif type== "simulations":
            self.transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[simulations_metadata["rescaling"][0]], std=[simulations_metadata["rescaling"][0]]),
                transforms.Resize((11, 24))
            ])
            self.key_of_transforms = "nds"
            
        self.define_dataset_and_dataloaders(X)
        
    def define_dataset_and_dataloaders(self,X,training_proportion=None,batch_size=None):
        if training_proportion is not None:
            self.training_proportion=training_proportion
        if batch_size is not None:
            self.batch_size = batch_size

        if isinstance(X,torch.Tensor):
            dataset = TensorDataset(X)
        elif isinstance(X,dict):
            dataset = DictDataSet(X,self.transforms,self.key_of_transforms)

        self.total_data_size = len(dataset)
        self.training_data_size = int(self.training_proportion * self.total_data_size)
        self.test_data_size = self.total_data_size - self.training_data_size

        training_dataset, test_dataset = random_split(dataset, [self.training_data_size, self.test_data_size])
        self._train_iter = DataLoader(training_dataset, batch_size=self.batch_size)
        self._test_iter = DataLoader(test_dataset, batch_size=self.batch_size)

    def train(self):
        return self._train_iter

    def test(self):
        return self._test_iter


In [165]:
import os
import pickle

with open("./data/client_.cp","rb") as file:
    clients_data = pickle.load(file)
    #clients_data["y_diff"] = torch.Tensor(clients_data["y_diff"])
with open("./data/simulations_.cp","rb") as file:
    simulations_data = pickle.load(file)
    #simulations_data["nds"] = torch.Tensor(simulations_data["nds"])

client_data_loader = BaseDataLoader(clients_data,type="client")
simulations_data_loader = BaseDataLoader(simulations_data,type="simulations")

['names', 'y_diff', 'PoDm_dist']
y_diff
['nds', 'PoDmD', 'name']
nds


In [168]:
client_databatch = next(client_data_loader.train().__iter__())
simulations_databatch = next(simulations_data_loader.train().__iter__())

In [169]:
databatch['y_diff'].shape

torch.Size([32, 11, 24])

In [171]:
databatch["nds"].shape

torch.Size([32, 11, 24])

In [82]:
.shape

torch.Size([32, 88, 24])

In [None]:
transforms.Compose([])