In [1]:
import glob
import numpy as np
import pandas as pd
import cv2
import random 
import math
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.datasets as datasets
from skimage import transform
from torch.autograd import Variable

In [3]:
class DataHandler(Dataset):
    def __init__(self, data_dir, transforms=None, bw=False, test_split=0.2):
        self.data_dir = data_dir        
        self.transforms = transforms
        self.colormap = cv2.COLOR_BGR2GRAY if bw else cv2.COLOR_BGR2RGB
        
        self._setup_data()
        
        # For splitting data
        self.test_split = test_split
        self.split_data()
    
    def _setup_labels(self, data_df):
        self.y_labels_source = data_df['health']
        
        unique_keys = np.sort(np.unique(self.y_labels_source))
        self.num_classes = len(unique_keys)
        self.key_to_idx = {ele : i for i, ele in enumerate(unique_keys)}
        self.idx_to_key = {val : key for key, val in self.key_to_idx.items()}
        
        self.y = [self.key_to_idx[key] for key in self.y_labels_source]
   
    def _setup_data(self):
        # Setup y (target ouput)
        labels_csv = glob.glob(f"{self.data_dir}/*.csv")[0]
        data_df = pd.read_csv(labels_csv)
        self._setup_labels(data_df)
        
        # Setup X (model input)
        self.X_filenames = data_df['file']
        self.X_filepaths = [f"{self.data_dir}/bee_imgs/{f}" for f in self.X_filenames]
        
        self.num_files = len(self.X_filepaths)
        
    def __getitem__(self, idx):
        img = cv2.imread(self.X_filepaths[idx])
        img = cv2.cvtColor(img, self.colormap)
        
        y = self.y[idx]
        
        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, y
    
    def __len__(self):
        return self.num_files
    
    # Splitting data into test/train
    def split_data(self):
        
        indices = list(range(self.num_files))
        
        # Shuffle
        np.random.shuffle(indices)
        
        # Split
        num_test = int(len(indices) * self.test_split)
        test_set_indices = indices[:num_test]
        train_set_indices = indices[num_test:]
        
        # Sampler
        self.train_sampler = SubsetRandomSampler(train_set_indices)
        self.test_sampler = SubsetRandomSampler(test_set_indices)

In [4]:
class CustomNormalize:
    def __call__(self, x):
        x = (x * 2) - 1
        return x