# Train Linear Regressor
---

## Import Libraries
---

In [1]:
import matplotlib.pyplot as plt
from image_util import I
from linear_regressor import LinearRegressor

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import math

class I:
    '''
    Custom Image Class
    '''
    image=[[]]
    
    def __init__(self, image: list):
        self.image = image
        
    @staticmethod
    def from_np(image: np.ndarray) -> 'I':
        # convert np array to list
        image = image.tolist()
        return I(image)
    
    @staticmethod
    def read(img_path) -> list:
        img = plt.imread(img_path)
        img.tolist()
        return img
    
    @staticmethod
    def show(img: list, gray=False) -> None:
        params = {}
        if gray:
            params['cmap'] = 'gray'
        
        plt.imshow(img, **params)
        plt.show()
    
    def display(self) -> None:
        plt.imshow(self.image)
        plt.show()
        
    @staticmethod
    def to_gray(img: list) -> list:
        gray_img = []
        for row in img:
            gray_row = []
            for px in row:
                gray_px = int(I.rgb_to_gray_px(px[0], px[1], px[2]))
                gray_row.append(gray_px)
            gray_img.append(gray_row)
        return gray_img
    
    def rgb_to_gray_px(r, g, b):
        return (r + g + b) / 3
        
    @staticmethod
    def flatten(img: list) -> list:
        '''
        flatten a 3D image to 2D
        '''
        flattened_img = []
        for row in img:
            for px in row:
                flattened_img.append(px)
        return flattened_img
        
    @staticmethod
    def extract_channels(img:list) -> list:
        '''
        extract all channels from an image
        '''
        channels = []
        for i in range(3):
            channels.append(I.extract_channel(img, i))
        return channels
    
    @staticmethod
    def extract_channel(img: list, channel: int) -> list:
        '''
        extract a channel from an image
        '''
        extracted_img = []
        for row in img:
            extracted_row = []
            for px in row:
                extracted_px = px[channel]
                extracted_row.append(extracted_px)
            extracted_img.append(extracted_row)
        return extracted_img
        
    @staticmethod
    def bl_resize(original_img, new_h, new_w):
        src_h, src_w, c = I.shape(original_img)
        resized_img = []

        # Calculate scaling factors
        scale_h = float(src_h) / new_h
        scale_w = float(src_w) / new_w

        # Iterate over each pixel in the resized image
        for y in range(new_h):
            row = []
            for x in range(new_w):
                # Calculate the corresponding position in the original image
                src_y = (y + 0.5) * scale_h - 0.5
                src_x = (x + 0.5) * scale_w - 0.5

                # Get the four surrounding pixels
                src_y1 = int(math.floor(src_y))
                src_x1 = int(math.floor(src_x))
                src_y2 = min(src_y1 + 1, src_h - 1)
                src_x2 = min(src_x1 + 1, src_w - 1)

                # Calculate the weights for interpolation
                w1 = (src_y2 - src_y) * (src_x2 - src_x)
                w2 = (src_y2 - src_y) * (src_x - src_x1)
                w3 = (src_y - src_y1) * (src_x2 - src_x)
                w4 = (src_y - src_y1) * (src_x - src_x1)

                # Perform bilinear interpolation for each color channel
                interpolated_px = []
                for ch in range(c):
                    interpolated_ch = (
                        original_img[src_y1][src_x1][ch] * w1 +
                        original_img[src_y1][src_x2][ch] * w2 +
                        original_img[src_y2][src_x1][ch] * w3 +
                        original_img[src_y2][src_x2][ch] * w4
                    )
                    interpolated_px.append(interpolated_ch)
                row.append(interpolated_px)
            resized_img.append(row)

        return resized_img
    
    @staticmethod
    def rgb_to_hsv(rgb_img):
        hsv_img = []
        for row in rgb_img:
            hsv_row = []
            for px in row:
                hsv_px = I.rgb_to_hsv_px(px[0], px[1], px[2])
                hsv_row.append(list(hsv_px))
            hsv_img.append(hsv_row)
        return hsv_img

    def rgb_to_hsv_px(r, g, b):
        r, g, b = r/255.0, g/255.0, b/255.0
        mx = max(r, g, b)
        mn = min(r, g, b)
        df = mx-mn
        if mx == mn:
            h = 0
        elif mx == r:
            h = (60 * ((g-b)/df) + 360) % 360
        elif mx == g:
            h = (60 * ((b-r)/df) + 120) % 360
        elif mx == b:
            h = (60 * ((r-g)/df) + 240) % 360
        if mx == 0:
            s = 0
        else:
            s = (df/mx)*100
        v = mx*100
        return h, s, v
    
    @staticmethod
    def shape(img_list: list):
        row = len(img_list)
        col = len(img_list[0])
        c = len(img_list[0][0])
        return row, col, c
    
    @staticmethod
    def multiply_img(k: float, img):
        '''
        multiply each pixel in an image by a constant
        '''
        new_img = []
        
        for row in img:
            new_row = []
            for px in row:
                if type(px) == list:
                    new_px = [k * channel for channel in px]
                else:
                    new_px = k * px
                new_row.append(new_px)
            new_img.append(new_row)
            
        return new_img

## Load data
---

In [2]:
# read training images
img1 = I.from_np(plt.imread('./data/no_bg/img/sapi1.jpg')).image
img2 = I.from_np(plt.imread('./data/no_bg/img/sapi2.jpeg')).image
img3 = I.from_np(plt.imread('./data/no_bg/img/sapi3.jpg')).image

print('image:', end='\t')
print(I.shape(img1), I.shape(img2), I.shape(img3))

# ground truth
gt_1 = I.from_np(plt.imread('./data/no_bg/gt/sapi1.jpg')).image
gt_2 = I.from_np(plt.imread('./data/no_bg/gt/sapi2.jpg')).image
gt_3 = I.from_np(plt.imread('./data/no_bg/gt/sapi3.jpg')).image

print('gt:', end='\t')
print(I.shape(gt_1), I.shape(gt_2), I.shape(gt_3))

image:	(578, 735, 3) (183, 276, 3) (690, 1023, 3)
gt:	(578, 735, 3) (183, 276, 3) (690, 1023, 3)


### Extract Features
---

In [3]:
# split hsv
hsv1 = I.rgb_to_hsv(img1)
hsv2 = I.rgb_to_hsv(img2)
hsv3 = I.rgb_to_hsv(img3)

h1, s1, v1 = I.extract_channels(hsv1)
h2, s2, v2 = I.extract_channels(hsv2)
h3, s3, v3 = I.extract_channels(hsv3)

In [4]:
flat_h1 = I.flatten(h1)
flat_s1 = I.flatten(s1)
flat_v1 = I.flatten(v1)

flat_h2 = I.flatten(h2)
flat_s2 = I.flatten(s2)
flat_v2 = I.flatten(v2)

flat_h3 = I.flatten(h3)
flat_s3 = I.flatten(s3)
flat_v3 = I.flatten(v3)

In [5]:
print(len(flat_h1), len(flat_h2), len(flat_h3))

424830 50508 705870


### Ground Truth
---

In [6]:
gt_1 = I.to_gray(gt_1)
gt_2 = I.to_gray(gt_2)
gt_3 = I.to_gray(gt_3)

In [29]:
flat_gt_1 = I.flatten(gt_1)
flat_gt_2 = I.flatten(gt_2)
flat_gt_3 = I.flatten(gt_3)

print(len(flat_gt_1), len(flat_gt_2), len(flat_gt_3))

424830 50508 705870


## Concatenate Features
---

In [30]:
X = [[flat_h1[i], flat_s1[i], flat_v1[i]] for i in range(len(flat_h1))]
X.extend([[flat_h2[i], flat_s2[i], flat_v2[i]] for i in range(len(flat_h2))])
X.extend([[flat_h3[i], flat_s3[i], flat_v3[i]] for i in range(len(flat_h3))])

y = flat_gt_1
y.extend(flat_gt_2)
y.extend(flat_gt_3)

In [9]:
def normalize(lst, divisor):
    result = []
    for item in lst:
        if isinstance(item, list):
            # If the element is a list, recursively divide its elements
            divided = normalize(item, divisor)
            result.append(divided)
        else:
            # Divide the element by the divisor
            divided = item / divisor
            result.append(divided)
    return result

In [31]:
print(len(X), len(y))

1181208 1181208


In [32]:
max(X), max(y)

([359.41747572815535, 82.4, 49.01960784313725], 255)

## Train
---

In [33]:
lr = LinearRegressor()

In [None]:
lr.fit(X=X, y=y, lr=0.00001, epochs=100, epsilon=0.1)

## Save Weights

In [43]:
weights = lr.m

In [44]:
print(weights)

[1.0014034909853222, 0.4012880982187062, 2.410825890893325, 0.3153325441925572]


In [45]:
m = [1.0014034909853222, 0.4012880982187062, 2.410825890893325, 0.3153325441925572]
weights, hist = lr.fit(X=X, y=y, lr=0.00001, epochs=1000, epsilon=0.1, m=m)

Epoch: 774 | Error: 4417.3992:  78%|███████▊  | 775/1000 [33:22<09:47,  2.61s/it]

KeyboardInterrupt: 

In [28]:
print(weights)

[0.05172266638498727, 0.6905957269351093, 1.4427400198297236, 0.3877351725711323]
