In [None]:
import numpy as np

import torch

import matplotlib.pyplot as plt

from tqdm import tqdm
from ray_transforms import get_ray_trafo, get_static_ray_trafo
from test_model_functions_swaped import get_net_corected_operator
# import random
import torch 


import os
from conf import config
from select_model_type_matrix import select_fwd_model_type,select_adj_model_type
from scipy import sparse
from util import get_op,Mat,plots,rand_shift_params,error_for_y,check_path

def im_norm(x):
    x = x.reshape(x.size)
    return np.linalg.norm(x,2)

def sp(a,b):
    return a.reshape(a.size)@b.reshape(b.size)

def ISTA_obj_func(op,y,lam,x):
    return 1/(2*lam)*np.linalg.norm((op(x)-y).reshape(y.size),2)**2+np.linalg.norm(x.reshape(x.size),1)

def Test_PGM(step_op,L,grad_R,true_op,p,y_e,x_0_selector,test_op,mu,num_iter):
    F_abl = lambda x: test_op.adjoint(test_op(x)-y_e) + grad_R(x)
    F_true_abl = lambda x: true_op.adjoint(true_op(x)-y_e) + grad_R(x)
    x = x_0_selector(y_e)
    X = np.zeros((num_iter + 1,x.shape[0],x.shape[1]))
    X[0] = x
    loss = np.zeros(num_iter+1)
    loss[0] = im_norm(p-x)
    TThetaMu = np.zeros(num_iter)
    LL = np.zeros(num_iter)
    AL = np.zeros(num_iter)
    FwL = np.zeros(num_iter)
    AdL = np.zeros(num_iter)
    for i in range(num_iter):
        x_old = x
        x = step_op(x,F_abl(x))
        X[i+1] = x
        loss[i+1] = im_norm(p-x)
        T_Thetamu = (x_old-x)/(mu)
        TThetaMu[i] = im_norm(T_Thetamu)
        LL[i] = L(x_old)-L(x)
        AL[i] = mu*(sp(F_true_abl(x_old)-F_abl(x_old),T_Thetamu)+0.5*TThetaMu[i]**2)
        FwL[i] = im_norm(true_op(x)-test_op(x))
        r = test_op(x)-y_e
        AdL[i] = im_norm(true_op.adjoint(r)-test_op.adjoint(r))
    X = X[np.arange(9,num_iter,10)]
    return {'X':X,'loss':loss,'TThetaMu':TThetaMu,'LL':LL,'AL':AL,'FwL':FwL,'AdL':AdL}

def Test_GD(step_op,L,grad_R,true_op,p,y_e,x_0_selector,test_op,mu,num_iter):
    """_summary_

    Args:
        step_op (_type_): an operator that produces the next iteration x_k+1 = step_op(x_k,F(x_k))
        L (_type_): The objective function that is minimized for testing only
        grad_R (_type_): the differentialle part of the regularization terms
        true_op (_type_): the precise operator we try to achieve
        p (_type_): the phantom
        y_e (_type_): noisy data
        x_0_selector (_type_): function that gives the x_0 deppending on y_e
        test_op (_type_): the operaator we want to test
        adj_test_op (_type_): the adjoint of the test opeartor
        mu (_type_): constant step size for 
        num_iter (_type_): number of iterations taht are being computed

    Returns:
        _type_: _description_
    """ 
    F_abl = lambda x: test_op.adjoint(test_op(x)-y_e) + grad_R(x)
    F_true_abl = lambda x: true_op.adjoint(true_op(x)-y_e) + grad_R(x)
    x = x_0_selector(y_e)
    X = np.zeros((num_iter + 1,x.shape[0],x.shape[1]))
    X[0] = x
    loss = np.zeros(num_iter+1)
    loss[0] = im_norm(p-x)
    TMu = np.zeros(num_iter)
    LL = np.zeros(num_iter)
    AL = np.zeros(num_iter)
    FF = np.zeros(num_iter)
    FwL = np.zeros(num_iter)
    AdL = np.zeros(num_iter)
    for i in range(num_iter):
        x_old = x
        x = step_op(x,F_abl(x))
        X[i+1] = x
        loss[i+1] = im_norm(p-x)
        grad_F = F_true_abl(x_old)
        TMu[i] = im_norm(grad_F)
        LL[i] = L(x_old)-L(x)
        AL[i] = mu*(sp(grad_F,F_abl(x_old))/TMu[i]**2)
        FF[i] = im_norm(F_true_abl(x_old)-F_abl(x_old))
        FwL[i] = im_norm(true_op(x)-test_op(x))
        r = test_op(x)-y_e
        AdL[i] = im_norm(true_op.adjoint(r)-test_op.adjoint(r))
    X = X[np.arange(9,num_iter,10)]
    return {'X':X,'loss':loss,'TMu':TMu,'LL':LL,'AL':AL,'FF':FF,'FwL':FwL,'AdL':AdL}

def save_plot_Test(path,dic,background = {},show = False):
    fig,axs = plots(2,1,3/2)
    axs[0].set_title('reconstruction loss')
    axs[0].plot(background.get('static',[]),label = 'static')
    axs[0].plot(background.get('true',[]),label = 'true')
    axs[0].plot(dic.get('loss',[]),label = 'cor')
    axs[0].set_yscale('log')
    axs[0].legend()
    axs[1].set_title('alignement')
    axs[1].plot(dic.get('LL')/dic.get('TM')**2,label='LL')
    axs[1].plot(dic.get('AL')/dic.get('TM')**2,label='AL')
    axs[1].set_yscale('log')
    axs[1].legend()
    fig.savefig(path)
    if show:
        plt.show(fig)
    else:
        plt.close(fig)

class net_cor_op():
    def __init__(self,static_op,fw_model,fw_swaped,adjoint_model,adj_swaped,device) -> None:
        """creates a operator 

        Args:
            static_op (_type_): _description_
            fw_model (_type_): _description_
            fw_swaped (_type_): _description_
            adjoint_model (_type_): _description_
            adj_swaped (_type_): _description_
            device (_type_): _description_
        """
        self.cor_op = get_net_corected_operator(static_op, fw_model,device = device,swaped=fw_swaped)
        self.cor_adj_op = get_net_corected_operator(static_op.adjoint, adjoint_model,device = device,swaped=adj_swaped)
    def __call__(self, x) -> np.array:
        """returns the forward operator

        Args:
            x (_type_): _description_

        Returns:
            np.array: _description_
        """
        return self.cor_op(x)
    def adjoint(self, x) -> np.array:
        """returns the adjoint of the opearator

        Args:
            x (_type_): _description_

        Returns:
            np.array: _description_
        """
        return self.cor_adj_op(x)

def soft_shrink(x,alpha):
    """the soft shrinkige operator
    Args:
        x (np.array): input
        alpha ( float ): parameter

    Returns:
        np.array: soft_schrink(x)
    """
    return np.sign(x) * np.maximum(np.abs(x)-alpha,0)
    


In [None]:
gpu_idx = 0
device = f"cuda:{gpu_idx}" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
import astra
astra.set_gpu_index(gpu_idx)

In [None]:
A_static_s = sparse.load_npz('Matritzen/64_64_256_96_static.npz')
x_res = 64
y_res = 64
num_angles = 256
detector_points = 96
static_op = get_op(A_static_s,x_res, y_res, num_angles, detector_points)
static_ray_trafo = get_static_ray_trafo(x_res, y_res,num_angles,detector_points,detector_len=2)

In [None]:
tp = np.load(file='phantoms/test_phantoms_64_1.npy')
p = tp[0,:,:]
p.shape

In [None]:
def x_0is0(y):
    return np.zeros((x_res,y_res))
def x_0isATy(y):
    return static_op.adjoint(y)
def x_0isp(y):
    return p

In [None]:
runs_list = ['Test']
operator_list = []
op_name_list = []
i = 10
# A_u = sparse.load_npz(f"Matritzen/64_64_256_96_1_u_and_v_shift/u_ray_trafo_{i}.npz")
# u_op = get_op(A_u,x_res, y_res, num_angles, detector_points)
# operator_list.append(u_op)
# op_name_list.append(r'$n=1$, $u$ shift')
# A_v = sparse.load_npz(f"Matritzen/64_64_256_96_1_u_and_v_shift/v_ray_trafo_{i}.npz")
# v_op = get_op(A_v,x_res, y_res, num_angles, detector_points)
# operator_list.append(v_op)
# op_name_list.append(r'$n=1$, $v$ shift')
A_u_v = sparse.load_npz(f"Matritzen/64_64_256_96_1_u_and_v_shift/u_v_ray_trafo_{i}.npz")
u_v_op = get_op(A_u_v,x_res, y_res, num_angles, detector_points)
operator_list.append(u_v_op)
op_name_list.append(r'$n=1$, $u$ and $v$ shift')
# shift_params = np.load(f'Matritzen/64_64_256_96_1_u_and_v_shift/shift_params_{i}.npy')

A_s = sparse.load_npz('Matritzen/Test_ray_trafo_64_256_96_100.npz')
operator_list.append(get_op(A_s,x_res, y_res, num_angles, detector_points))
op_name_list.append(r'$n=1$, $u$ and $v$ shift, new')
# A_s = sparse.load_npz('Matritzen/64_64_256_96_strong_u_v_shift.npz')
# operator_list.append(get_op(A_s,x_res, y_res, num_angles, detector_points))
# op_name_list.append(r'$n=1$, $u$ and $v$ shift, new large amplitude')
A_s = sparse.load_npz("Matritzen/Test_64_256_96_5addet_u_v.npz")
operator_list.append(get_op(A_s,x_res, y_res, num_angles, detector_points))
op_name_list.append(r'$n=5$, $u$ and $v$ shift, new')