In [None]:
import imageio.v3 as iio
from typing import Dict
import numpy as np
import os

In [None]:
import matplotlib.pyplot as plt
os.chdir("test_cases")
os.getcwd()

In [None]:
def input(low_input: str, high_input: str):
    img_low = {}
    for i in range(0,4):
        img_low[i] = iio.imread(f"{low_input}{i}.png")

    img_high = iio.imread(f"{high_input}.png")
    
    return img_low, img_high

def rmse(img_high: np.array, img_high_calculated: np.array):
    error = np.sqrt(((img_high - img_high_calculated)**2).sum()/img_high.size)
    print(f"{error:.4f}") 


def intercalate(img1: np.array,img2: np.array):
    
    N, M = img1.shape
    img3 = np.empty( [N, M+M], dtype=img1.dtype)

    for row in range(0, img3.shape[0]):
        img3[row][0::2] = img1[row]
        img3[row][1::2] = img2[row]

    return img3
    
def superresolution(img_dict: Dict[int, np.array]):

    img1 = intercalate(img_dict[0],img_dict[2])
    img2 = intercalate(img_dict[1],img_dict[3])

    N, M = img1.shape
    img3 = np.empty([N+N, M], dtype=img1.dtype)
    img3[0::2] = img1
    img3[1::2] = img2
    
    return img3

def histogram(img: np.array, n_levels: int):
    
    N, M = img.shape
    hist = np.empty(n_levels, dtype=int)
    for level in range(n_levels):
        hist[level] = np.sum(img == level)
        
    return hist

def histogram_equalization(img: np.array, n_levels: int):
    hist = histogram(img, n_levels)
    histC = np.empty(n_levels, dtype=int)
    
    histC[0] = hist[0]
    for i in range(1, n_levels):
        histC[i] = hist[i] + histC[i-1]
    
    N, M = img.shape    
    new_img = np.empty([N,M], dtype = img.dtype)
    for level in range(n_levels):
        new_img[np.where(img == level)] = (n_levels-1)*histC[level]/(N*M)
    
    return new_img

def single_image_cumulative_histogram(img_dict: Dict[int, np.array], n_levels: int):
    
    new_img_dict = {}
    for key in img_dict:
        new_img_dict[key] = histogram_equalization(img_dict[key], n_levels)
        
    new_img = superresolution(new_img_dict)
    return new_img

In [None]:
def img_subplot(img, subplot):
    plt.subplot(subplot)
    plt.imshow(img, cmap = "gray")
    plt.axis("off")

def hist_subplot(img, subplot):
    
    plt.subplot(subplot)
    plt.bar(range(0,256), histogram(img_high, 256))
    plt.xlabel("Graylevel/Intensity")
    plt.ylabel("Frequency")    

In [None]:
img_low, img_high = input("01_low","01_high")

In [None]:
plt.figure(figsize=(8,8))

img_subplot(img_low[0], 321)
img_subplot(img_low[1], 322)
img_subplot(img_low[2], 323)
img_subplot(img_low[3], 324)
img_subplot(img_high, 325)

In [None]:
img_high_calculated = superresolution(img_low)

In [None]:
plt.figure(figsize=(12,12))
img_subplot(img_high_calculated, 121)
img_subplot(img_high, 122)

In [None]:
def histogram(img: np.array, n_levels: int):
    
    hist = np.empty(n_levels, dtype=int)
    for level in range(n_levels):
        hist[level] = np.sum(img == level)
        
    return hist

In [None]:

def cumulative_histogram(hist: np.array, n_levels: int):
    histC = np.empty(n_levels, dtype=int)

    histC[0] = hist[0]
    for i in range(1, n_levels):
        histC[i] = hist[i] + histC[i-1]
    
    return histC

def histogram_equalization(img: np.array, n_levels: int, 
    joint: bool  = False, img_dict: Dict[int, np.array] = None):
    
    N, M = img.shape
    if joint is True:
        img = np.empty([N,M], dtype=img_dict[0].dtype)
        for key in img_dict:
            img = img + img_dict[key]
        
    hist = histogram(img, n_levels)
    histC = cumulative_histogram(hist, n_levels)            
        
    new_img = np.empty([N,M], dtype = img.dtype)
    for level in range(n_levels):
        L = (n_levels-1)*histC[level]/(N*M)
        new_img[np.where(img == level)] = L
    
    return new_img, histC


In [None]:
new_img, histC = histogram_equalization(img_low[0], 256, joint=True, img_dict=img_low)

img_subplot(new_img, 121)

img_subplot(img_low[0], 122)

In [None]:
def joint_cumulative_histogram(img_dict: Dict[int, np.array], n_levels: int):
    
    new_img_dict = {}
    for key in img_dict:
        new_img_dict[key] = histogram_equalization(img_dict[key], n_levels)
        
    new_img = superresolution(new_img_dict)
    return new_img

In [None]:
new_img = single_image_cumulative_histogram(img_low, 256)
rmse(img_high, new_img)