# Define imports

In [None]:
#!pip install opencv-python

import cv2
import os
import glob
import numpy as np
from Regressor_Model_Controller import Regressor_Model_Controller
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
import torch
import torchvision.transforms as transforms
import math
import random


torch.set_default_tensor_type(torch.FloatTensor)

# Define constants

In [10]:
DIPALI_HOME = '/home/dipali.patidar/DLWG/DataSet/*.jpg'
IMAGE_HEIGHT = 128
IMAGE_WIDTH = 128
REGRESSOR_SAVED_MODEL_PATH= './Model/Regressor/Regressor.pth'
REGRESSOR_EPOCH = 75
REGRESSOR_LOSS_PLOT_PATH = "./Plots/Regressor/Regressor_Loss_plot.jpeg"
REGRESSOR_LR = 0.0001
REGRESSOR_WEIGHT_DECAY = 1e-5
REGRESSOR_IN_CHANNEL = 1
REGRESSOR_HIDDEN_CHANNEL = 3
REGRESSOR_OUT_DIMS = 2
REGRESSOR_BATCH_SIZE_CPU = 32
REGRESSOR_BATCH_SIZE_CUDA = 32

# Use GPU if available

In [11]:
def get_device():
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        is_cuda_present = True if torch.cuda.is_available() else False
        num_workers = 8 if is_cuda_present else 0

        return device, is_cuda_present, num_workers

# Load image data from folder

In [8]:
def image_modification(file,nimages):
    img = cv2.imread(file)

    #Resize to respect the input_shape
    inp = cv2.resize(img, (IMAGE_WIDTH , IMAGE_HEIGHT))

    #Convert img to RGB
    rgb = cv2.cvtColor(inp, cv2.COLOR_BGR2RGB)
    
    return img
    

def read_images_to_tensor(filepath):
    
    files = glob.glob(filepath)
    nimages = len(files)
    data = []
    #print(files)
    for f1 in files:
        #cv2.imshow('image',img)
        data.append(image_modification(f1,nimages))
    return torch.tensor(data)


data_tensor = read_images_to_tensor(DIPALI_HOME)

# Define input and output directories

In [12]:
data_set_dir = './DataSet'
image_paths = os.listdir(data_set_dir)
image_paths = [image_path for image_path in image_paths if image_path.endswith('.jpg')]
random.shuffle(image_paths)
image_data = [cv2.imread(os.path.join(data_set_dir, image_path)) for image_path in image_paths]

# Convert images to LAB color spaceÂ¶

In [31]:
transformed_image_data = []

transform_color = transforms.Compose([transforms.ToTensor()])

for image_path in image_paths:
    image = cv2.imread(os.path.join(data_set_dir, image_path))
    imageLAB = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    #print(imageLAB.shape)
    image_lab = transform_color(image)
    image_gray = image_lab[0:1,:,:]

    image_a = image_lab[1:2, :, :]
    #print(image_a.shape)

    image_b = image_lab[2:3, :, :]
    #print(image_b.shape)
    #print(image_b)
    #plt.imshow(image_b)
    #plt.show()
    

    transformed_image_data.append((image_gray, image_a, image_b))


In [34]:
test_size=0.1
batch_size =REGRESSOR_BATCH_SIZE_CPU
index_test = math.floor(len(image_data) * test_size)

# include both color and gray images
trainloader = torch.utils.data.DataLoader(transformed_image_data[:index_test], batch_size=batch_size,
                                          shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(transformed_image_data[index_test:], batch_size=batch_size,
                                         shuffle=False, num_workers=0)

NameError: name 'image_data' is not defined

# Train the Regressor

In [None]:
def train_regressor(augmented_dataset_batch, device):
        regressor_train_arguments = {
            "data_loader": augmented_dataset_batch,
            "saved_model_path": REGRESSOR_SAVED_MODEL_PATH,
            "epochs": REGRESSOR_EPOCH,
            "learning_rate": REGRESSOR_LR,
            "weight_decay": REGRESSOR_WEIGHT_DECAY,
            "in_channel": REGRESSOR_IN_CHANNEL,
            "hidden_channel": REGRESSOR_HIDDEN_CHANNEL,
            "out_dims": REGRESSOR_OUT_DIMS,
            "loss_plot_path": REGRESSOR_LOSS_PLOT_PATH
        }

        regressor_manager = Regressor_Model_Controller()
        regressor_manager.train(regressor_train_arguments, device)

def test_regressor(augmented_dataset_batch, device):
        regressor_arguments = {
            "data_loader": augmented_dataset_batch,
            "saved_model_path": Constants.REGRESSOR_SAVED_MODEL_PATH,
            "in_channel": Constants.REGRESSOR_IN_CHANNEL,
            "hidden_channel": Constants.REGRESSOR_HIDDEN_CHANNEL,
            "out_dims": Constants.REGRESSOR_OUT_DIMS,
            "loss_plot_path": Constants.REGRESSOR_LOSS_PLOT_PATH
        }

        regressor_manager = Regressor_Model_Controller()
        regressor_manager.test(regressor_arguments, device)