In [1]:
import os
import csv

from PIL import Image as im

import torch

import torchvision
from torchvision import transforms

import pytorch_lightning as pl

In [2]:
class PLDataModule(pl.LightningDataModule):
    def __init__(self, data_root, img_size):
        self.root = data_root
        self.metadata = self.get_metadata()
        self.img_size = img_size
        self.common_transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((img_size + 20, img_size + 20), 
                              transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(img_size)
            
        ])
        
    
    def prepare_data(self):
        # There is nothing to prepare.
        # The data should have already been downloaded and
        # extracted in the root that is passed to initiaize this class
        pass
    
    def setup(self, stage='fit'):
        data_dict = self.get_img_path_labels(stage)
        imgs_ct = len(data_dict['image_path'])
        imgs_tensor = torch.empty([imgs_ct, 3, self.img_size, self.img_size])
        labels_tensor = torch.empty(imgs_ct)
        
        for i in range(imgs_ct):
            img = self.common_transforms(im.open(data_dict['image_path'][i]))
            imgs_tensor[i] = img
            labels_tensor[i] = data_dict['label'][i]
                
        if stage == 'fit':
            self.train_imgs_tensor = imgs_tensor
            self.train_labels_tensor = labels_tensor
        
        else:
            self.val_imgs_tensor = imgs_tensor
            self.val_labels_tensor = labels_tensor
           
            
    #############################
    ## Miscellaneous functions ##
    #############################
    def get_img_path_labels(self, stage='fit'):
        data_dict = {'image_path':[], 'label':[]}
        if stage == 'fit':
            dir_ = os.path.join(self.root, 'training', 'training')
        else:
            dir_ = os.path.join(self.root, 'validation', 'validation')
            
        for cl in self.metadata['Label']:
            class_path = os.path.join(dir_, cl)
            for img in os.listdir(class_path):
                if img.endswith(".jpg") or img.endswith(".png"):
                    fname = os.path.join(class_path, img)
                    data_dict['image_path'].append(fname)
                    data_dict['label'].append(int(cl[1]))
        
        return data_dict

    
    def get_metadata(self):
        metadata = dict()
        with open(os.path.join(self.root, 'monkey_labels.txt'), mode ='r') as file:
            csvFile = csv.reader(file)            
            headers = [r.strip() for r in  next(csvFile)]
            
            for h in headers:
                metadata[h] = []
                
            for lines in csvFile:
                for i in range(len(lines)):
                    metadata[headers[i]].append(lines[i].strip())
        return metadata

In [3]:
d = PLDataModule('./../monkeys/', 224)