In [1]:
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch

In [2]:
class data_loader(Dataset):
    """
    This dataloader loads the tensor input and target
    """
    def __init__(self, path: str, ind: list, device):
        """
        Args:
            path (str): path to the input & target folder.
            ind (list): list of indices for which pictures to load.
            device (class 'torch.device'): which pytorch device the data should
            be sent to.
        """

        self.device = device
        self.imgs_path = path
        self.data = []
        for i in ind:
            self.data.append([self.imgs_path + f"/model_input ({i}).pt",
                        self.imgs_path + f"/model_target ({i}).pt"])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_path, target_path = self.data[idx] # path for target + input

        inp = torch.load(input_path) # load the input data
        inp = inp.type(torch.float).to(self.device)

        tar = torch.load(target_path) # load the target data
        tar = tar.type(torch.float).to(self.device)

        return inp, tar

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
trainload = data_loader(path = "C:/Users/Marc/Desktop/model_data", ind = [i for i in range(1, 310 + 1)], device = device)
batch_size = 1

# Set up the dataloaders:
trainloader = torch.utils.data.DataLoader(trainload,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=0)

In [7]:
data = []

for i in trainloader:
    data.append([i[0].shape[1], i[0].shape[2]]) 

In [10]:
sum(1 for i in data if i[1] > 90000)

281

In [11]:
sum(1 for i in data)

310

In [6]:
sorted(data, key=lambda j: j[1])

[[22, 16750],
 [22, 35250],
 [2, 48384],
 [22, 57344],
 [6, 60160],
 [5, 62720],
 [4, 63744],
 [6, 65280],
 [20, 74000],
 [8, 75264],
 [16, 76800],
 [22, 76800],
 [22, 76800],
 [4, 76800],
 [5, 76800],
 [22, 76800],
 [22, 76800],
 [22, 76800],
 [18, 76800],
 [17, 76800],
 [13, 76800],
 [13, 76800],
 [22, 76800],
 [21, 76800],
 [22, 76800],
 [4, 76800],
 [3, 76800],
 [6, 78080],
 [10, 90000],
 [22, 94464],
 [22, 100500],
 [20, 110800],
 [20, 121600],
 [20, 122000],
 [22, 138750],
 [20, 141600],
 [20, 150250],
 [22, 150250],
 [19, 150250],
 [2, 150250],
 [19, 150250],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [21, 153856],
 [22, 153856],
 [6, 153856],
 [4, 153856],
 [4, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [16, 153856],
 [17, 153856],
 [20, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [22, 153856],
 [15, 153856],
 [22, 153856],
 [21, 153856],
 [22, 153856],
 [2, 153856],
 [10, 15385

In [36]:
sum(i[1] for i in data)

93513546