# Steepest Descent Method

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from scipy import sparse
import os
import re

In [None]:
def images_to_matrix(folder_path, convert_gray=False, is_binary=False, cap=False, seed=0, ratio=1.0):
    files = os.listdir(folder_path)
    if cap:
        files.sort(key=lambda f: int(
            re.search('hadamard64_(\d+).png', f).group(1)))
    else:
        files.sort(key=lambda f: int(
            re.search('hadamard_(\d+).png', f).group(1)))

    selected_files = files

    if seed != 0:
        random.seed(seed)
        random.shuffle(files)
        number_of_files_to_load = int(len(files) * ratio)
        selected_files = files[:number_of_files_to_load]
        if cap:
            selected_files.sort(key=lambda f: int(re.search('hadamard64_(\d+).png', f).group(1)))
        else:
            selected_files.sort(key=lambda f: int(re.search('hadamard_(\d+).png', f).group(1)))

    images = []

    for file in selected_files:
        print(file)
        img = Image.open(os.path.join(folder_path, file))
        if convert_gray:
            img = img.convert('L')
        if is_binary:
            img = img.point(lambda x: 255 if x else 0, 'L')
        img_array = np.asarray(img).flatten()
        img_array = img_array / 255
        images.append(img_array)
    return np.column_stack(images)


def update_H(H, G, F, gamma, threshold):
    i = 1
    H_prev = H.copy()
    while True:
        grad = 2 * (H @ F - G) @ F.T
        H = H - gamma * grad
        norm = np.linalg.norm(grad, 'fro')
        error = np.linalg.norm(H - H_prev, 'fro')

        print('iter: {}, norm: {}, error: {}'.format(i, norm, error))
        if norm < threshold:
            break
        H_prev = H.copy()
        i += 1
    return H


In [None]:
# パラメータ設定
n = 64
m = 128
gamma = 0.0001
threshold = 0.001

In [None]:
# 真のシステム行列の計算
G_full = images_to_matrix('../data/hadamard'+str(n)+'_cap_W_sim/',
                     convert_gray=True, cap=True)
F_full = images_to_matrix('../data/Hadamard'+str(n)+'_input/', is_binary=True)
H1 = np.zeros((m**2, n**2))
G1s = []
for i in range(n**2):
    G1s.append(G_full[:, 0])
H1 = np.column_stack(G1s)
F_hat_full = 2*F_full-1
G_hat_full = 2*G_full-H1
H_true = G_hat_full@F_hat_full.T/(n**2)
# plt.figure(figsize=(12, 8))
# sns.heatmap(H_true, annot=False, cmap='viridis')

In [None]:
# 初期化
H = np.zeros((m**2, n**2))
ratio = 0.5

G = images_to_matrix('../data/hadamard'+str(n)+'_cap_W_sim/',
                     convert_gray=True, cap=True, seed=2, ratio=ratio)

In [None]:
F = images_to_matrix('../data/Hadamard'+str(n)+'_input/', is_binary=True, seed=2, ratio=ratio)

In [None]:
H1 = np.zeros((m**2, int(n**2*ratio)))
G1s = []
for i in range(int(n**2*ratio)):
    G1s.append(G[:, 0])
H1 = np.column_stack(G1s)
F_hat = 2*F-1
G_hat = 2*G-H1
# plt.figure(figsize=(12, 8))
# sns.heatmap(F_hat, annot=False, cmap='viridis')
# plt.figure(figsize=(12, 8))
# sns.heatmap(G_hat, annot=False, cmap='viridis')

In [None]:
# Hを更新
H = update_H(H, G_hat, F_hat, gamma, threshold)

np.save('../data/systemMatrix/H_matrix.npy', H)

plt.figure(figsize=(12, 8))
sns.heatmap(H, annot=False, cmap='viridis')

In [None]:
sample_image = Image.open('../data/sample_image64/Cameraman64.png')
sample_image = sample_image.convert('L')
sample_image = np.asarray(sample_image).flatten()/255

plt.figure(figsize=(3, 3))
plt.imshow(sample_image.reshape(n,n), cmap='gray')
plt.axis('off')
plt.show()

plt.figure(figsize=(3, 3))
grand_truth = Image.open('../data/sample_image64_cap_sim/Cameraman64.png')
plt.imshow(grand_truth, cmap='gray')
plt.axis('off')
plt.show()

In [None]:
Hf = H@sample_image
# plt.figure(figsize=(12, 8))
# sns.heatmap(Hf.reshape(m,m), annot=False, cmap='viridis')
plt.imshow(Hf.reshape(m, m), cmap='gray')
plt.axis('off')
plt.savefig('../data/240130/Cameraman64_'+str(ratio) +
            '.png', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
Hf_true = H_true@sample_image
plt.imshow(Hf_true.reshape(m, m), cmap='gray')
plt.axis('off')
# plt.savefig('../data/240130/Cameraman64_.png', bbox_inches='tight', pad_inches=0)
plt.show()
# plt.figure(figsize=(12, 8))
# sns.heatmap(HF, annot=False, cmap='viridis')

In [None]:
rem = np.linalg.norm(H_true-H, 'fro')
print(rem)