In [None]:
import os
import numpy as np
import cupy as cp
import scipy.sparse as ssp
import cupyx.scipy.sparse as csp
from PIL import Image
import admm

In [None]:
# DATA_PATH = '../../../OneDrive - m.titech.ac.jp/Lab/data'
DATA_PATH = '../data'
OBJ_NAME = "Cameraman"
# H_SETTING = "FISTA_p-5_lmd-100_m-255"
H_SETTING = "gf"
CAP_DATE = "241114"
EXP_DATE = "241118"
n = 128
m = 255

In [None]:
def get_sparse_matrix_memory_size(sparse_matrix):
    data_size = sparse_matrix.data.nbytes
    indices_size = sparse_matrix.indices.nbytes
    indptr_size = sparse_matrix.indptr.nbytes
    total_size = data_size + indices_size + indptr_size
    return total_size


def format_size(bytes_size):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if bytes_size < 1024:
            return f"{bytes_size:.2f} {unit}"
        bytes_size /= 1024
    return f"{bytes_size:.2f} PB"

In [None]:
def create_D_mono(n):
    I = csp.eye(n**2, format="csr")

    Dx = I - csp.csr_matrix(cp.roll(I.toarray(), 1, axis=1))
    Dx[n - 1 :: n, :] = 0
    Dy = I - csp.csr_matrix(cp.roll(I.toarray(), n, axis=1))
    Dy[-n:, :] = 0

    return csp.vstack([Dx, Dy])


D = create_D_mono(50)
print(f"Size of D: {format_size(get_sparse_matrix_memory_size(D))}")

In [None]:
captured = Image.open(f"{DATA_PATH}/capture_{CAP_DATE}/{OBJ_NAME}.png").convert("L")
captured = cp.asarray(captured)
g = captured.ravel()

In [None]:
PREFIX = "int_"
H = cp.load(f"{DATA_PATH}/{EXP_DATE}/systemMatrix/H_matrix_{PREFIX}{H_SETTING}.npy").astype(cp.float32)
print("H shape:", H.shape, "type(H):", type(H), "H.dtype:", H.dtype)

In [None]:
# Thresholding
H = cp.where(H < 1e-4, 0, H)

In [None]:
H_sparse = csp.csr_matrix(H)
del H
print("Non zero elements in H:", H_sparse.nnz)
print("Size of H", format_size(get_sparse_matrix_memory_size(H_sparse)))

In [None]:
admm = admm.Admm(H_sparse, g, D)

In [None]:
f, err = admm.solve()

In [None]:
f = cp.clip(f, 0, 1)
f = cp.asnumpy(f.reshape(n, n))
f_image = Image.fromarray((f*255).astype(np.uint8), mode="L")

tau = np.log10(admm.tau)
mu1 = np.log10(admm.mu1)
mu2 = np.log10(admm.mu2)
mu3 = np.log10(admm.mu3)

if not os.path.exists(f"{DATA_PATH}/{EXP_DATE}/reconst"):
    os.makedirs(f"{DATA_PATH}/{EXP_DATE}/reconst")
SAVE_PATH = f"{DATA_PATH}/{EXP_DATE}/reconst/{OBJ_NAME}_{H_SETTING}_admm_t-{tau}_m{mu1}m{mu2}m{mu3}.png"
f_image.save(SAVE_PATH, format="PNG")
print(SAVE_PATH)