In [1]:
from sklearn.model_selection import train_test_split
import os,gc
import cv2
import random
import numpy as np
from PIL import Image, ImageStat
from torch.utils.data import Dataset, DataLoader,random_split
from torch import randperm,unique
from torch._utils import _accumulate
import torchvision.transforms.functional as F
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.optim import lr_scheduler
import time
from skimage import io, transform
import sklearn.metrics as skm
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch import manual_seed,zeros_like,zeros,ones,unique,autograd,device,cuda,cat,save,load,tensor,utils,rand
from torch import sum as torch_sum
manual_seed(17)
random.seed(17)
###TODO
resize_mode = ""
resize_factor = 1

In [2]:
class AuxaliaryFunctions():
    """Cutomized Dataset used to train the model.
    Args:
      root_dir: path where all the training files are saved.
      transform: transformations to be applied to the dataset.
      filenames: names of all the files in training dataset.
    """
    
    @staticmethod
    def file_extension(file):
        if file.endswith(".jpg"):
          extension = ".jpg"
        elif file.endswith(".jpeg"):
          extension = ".jpeg"
        elif file.endswith(".png") and not file.endswith("_heatmap.png"):
          extension = ".png"
        else:
          extension = "null"
        return extension
    
    @staticmethod
    def get_image(root_dir,filename,only_image_transforms):
        filepath = os.path.join(root_dir,filename)
        img = Image.open(filepath)
        
        if only_image_transforms:
            img = only_image_transforms(img)
    
        return img
    
    @staticmethod
    def apply_transformers(image,target,target2=None,normalize_transformer=None):

        image = F.to_tensor(image)
        target = F.to_tensor(target)
        if target2:
            target2 = F.to_tensor(target2)
        
        if normalize_transformer:
            image = normalize_transformer(image)

        return image , target, target2

In [3]:
class DetectionDataset(Dataset):
    """Cutomized Dataset used to train the model.
    Args:
      root_dir: path where all the training files are saved.
      transform: transformations to be applied to the dataset.
      filenames: names of all the files in training dataset.
    """
    def __init__(self,root_dir,only_image_transforms=None,normalize_transformer=None):
        super(DetectionDataset, self).__init__()
        self.root_dir = root_dir
        self.only_image_transforms = only_image_transforms
        self.normalize_transformer = normalize_transformer
        #self.filenames, self.heatmap_filenames, self.xml_filenames = data.get_rest_detection()
        
        self.filenames = []
        self.heatmap_filenames = []
        self.xml_filenames = []
        
        ###TODO
        #count = 1
        
        for subfolder in ["dataset","forceTrain","forceTest"]:
            root_dir2 = os.path.join(root_dir,subfolder)
            for file in os.listdir(root_dir2):
                ###TODO
#                 if count > 32:
#                     break
                extension = AuxaliaryFunctions.file_extension(file)

                if extension == "null":
                    continue
                    
                #print("sks ", os.path.join(subfolder,file))
                self.filenames.append(os.path.join(subfolder,file))
                self.heatmap_filenames.append(os.path.join(subfolder,file.replace(extension,"_heatmap.png")))
                self.xml_filenames.append(os.path.join(subfolder,file.replace(extension,'.xml')))
                ###TODO
                #count += 1
                
    def __len__(self):
        """Length of the dataset"""
        return len(self.heatmap_filenames)
    
    def set_transformers(self,normalize_transformer,only_image_transforms):
        self.normalize_transformer = normalize_transformer
        self.only_image_transforms = only_image_transforms

    def __getitem__(self, ind):
        #print(self.root_dir)
        img = AuxaliaryFunctions.get_image(self.root_dir,self.filenames[ind],self.only_image_transforms)
        heatmap = Image.open(os.path.join(self.root_dir,self.heatmap_filenames[ind]))
        img,heatmap,_ = AuxaliaryFunctions.apply_transformers(img,heatmap,normalize_transformer=self.normalize_transformer)
#         if heatmap.max() > 1:
#             print("eiiiiiiiiiiiiiiiiiiiiih??",heatmap.max())
        #print(heatmap[0,0])
        #print(heatmap[0,50])
        #print(heatmap[0,100])
        #heatmap[heatmap == 0] = 0.1
        
        ###TODO
        #heat = zeros_like(heatmap)
        #heat[heatmap==0] = 1
        #heat[heatmap==1] = 0
        
        ###TODO
        return img, heatmap
        #return {'input':img, 'heatmap': heat}
    
class SegmentationDataset(Dataset):
    def __init__(self,root_dir,only_image_transforms=None,normalize_transformer=None):
        super(SegmentationDataset, self).__init__()
        self.root_dir = root_dir
        self.only_image_transforms = only_image_transforms
        self.normalize_transformer = normalize_transformer
        #self.filenames, self.target_filenames = data.get_rest_segmentation()
        self.filenames = []
        self.target_filenames = []
        
        ###TODO
#         count = 1
        
        for subfolder in ["dataset","forceTrain"]:
            root_dir2 = os.path.join(root_dir,subfolder,"image")

            for file in os.listdir(root_dir2):
               ###TODO 
#                 if count > 32:
#                     break
                
                extension = AuxaliaryFunctions.file_extension(file)

                if extension == "null":
                    continue

                self.filenames.append(os.path.join(subfolder,"image",file))
                self.target_filenames.append(os.path.join(subfolder,"target",file.replace(extension,".png")))
                ###TODO
#                 count += 1
    def __getitem__(self, ind):
            img = AuxaliaryFunctions.get_image(self.root_dir,self.filenames[ind],self.only_image_transforms)
            label = cv2.imread(os.path.join(self.root_dir,self.target_filenames[ind]),0)
            label = Image.fromarray(np.uint8(label))
            img,label,_ = AuxaliaryFunctions.apply_transformers(img,label,normalize_transformer = self.normalize_transformer)
            
            label[(1 == label)] = 2
            label[((0 < label) & (1 > label))] = 1
            label = label.squeeze()
            
            return img,label

    def __len__(self):
        return len(self.filenames)
    
    def set_transformers(self,normalize_transformer,only_image_transforms):
        self.normalize_transformer = normalize_transformer
        self.only_image_transforms = only_image_transforms