In [None]:
import os
import math
import time
from typing import Callable
import numpy as np
import cupy as cp
import cupyx.scipy.sparse as csp
import matplotlib.pyplot as plt
import dask.array as da
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from PIL import Image
from IPython.display import display
import package.myUtil as myUtil

In [None]:
n = 256
LAMBDA = 100
RATIO = 0.05
DO_THIN_OUT = False
SAVE_AS_SPARSE = True
DATA_PATH = "../data"
# DATA_PATH = "../../OneDrive - m.titech.ac.jp/Lab/data"
IMG_NAME = "hadamard"
CAP_DATE = "241205"
EXP_DATE = "241206"
DIRECTORY = f"{DATA_PATH}/{EXP_DATE}"
SETTING = f"p-{int(100*RATIO)}_lmd-{LAMBDA}"
if DO_THIN_OUT:
    SETTING = SETTING + "to"

if not os.path.exists(DIRECTORY):
    os.makedirs(DIRECTORY)
if not os.path.exists(DIRECTORY + "/systemMatrix"):
    os.makedirs(DIRECTORY + "/systemMatrix")
use_list = myUtil.get_use_list(n*n, RATIO)

In [None]:
G = myUtil.images2matrix(f"{DATA_PATH}/{IMG_NAME}{n}_cap_{CAP_DATE}/", use_list, thin_out=DO_THIN_OUT)
F = myUtil.images2matrix(f"{DATA_PATH}/{IMG_NAME}{n}_input/", use_list).astype(cp.int8)
M, K = G.shape
N, K = F.shape
print("G shape:", G.shape, "F shape:", F.shape, "M=", M, "N=", N, "K=", K)
print("G max:", G.max(), "G min:", G.min(), "F max:", F.max(), "F min:", F.min())

black = myUtil.calculate_bias(M, DATA_PATH, CAP_DATE)
B = cp.tile(black[:, None], K)

G = G - B

white_img = Image.open(f"{DATA_PATH}/capture_{CAP_DATE}/White.png").convert("L")
white = (cp.asarray(white_img) / 255).astype(cp.float32)
if DO_THIN_OUT:
    white = white[::2, ::2].ravel() - black
else:
    white = white.ravel() - black
H1 = cp.tile(white[:, None], K)

F_hat = 2 * F - 1
G_hat = 2 * G - H1
del F, G, H1

In [None]:
def distributed_matmul(A: np.ndarray, B: np.ndarray, M_block: int, K_block: int, gpu_ids=[0, 1, 2]):
    """
    A, B: NumPy配列 (AはM×N, BはN×K)
    M_block, K_block: ブロックサイズ
    gpu_ids: 使用するGPU IDリスト

    戻り値:
        C: A@Bの結果を表すDask Array（cupyをバックエンドにしたブロック計算結果）
    """
    # 行列サイズ取得
    M, N = A.shape
    N2, K = B.shape
    if N != N2:
        raise ValueError("行列サイズが不一致です。Aは(M×N), Bは(N×K)である必要があります。")

    # Dask + CUDA Cluster 設定
    # 既にClusterやClientが存在する場合は、外部から渡せるよう拡張可能
    cluster = LocalCUDACluster(n_workers=len(gpu_ids), threads_per_worker=1, CUDA_VISIBLE_DEVICES=gpu_ids)
    client = Client(cluster)

    # 行列A, BをDask配列化、チャンク分割
    # Aは行方向にM_blockごと、Bは列方向にK_blockごと分割
    A_d = da.from_array(A, chunks=(M_block, N))
    B_d = da.from_array(B, chunks=(N, K_block))

    # ブロックをcupy配列に変換
    A_c = A_d.map_blocks(cp.asarray, dtype=A_d.dtype)
    B_c = B_d.map_blocks(cp.asarray, dtype=B_d.dtype)

    # 行列積計算 (ブロックごとにGPUで計算)
    C_c = da.dot(A_c, B_c)

    # 計算を開始
    # C_cはDask配列。persistで計算をスケジューラに渡す(遅延評価解消)
    C_c = C_c.persist()
    client.wait(C_c)

    # cluster, clientをこの関数内で作成した場合、必要に応じて後でclose可能
    # ここでは返り値としてC_c（Dask Array）を返す
    return C_c

In [None]:
def prox_l122(Y: cp.ndarray, gamma: float) -> csp.csr_matrix:
    factor = (2 * gamma) / (1 + 2 * gamma * N)
    l1_norms = cp.sum(cp.absolute(Y), axis=1)
    X = cp.sign(Y) * cp.maximum(cp.absolute(Y) - factor * l1_norms[:, None], 0)
    return csp.csr_matrix(X)
    # return X


def fista(
    Ft: cp.ndarray,
    Gt: cp.ndarray,
    lmd: float,
    prox: Callable[[cp.ndarray, float], cp.ndarray],
    max_iter: int = 500,
    tol: float = 1e-3,
) -> cp.ndarray:
    """
    Solve the optimization problem using FISTA:
    min_h ||g - Xh||_2^2 + lambda * ||h||_1,2^2

    Parameters:
    - Ft: numpy array, the matrix Ft
    - g: numpy array, the vector g
    - lmd: float, the regularization parameter

    Returns:
    - h: numpy array, the solution vector h
    """
    N = Ft.shape[1]
    M = Gt.shape[1]
    t = 1
    # Ht = cp.zeros((N, M), dtype=cp.float32)
    # Ht_old = cp.zeros_like(Ht)
    Ht = csp.csr_matrix((N, M), dtype=cp.float32)
    Ht_old = csp.csr_matrix((N, M), dtype=cp.float32)
    Yt = cp.zeros((N, M), dtype=cp.float32)
    # fft = Ft.T @ Ft
    fft=distributed_matmul(Ft.T, Ft, 4096, 4096)
    # fgt = Ft.T @ Gt
    fgt = distributed_matmul(Ft.T, Gt, 4096, 4096)

    # Lipschitz constant
    # L = np.linalg.norm(Ft.T @ Ft, ord=2) * 3
    gamma = 1 / (N * 3)

    for i in range(max_iter):
        t_old = t
        Ht_old = Ht.copy()

        A_dense = Yt - gamma * fft @ Ht - fgt
        Ht = prox(A_dense, gamma * lmd)
        # Ht = prox(Yt - gamma * fft @ Ht - Ft.T @ Gt, gamma * lmd)
        t = (1 + np.sqrt(1 + 4 * t_old**2)) / 2
        Yt = Ht + ((t_old - 1) / t) * (Ht - Ht_old)

        # error = cp.linalg.norm(Ht - Ht_old) / cp.linalg.norm(Ht)
        error = csp.linalg.norm(Ht - Ht_old) / csp.linalg.norm(Ht)
        print(f"iter: {i}, error: {error}")
        # rem = cp.linalg.norm(Ht - H_true.T)
        # print(f"iter: {i}, error: {error}, rem: {rem}")
        if error < tol:
            break

    return Ht

In [None]:
Ht = fista(F_hat.T, G_hat.T, LAMBDA, prox_l122)
H = Ht.T
del Ht

In [None]:
if SAVE_AS_SPARSE:
    print(f"shape: {H.shape}, nnz: {H.nnz}({H.nnz / H.shape[0] / H.shape[1] * 100:.2f}%)")
    H_np = {
        "data": cp.asnumpy(H.data),
        "indices": cp.asnumpy(H.indices),
        "indptr": cp.asnumpy(H.indptr),
        "shape": H.shape
    }
    np.savez(f"{DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npz", **H_np)
    print(f"Saved {DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npz")
    # myUtil.plot_sparse_matrix_cupy(H, row_range=(5500, 6000), col_range=(4500, 5000), markersize=1)
else:
    cp.save(f"{DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npy", H)
    print(f"Saved {DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npy")

In [None]:
SAMPLE_NAME = "Cameraman"
sample_image = Image.open(f"{DATA_PATH}/sample_image{n}/{SAMPLE_NAME}.png").convert("L")
sample_image = cp.asarray(sample_image).flatten() / 255

m = int(math.sqrt(M))
FILENAME = f"{SAMPLE_NAME}_{SETTING}.png"
Hf = H @ sample_image + black
Hf = cp.asnumpy(Hf.reshape(m, m))
print("Hf shape:", Hf.shape)

Hf_pil = Image.fromarray((Hf * 255).astype(np.uint8), mode="L")
Hf_pil.save(f"{DIRECTORY}/{FILENAME}", format='PNG')
print(f"Saved {DIRECTORY}/{FILENAME}")
display(Hf_pil)

plt.imshow(Hf, cmap='gray', interpolation='nearest')
plt.colorbar()
plt.title('Grayscale Heatmap')
plt.show()