In [42]:
import os
from argparse import ArgumentParser
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader, random_split

import pytorch_lightning as pl

from PIL import Image

# Data Set augmentation and Dataset class

In [73]:
# Random Crop Augment
class RandomCrop():
    def __init__(self, min_crop = 256, max_crop = 1024):
        self.min_crop = min_crop
        self.max_crop = max_crop
    
    @staticmethod
    def random_crop(I1, crop_size):
        h, w = I1.size # Assume I1.size == I2.size
        th, tw = crop_size
        
        if w == tw and h == th:
            return 0, 0, h, w

        i = torch.randint(0, h - th + 1, size=(1, )).item()
        j = torch.randint(0, w - tw + 1, size=(1, )).item()
        return i, j, i+th, j+tw
    
    def __call__(self, img_set):
        I1, I2 = img_set
        cs = torch.randint(self.min_crop, self.max_crop, size=(1, )).item()
        crop_size = (cs, cs)
        bbox =  self.random_crop(I1, crop_size)
        
        return I1.crop(bbox), I2.crop(bbox)
    
class RandomRotate():
    def __init__(self, min_angle = -25, max_angle = 25):
        self.min_angle = min_angle
        self.max_angle = max_angle
    
    def __call__(self, img_set):
        I1, I2 = img_set
        angle = torch.randint(self.min_angle, self.max_angle, size=(1, )).item()
        return TF.rotate(I1, angle, Image.BILINEAR), TF.rotate(I2, angle, Image.BILINEAR)

class RandomFlip():
    def __init__(self):
        pass
    
    def __call__(self, img_set):
        I1, I2 = img_set
        horizontal_flip = np.random.choice([True, False])
        vertical_flip = np.random.choice([True, False])
        if horizontal_flip:
            I1 = I1.transpose(Image.FLIP_LEFT_RIGHT)
            I2 = I2.transpose(Image.FLIP_LEFT_RIGHT)
        if vertical_flip:
            I1 = I1.transpose(Image.FLIP_TOP_BOTTOM)
            I2 = I2.transpose(Image.FLIP_TOP_BOTTOM)
        
        return I1, I2

class ColorJitter():
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        self.color_jitter = transforms.ColorJitter(brightness, contrast, saturation, hue)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        
        return self.color_jitter(I1), I2

class CenterCrop():
    def __init__(self, crop):
        self.center_crop = transforms.CenterCrop(crop)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.center_crop(I1), self.center_crop(I2)
    

class Resize():
    def __init__(self, size):
        self.resize = transforms.Resize(size)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.resize(I1), self.resize(I2)
    
class ToTensor():
    def __init__(self):
        self.tensor = transforms.ToTensor()
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.tensor(I1), self.tensor(I2)
    

class Normalize():
    def __init__(self, mean = (0.5,0.5, 0.5), std = (0.5,0.5, 0.5)):
        self.norm = transforms.Normalize(mean, std)
    
    def __call__(self, img_set):
        I1, I2 = img_set
        return self.norm(I1), self.norm(I2)

data_transform = transforms.Compose([
    RandomRotate(min_angle = -25, max_angle = 25),
    RandomCrop(min_crop = 256, max_crop = 1024),
    Resize(340),
    CenterCrop(256),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.5),
    RandomFlip(),
    ToTensor(),
    Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))
])
    
class texture_dataset(Dataset):
    def __init__(self, imgs, transform = data_transform):

        self.imgs = img
        self.num_samples = len(self.imgs)
        self.transform = transform

    def __getitem__(self, i):

        I = Image.open(self.imgs[i]).convert("RGB")
        I_diffuse = I.crop((0, 0, 1024, 1024)) 
        I_normal = I.crop((1024, 0, 2048, 1024)) 
        if self.transform:
            I_diffuse, I_normal = self.transform([I_diffuse, I_normal])
        return I_diffuse, I_normal

    def __len__(self):
        return self.num_samples

In [77]:
I = Image.open(r"G:\texture\data_1\0000.png").convert("RGB")
I_diffuse = I.crop((0, 0, 1024, 1024)) 
I_normal = I.crop((1024, 0, 2048, 1024)) 

In [78]:
I1_c, I2_c = data_transform([I_diffuse, I_normal])