In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

#Creaeting a data set to be trained over
class ImgDataset:
    """
    Img Dataset.
    """
    def __init__(self, batch_size=4, dataset_path='path'): # Fill in the correct path for the data set
        self.batch_size = batch_size
        self.dataset_path = dataset_path
        self.train_dataset = self.get_train_numpy()
        self.x_mean, self.x_std = self.compute_train_statistics()
        self.transform = self.get_transforms()
        self.train_loader, self.val_loader = self.get_dataloaders()

    def get_train_numpy(self):
        train_dataset = torchvision.datasets.ImageFolder(os.path.join(self.dataset_path, 'train'))
        train_x = np.zeros((len(train_dataset), 64, 64, 3))
        for i, (img, _) in enumerate(train_dataset):
            train_x[i] = img
        return train_x / 255.0

    def compute_train_statistics(self):
        # Compute per-channel mean and std with respect to self.train_dataset
        print(self.train_dataset.shape)
        x_mean = np.mean(self.train_dataset, axis = (0,1,2))  # per-channel mean -> I assumed (# Data, Row, Col,1); i.e. 1 channel in our case
        x_std = np.std(self.train_dataset, axis = (0,1,2))  # per-channel std -> same
        return x_mean, x_std

    def get_transforms(self):
        # Fill in the data transforms
        transform_list = [
            transforms.Resize((32,32)),  # resize the image to 32x32x3 -> can change to what we want
            transforms.ToTensor(), # convert image to PyTorch tensor
            transforms.Normalize(self.x_mean,self.x_std) # normalize the image (use self.x_mean and self.x_std)
        ]
        transform = transforms.Compose(transform_list)
        return transform

    def get_dataloaders(self):
        # train set -> assuming we have separated data for testing and validation
        train_set = torchvision.datasets.ImageFolder(os.path.join(self.dataset_path, 'train'), transform=self.transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=self.batch_size, shuffle=True)

        # validation set -> assuming we have separated data for testing and validation
        val_set = torchvision.datasets.ImageFolder(os.path.join(self.dataset_path, 'val'), transform=self.transform)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader

    def plot_image(self, image, label):
        image = np.transpose(image.numpy(), (1, 2, 0))
        image = image * self.x_std.reshape(1, 1, 3) + self.x_mean.reshape(1, 1, 3)  # un-normalize
        plt.title(label)
        plt.imshow(image)
        plt.show()

    def get_semantic_label(self, label):
        mapping = {'COVID':0 'Not COVID': 1 } # Put correct Lables here
        reverse_mapping = {v: k for k, v in mapping.items()}
        return reverse_mapping[label]

if __name__ == '__main__':
    dataset = ImgDataset()
    print(dataset.x_mean, dataset.x_std)
    images, labels = iter(dataset.train_loader).next()
    dataset.plot_image(
        torchvision.utils.make_grid(images),
        ', '.join([dataset.get_semantic_label(label.item()) for label in labels])
    )
