## 最適化問題

$$ \min_h \frac{1}{2}\|\bm{g}-(I\otimes F^\top) \bm{h}\|_2^2+\lambda_1\|\bm{h}\|_{1,2}^2 + \lambda_2\|D\bm{h}\|_{1,2}$$


In [1]:
import os
import math
from typing import Callable
import cupy as cp
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import package.myUtil as myUtil

In [2]:
n = 64
m = 128
N = n**2
M = m**2
LAMBDA1 = 1
LAMBDA2 = 1
TAU = 1e-4
SIGMA = 1e-4
SEED = 5
RATIO = 0.05
ITER = 1000
DATA_PATH = "../../OneDrive - m.titech.ac.jp/Lab/data"
# DATA_PATH = "../data"
IMG_NAME = "hadamard"
DIRECTORY = DATA_PATH + "/241005"
SETTING = f"{IMG_NAME}_pr-du_p-{int(100*RATIO)}_lmd1-{LAMBDA1}_lmd2-{LAMBDA2}_t{int(math.log10(TAU))}_s{int(math.log10(SIGMA))}"

if not os.path.exists(DIRECTORY):
    os.makedirs(DIRECTORY)
if not os.path.exists(DIRECTORY + "/systemMatrix"):
    os.makedirs(DIRECTORY + "/systemMatrix")

In [3]:
def print_memory_usage(message):
    mempool = cp.get_default_memory_pool()
    used_bytes = mempool.used_bytes()
    total_bytes = mempool.total_bytes()
    print(f"{message}: Used memory: {used_bytes / 1024**3:.2f} GB, Total memory: {total_bytes / 1024**3:.2f} GB")

In [4]:
def mult_D(h):
    tensor = h.reshape((n, n, m, m), order='F')

    di = cp.zeros_like(tensor)
    dj = cp.zeros_like(tensor)
    dk = cp.zeros_like(tensor)
    dl = cp.zeros_like(tensor)

    di[1:, :, :, :] = tensor[1:, :, :, :] - tensor[:-1, :, :, :]
    dj[:, 1:, :, :] = tensor[:, 1:, :, :] - tensor[:, :-1, :, :]
    dk[:, :, 1:, :] = tensor[:, :, 1:, :] - tensor[:, :, :-1, :]
    dl[:, :, :, 1:] = tensor[:, :, :, 1:] - tensor[:, :, :, :-1]

    di_flat = di.ravel(order='F')
    dj_flat = dj.ravel(order='F')
    dk_flat = dk.ravel(order='F')
    dl_flat = dl.ravel(order='F')

    Dh = cp.concatenate([di_flat, dj_flat, dk_flat, dl_flat])

    return Dh


def mult_Dt(y):
    length = n * n * m * m

    di_flat = y[0:length]
    dj_flat = y[length:2*length]
    dk_flat = y[2*length:3*length]
    dl_flat = y[3*length:4*length]

    di = di_flat.reshape((n, n, m, m), order='F')
    dj = dj_flat.reshape((n, n, m, m), order='F')
    dk = dk_flat.reshape((n, n, m, m), order='F')
    dl = dl_flat.reshape((n, n, m, m), order='F')

    h_tensor = cp.zeros((n, n, m, m))

    h_tensor[:-1, :, :, :] -= di[1:, :, :, :]
    h_tensor[1:, :, :, :] += di[1:, :, :, :]

    h_tensor[:, :-1, :, :] -= dj[:, 1:, :, :]
    h_tensor[:, 1:, :, :] += dj[:, 1:, :, :]

    h_tensor[:, :, :-1, :] -= dk[:, :, 1:, :]
    h_tensor[:, :, 1:, :] += dk[:, :, 1:, :]

    h_tensor[:, :, :, :-1] -= dl[:, :, :, 1:]
    h_tensor[:, :, :, 1:] += dl[:, :, :, 1:]

    h = h_tensor.ravel(order='F')

    return h

In [5]:
def calculate_1st_term(Gt, Ft, Ht):
    print("calculate_1st_term start")
    return cp.linalg.norm(Gt - Ft @ Ht) ** 2


def calculate_2nd_term(H):
    print("calculate_2nd_term start")
    column_sums = cp.sum(cp.abs(H), axis=1)
    result = cp.sum(column_sums**2)
    return result


def calculate_3rd_term(h):
    print("calculate_3rd_term start")
    Du = mult_D(h)
    Du = Du.reshape(-1, 4, order="F")
    tv = cp.sum(cp.linalg.norm(Du, axis=1))
    return tv

In [6]:
def prox_l122(Ht: cp.ndarray, gamma: float) -> cp.ndarray:
    l1_norms = cp.sum(cp.absolute(Ht), axis=1)
    factor = (2 * gamma) / (1 + 2 * gamma * N)
    X = cp.zeros_like(Ht)
    X = cp.sign(Ht) * cp.maximum(cp.absolute(Ht) - factor * l1_norms[:, None], 0)
    return X


def prox_tv(y: cp.ndarray, gamma: float) -> cp.ndarray:
    l2 = cp.linalg.norm(y.reshape(-1, 4, order="F"), axis=1, keepdims=True)
    l2 = cp.maximum(1 - gamma / (l2 + 1e-16), 0)
    return (l2 * y.reshape(-1, 4, order="F")).ravel(order="F")


def prox_conj(y: cp.ndarray, prox: Callable[[cp.ndarray, float], cp.ndarray], gamma: float) -> cp.ndarray:
    """Conjugate proximal operator."""
    return y - gamma * prox(y / gamma, 1 / gamma)


def primal_dual_splitting(
    Ft: cp.ndarray, Gt: cp.ndarray, lambda1: float, lambda2: float, max_iter: int = ITER
) -> tuple[cp.ndarray, dict]:

    N = Ft.shape[1]
    M = Gt.shape[1]
    print_memory_usage("Before initializing variables")
    Ht = cp.zeros((N, M), dtype=cp.float32)
    Ht_old = cp.zeros_like(Ht)
    y = cp.zeros(4 * N * M, dtype=cp.float32)
    y_old = cp.zeros_like(y)
    print_memory_usage("After initializing variables")

    tau = TAU
    sigma = SIGMA
    print(f"tau={tau}, sigma={sigma}")

    for k in range(max_iter):
        Ht_old[:] = Ht[:]
        y_old[:] = y[:]

        Ht[:] = prox_l122(
            Ht_old - tau * (Ft.T @ (Ft @ Ht - Gt) + (mult_Dt(y_old)).reshape(N, M, order="F")),
            lambda1 * tau,
        )

        y[:] = prox_conj(y_old + sigma * mult_D((2 * Ht - Ht_old).ravel(order="F")), prox_tv, lambda2 / sigma)

        if k % 20 == 19:
            primal_residual = cp.linalg.norm(Ht - Ht_old) / cp.linalg.norm(Ht)
            dual_residual = cp.linalg.norm(y - y_old) / cp.linalg.norm(y)
            print(f"iter={k}, primal_res={primal_residual:.8e}, dual_res={dual_residual:.8e}")
            print("1st", calculate_1st_term(Gt, Ft, Ht))
            print("2nd", calculate_2nd_term(Ht))
            print("3rd", calculate_3rd_term(Ht))
            if cp.isnan(primal_residual) or cp.isnan(dual_residual):
                print("NaN detected in residuals, stopping optimization.")
                break
            if primal_residual < 1e-3 and dual_residual < 1e-3:
                print("Convergence criteria met.")
                break
        else:
            print(f"iter={k}")

    primal_residual = cp.linalg.norm(Ht - Ht_old)
    dual_residual = cp.linalg.norm(y - y_old)
    print(f"Final iteration {k+1}, primal_res={primal_residual:.8e}, dual_res={dual_residual:.8e}")
    print_memory_usage("After optimization")

    info = {
        "iterations": k + 1,
        "primal_residual": primal_residual,
        "dual_residual": dual_residual,
    }

    return Ht, info

In [None]:
# load images
INFO = "cap_R_230516_128"
# INFO = "cap_240814"
G, _ = myUtil.images_to_matrix(f"{DATA_PATH}/{IMG_NAME}{n}_{INFO}/", ratio=RATIO, resize=True, ressize=m)
F, _ = myUtil.images_to_matrix(f"{DATA_PATH}/{IMG_NAME}{n}_input/", ratio=RATIO)
K = F.shape[1]
print("K=", K)
white_img = Image.open(f"{DATA_PATH}/{IMG_NAME}{n}_{INFO}/{IMG_NAME}_1.png").convert("L")
white_img = white_img.resize((m, m))
white = np.asarray(white_img).ravel() / 255
white = white[:, np.newaxis]
H1 = np.tile(white, F.shape[1])
F_hat = 2 * F - 1
G_hat = 2 * G - H1
# G_vec = G_hat.ravel(order="F")

In [None]:
F_hat_T_gpu = cp.asarray(F_hat.T).astype(cp.int8)
G_hat_T_gpu = cp.asarray(G_hat.T).astype(cp.float32)

print(f"F device: {F_hat_T_gpu.device}")
print(f"g device: {G_hat_T_gpu.device}")
del F, G, H1, F_hat, G_hat

In [None]:
h, info = primal_dual_splitting(F_hat_T_gpu, G_hat_T_gpu, LAMBDA1, LAMBDA2)

In [None]:
Ht = h.reshape(N, M, order="F")
# np.save(f"{DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npy", H)
# print(f"Saved {DIRECTORY}/systemMatrix/H_matrix_{SETTING}.npy")

SAMPLE_NAME = "Cameraman"
sample_image = Image.open(f"{DATA_PATH}/sample_image{n}/{SAMPLE_NAME}.png").convert('L')
sample_image = cp.asarray(sample_image).ravel() / 255

Hf = Ht.T @ sample_image
Hf_img = cp.asnumpy(Hf.reshape(m, m))
Hf_img = np.clip(Hf_img, 0, 1)
Hf_pil = Image.fromarray((Hf_img * 255).astype(np.uint8), mode='L')

FILENAME = f"{SAMPLE_NAME}_{SETTING}.png"
fig, ax = plt.subplots(figsize=Hf_img.shape[::-1], dpi=1, tight_layout=True)
ax.imshow(Hf_pil, cmap='gray')
ax.axis('off')
fig.savefig(f"{DIRECTORY}/{FILENAME}", dpi=1)
# plt.show()
print(f"Saved {DIRECTORY}/{FILENAME}")