# Preliminares 

## Importar dependencias 

In [None]:
from inrmri.fourier_features import BasicSpatialFourierFeaturesNet 
from inrmri.utils import meshgrid_from_subdiv_autolims
from inrmri.data_splitter import SimpleDataLoader

from jax import random 
import jax.numpy as np 

import matplotlib.pyplot as plt 
from inrmri.utils import to_complex 
from jax import jit 
from inrmri.fourier import make_kspace, make_image
from inrmri.basic_nn import mse, simple_train
import optax 

## Descargar datos 

Para este notebook se necesitan dos archivos `cine_sax_slice5_frame5` y `cs`. Los directorios a ambos archivos se guardan en las variables `KSPACE_FILE` y `CSMAP_FILE` respectivamente, y pueden descargarse de:

- `cine_sax_slice5_frame5.npy`: https://drive.google.com/file/d/13MAxtswKV-31tAu8XPq1Mea2yYdRCTyv/view?usp=sharing 
- `cs.npy`: https://drive.google.com/file/d/11He4AX0rC1n9hgehzcEzWU9ZtPD-q03b/view 

Los datos `cine_sax_slice5_frame5.npy` corresponden a una adquisición de resonancia magnética cardiaca, específicamente un cine cardiaco (un vídeo del corazón moviéndose). Los datos corresponden a un frame y una slice de una adquisición cartesiana multibobina. 

> Para ver otros frames o slices de la misma adquisición, se puede descargar el archivo [`cine_sax.mat`](https://drive.google.com/file/d/1P8FRU2FOp97Sbbi7JwPMHvV_keHBkjiX/view). Se necesita la librería [`mat73`](https://pypi.org/project/mat73/) para leerlo. Puede recuperarse el archivo `cine_sax_slice5_frame5` a partir de `cnie_sax` haciendo:
>
> ```python
> import mat73 
> cine_sax = max73.loadmat(FULL_CINE_SAX_PATH)['kspace_full']
> selected_slice = 5 
> selected_frame = 5 
> cine_sax_slice5_frame5 = cine_sax[:,:,:,selected_slice, selected_frame] 
> ```

In [None]:
KSPACE_FILE = '/home/tabita/ACIP-MRI/NF-cMRI/carthesian_data/cine_sax_slice5_frame5.npy'
CSMAP_FILE = '/home/tabita/ACIP-MRI/NF-cMRI/carthesian_data/cs.npy'

# Reconstrucción de una adquisición cartesiana usando _Implicit Neural Representations_

## Cargar datos 

In [None]:
kspace = np.load(KSPACE_FILE)
Y_full = kspace[None, ..., None] # (1, px, py, coil, 1)
grid = meshgrid_from_subdiv_autolims(Y_full.shape[1:3], endpoint=True)
X_full = grid[None,...]

# data_loader = SimpleDataLoader(X_full, Y_full) 

In [None]:
fig, axs = plt.subplots(2,5, figsize=(5*2,2*4))
axs = axs.flatten()

maxval=0.3 * np.abs(make_image(Y_full, axes=(1,2)).max())
for i,ax in enumerate(axs):
    ax.imshow(np.abs(make_image(Y_full[0,:,:,i,0])), vmax=maxval, cmap='bone')
    ax.set_title(f'Imagen obtenida\n por la bobina {i}')
plt.tight_layout()

In [None]:
cs = np.load(CSMAP_FILE)
print("cs.shape: ", cs.shape)

fig, axs = plt.subplots(2,5, figsize=(5*2,2*4))
axs = axs.flatten()

maxval = np.abs(cs.max())

for i,ax in enumerate(axs):
    ax.imshow(np.abs(cs[0,:,:,i]), vmax=maxval)
    ax.set_title(f'Bobina {i}')
plt.tight_layout()

## Construir red y entrenar 

In [None]:
key_B, key_params, key_train = random.split(random.PRNGKey(0), 3)
ffnet = BasicSpatialFourierFeaturesNet(input_size=2, # la red es 2D
                                       mapping_size=1000,
                                       sigma=7.5,
                                       output_size=2, # real e imaginario 
                                       key_B=key_B)

FFs = ffnet.useFFBox(X_full) 
fig, axs = plt.subplots(2,10, figsize=(10*2, 2*4), dpi=150) # observo los primeros 20 fourier features 
for i, ax in enumerate(axs.flatten()):
    ax.imshow(FFs[0,:,:,i])
    ax.set_axis_off()
    ax.set_title(f'Feature {i}')


In [None]:
def ffnet_image(params, grid): 
    img = ffnet.eval_coordinates(params, grid) # (1,px,py,2)
    img = to_complex(img) 
    return img 

@jit
def loss(params, X, Y): 
   print(X.shape)
   img = ffnet_image(params, X)
   print(img.shape)
   print(cs.shape)
   coil_weighted_img = img * cs # (1,px,py,coils)
   kspace = make_kspace(coil_weighted_img, axes=(1,2))[...,None] 
   print(kspace.shape) # revisar que los shapes de kspace e Y coincidan
   print(Y.shape)
   return mse(kspace, Y)

params = ffnet.init_params(key_params, [256,256,256])

optimizer = optax.adam(1e-3)
results = simple_train(loss,
                X_full,
                Y_full,
                params,
                optimizer,
                key_train,
                batch_size=1,
                nIter=5000) # no cambiar el batch size

## Ver resultados

In [None]:
# %% 

chosen_csmap = 5
vmax = 0.18

fig = plt.figure(figsize=(8,7))
plt.subplot(121)
plt.imshow(np.abs(make_image(Y_full[0,:,:,chosen_csmap])), vmin=0., vmax=vmax, cmap='bone')
plt.colorbar()
plt.axis('off')
plt.title(f'Reference en coil {chosen_csmap}')
plt.subplot(122)
plt.imshow(np.abs(ffnet_image(results['last_param'], X_full)[0,...,0] * cs[0,:,:,chosen_csmap]), vmin=0., vmax=vmax, cmap='bone')
plt.title(f'Imagen aproximada con \nFourier Features en coil {chosen_csmap}')
plt.axis('off')
plt.colorbar()
# %%

ffnet_image(results['last_param'], X_full).shape, cs.shape
# %%