## Imports

In [None]:
import sys
sys.path.append("/home/roman/study/la/dim-red")

In [None]:
from dimred.models import PCAModel, SVDModel, KMeansModel, AutoEncoderModel
from dimred.metrics import get_metrics

In [None]:
import glob
import os
import pandas as pd
from pathlib import Path
from shutil import copyfile
import time
from tqdm import tqdm
import cv2 
from math import ceil
import matplotlib.pyplot as plt
import numpy as np

## Image utils

In [None]:
def get_ax(ax, col, row):
    if type(ax) is np.ndarray:
        curr_row_ax = ax[col]
        if type(curr_row_ax) is np.ndarray:
            return curr_row_ax[row]
        return curr_row_ax
    return ax

def plot_images(images, labels=None, cols=5, col_width=4, row_width=4, show_axis=False):
    rows = ceil(len(images) / cols)
    fig, ax = plt.subplots(rows, cols, figsize=(20, rows * row_width))

    curr_row = 0
    curr_col = 0
    for i in range(len(images)):
        row = i % cols
        col = i // cols
        if rows == 1 and cols > 1:
            col = row
        curr_ax = get_ax(ax, col, row)
        curr_ax.imshow(images[i])
        if labels is not None:
            curr_ax.set_title(labels[i]) 
        if not show_axis:
            curr_ax.get_xaxis().set_visible(False)
            curr_ax.get_yaxis().set_visible(False)
        curr_col += 1
        if curr_col == cols:
            curr_col = 0
            curr_row += 1
    return fig

In [None]:
def get_images_subset(save_dir, glob_path, start, end):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        image_paths = glob.glob(glob_path)[start:end]
        for image_path in image_paths:
            file_name = image_path.rsplit("/", 1)[1]
            copyfile(image_path, os.path.join(save_dir, file_name))

In [None]:
def read_img(path: str): 
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

In [None]:
def get_image_paths(images_directory, max_num=None):
    for index, image_path in enumerate(glob.glob(os.path.join(images_directory, "*.jpg"))):
        if index == max_num:
            return
        yield image_path

## Prepare data

Dataset to download: https://www.kaggle.com/adityajn105/flickr8k

In [None]:
original_image_path_directory = "original_images"
training_image_path_directory = "training_images"
glob_pattern = "/datasets/flickr8k/Images/*.jpg"

get_images_subset(original_image_path_directory, glob_pattern, 0, 400)
get_images_subset(training_image_path_directory, glob_pattern, 400, 1000)

num_of_images = len(glob.glob(os.path.join(original_image_path_directory, "*.jpg")))

## Experiment class definition

In [None]:
class Experiment:
    def __init__(self, model, experiment_name, experiment_dir="experiments"):
        self.model = model
        self.experiment_name = experiment_name
        self.experiment_dir = experiment_dir
        
        self._rows = list()
        os.makedirs(os.path.join(experiment_dir, experiment_name), exist_ok=True)
        
    def process_image_from_path(self, image_path):
        image = read_img(image_path)
        save_image_path = os.path.join(self.experiment_dir, self.experiment_name, os.path.basename(image_path))
        metrics = get_metrics(compression_model=self.model, original_image=image)
        self._rows.append(metrics)
        
        output_image = self.model.decompress(self.model.compress(image))
        cv2.imwrite(save_image_path, cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR))
        
    def save_report(self):
        report = pd.DataFrame(self._rows)
        report["experiment_name"] = self.experiment_name
        report["model_name"] = type(self.model).__name__
        report.to_csv(os.path.join(self.experiment_dir, self.experiment_name, "report.csv"), index=False)
    
    def get_processed_images(self):
        return [read_img(image_path) for image_path in glob.glob(os.path.join(self.experiment_dir, self.experiment_name, "*.jpg"))]
    
    def get_report(self):
        return pd.read_csv(os.path.join(self.experiment_dir, self.experiment_name, "report.csv"))

## Compare compression methods

In [None]:
experiments = [
    Experiment(model=KMeansModel.from_config("kmeans/config/kmeans_w4_2000.yaml"), experiment_name="kmeans_w4_2000"),
    Experiment(model=PCAModel(num_components=0.1), experiment_name="pca_0.1_components"),
    Experiment(model=SVDModel(num_components=0.1), experiment_name="svd_0.1_components"),
    Experiment(model=AutoEncoderModel.from_config("autoencoder/config/autoencoder.yaml"), experiment_name="autoencoder"),
]

In [None]:
for experiment in experiments:
    pbar = tqdm(total=num_of_images, desc=f"Experimet: {experiment.experiment_name}")
    for image_path in get_image_paths(original_image_path_directory):
        experiment.process_image_from_path(image_path)
        pbar.update()
    experiment.save_report()