In [50]:
import os
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms

In [77]:
BASE_DIR = os.path.join('..','Neural Networks for Images')

image_path = os.path.join(BASE_DIR,'imagenet-mini')
labels_path = os.path.join(BASE_DIR, 'words.txt')

class ImageNetMiniDataset(Dataset):
    def __init__(self, annotations_file, img_dir, resize=(244,244), train=True, transform=None, target_transform=None):
        self.img_labels = self.read_labels(annotations_file)
        self.img_dir_paths = self.read_image_paths(img_dir,train)
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.resize_dim = resize

    def read_labels(self,path):
        labels_dict = dict()
        with open(path, 'r') as f:
            while True:
                label = f.readline()
                if len(label) == 0:
                    break
                label = label.split('\t')
                labels_dict[label[0]] = label[1]
        return labels_dict

    def read_image_paths(self,img_dir,train):
        sub_path = 'train' if train else 'test'
        result = []
        path = os.path.join(img_dir, sub_path)
        # result_count = torch.tensor([])
        for current in os.listdir(path):
            current_path = os.path.join(path,current)
            for current_file in os.listdir(current_path):
                result.append((os.path.join(current_path,current_file),current))
                # result_count = torch.cat((result_count,torch.tensor([int(current[1:])])))
                # print(result[-1])
        # print(torch.unique(result_count,return_counts=True))
        return result

    def __len__(self):
        return len(self.img_dir_paths)

    def __getitem__(self, idx):
        img_path, label = self.img_dir_paths[idx]
        image = cv2.resize(cv2.imread(img_path), self.resize_dim, interpolation=cv2.INTER_AREA)
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)


class LazyDataLoader:

    def __init__(self, dataset, batch_size=10, shuffle=False, num_workers=0):
        self.dataset = dataset
        self.batch_size = batch_size
        # self.shuffle = shuffle
        self.num_workers = num_workers
        self.total_batches = len(dataset)//batch_size + (0 if len(dataset)%batch_size == 0 else 1)
        self.extras = 0 if len(dataset)%batch_size == 0 else len(dataset)%batch_size
    
    def __len__(self):
        return self.total_batches
    
    def __iter__(self):
        counter = np.arange(self.total_batches*self.batch_size)
        np.random.shuffle(counter)
        self.counter = counter.reshape((self.total_batches,self.batch_size,-1))
        # print(self.counter.shape)
        self.current_batch = 0
        return self
    
    def fetch_data(self, index):
        return self.dataset[index]
    
    
    def __next__(self):
        if self.num_workers != 0:
            # TODO
            # thread_pool = multiprocessing.Pool(self.num_workers)
            # for current_batch in range(self.total_batches):
            #     indexes = [i for i in range()]
            #     tensors_per_batch = thread_pool.map(fetch_all_data, )
            pass
        else:
            # for current_batch in range(self.total_batches):
            x_all = []
            y_all = []
            for i in self.counter[self.current_batch]:
                if i >= len(self.dataset):
                    continue
                x,y = self.fetch_data(i[0])
                x_all.append(x)
                y_all.append(int(y[1:]))
            self.current_batch += 1
            return torch.stack(x_all),torch.Tensor(y_all)
        raise StopIteration

In [78]:
image_net_train_dataset = ImageNetMiniDataset(labels_path,image_path,resize=(244,244),transform=transform)
# print(len(image_net_train_dataset.img_labels))
# for i in range(25):
    # sample_idx = random.randint(0,len(image_net_train_dataset))
    # data = image_net_train_dataset[1]
    # plt.imshow(img.permute(1,2,0))
    # plt.title(label)
    # plt.show()


# train_dataloader = LazyDataLoader(image_net_train_dataset, batch_size=1000, shuffle=True, num_workers=0)
# counter = 0
# result = torch.tensor([])
# total_count = None
# for data in train_dataloader:
#     # print(data[0].shape)
#     # print(data[1])
#     result = torch.cat((data[1],result),dim=0)
#     counter += 1
#     print(counter)
# print(torch.unique(result))
# values = next(train_dataloader_iter)
# print(values)


(tensor([ 1440764.,  1443537.,  1484850.,  1491361.,  1494475.,  1496331.,
         1498041.,  1514668.,  1514859.,  1518878.,  1530575.,  1531178.,
         1532829.,  1534433.,  1537544.,  1558993.,  1560419.,  1580077.,
         1582220.,  1592084.,  1601694.,  1608432.,  1614925.,  1616318.,
         1622779.,  1629819.,  1630670.,  1631663.,  1632458.,  1632777.,
         1641577.,  1644373.,  1644900.,  1664065.,  1665541.,  1667114.,
         1667778.,  1669191.,  1675722.,  1677366.,  1682714.,  1685808.,
         1687978.,  1688243.,  1689811.,  1692333.,  1693334.,  1694178.,
         1695060.,  1697457.,  1698640.,  1704323.,  1728572.,  1728920.,
         1729322.,  1729977.,  1734418.,  1735189.,  1737021.,  1739381.,
         1740131.,  1742172.,  1744401.,  1748264.,  1749939.,  1751748.,
         1753488.,  1755581.,  1756291.,  1768244.,  1770081.,  1770393.,
         1773157.,  1773549.,  1773797.,  1774384.,  1774750.,  1775062.,
         1776313.,  1784675.,  179554