### For more info, see the [Github repo](https://github.com/rlaugier/nifits) and the [documentation](https://rlaugier.github.io/nifits_doc.github.io/)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

# 1. Create example data
# 1.1 Setup a 3T kernel-nuller

Nota: the ``NI_IOTAGS`` and ``NI_DSAMP`` are not mandatory. In this example, they are trivial and transparent. Use the parameters ``include_iotags`` and ``include_downsampling`` to include them or not.
activate and deactivate these extensions.

In [None]:
# combiner matrix for a 3T kernel nuller (from github.com/rlaugier/kernuller)
combiner = np.array(((1, -1), (1, 1)), dtype=complex) / np.sqrt(2)
mat_3T_txt = """
Matrix([
[sqrt(3)/3,                sqrt(3)/3,                sqrt(3)/3],
[sqrt(3)/3,  sqrt(3)*exp(2*I*pi/3)/3, sqrt(3)*exp(-2*I*pi/3)/3],
[sqrt(3)/3, sqrt(3)*exp(-2*I*pi/3)/3,  sqrt(3)*exp(2*I*pi/3)/3]])
"""
import sympy as sp
combiner_s = sp.sympify(mat_3T_txt)
combiner = np.array(sp.N(combiner_s,), dtype=np.complex128)

kmat = np.array([[0.0, 1.0, -1.0],])

include_iotags = True
include_downsampling = True
if include_iotags:
    from astropy.table import Column
    outbright = np.array([True, False, False])[None,:]
    outphot = np.array([False, False, False])[None,:]
    outdark = np.array([False, True, True])[None,:]
    inpol = np.array(["s","s","s"])[None,:]
    outpol = np.array(["s","s","s"])[None,:]
    

# collector positions
baseline = 15  # in meter
# Collector diameter
telescope_diam = 3.0

# rotation angles over observation
n_sample_time = 100
rotation_angle = np.linspace(0., 2*np.pi, n_sample_time)  # in rad

# collector_positions_init = np.array(((-baseline/2, baseline/2),
#                                      (0, 0)))

collector_positions_init = np.array(((-baseline/2, baseline/2, 0),
                                     (0, 0, baseline/2)))

rotation_matrix = np.array(((np.cos(rotation_angle), -np.sin(rotation_angle)),
                            (np.sin(rotation_angle), np.cos(rotation_angle))))

collector_position = np.dot(np.swapaxes(rotation_matrix, -1, 0), collector_positions_init)

# observing wavelengths
n_wl_bin = 5
wl_bins = np.linspace(4.0e-6, 18.0e-6, n_wl_bin)  # in meter

# collector area
scaled_area = 1  # in meter^2

# Measurement covariance 
# np.random.seed(10)
# np.random.normal(loc=(), size=wl_bin.shape)
cov = 1e1 * np.eye(kmat.shape[0] * wl_bins.shape[0])
covs = np.array([cov for i in range(n_sample_time)])
print(covs.shape)

In [None]:
collector_positions_init.T
from astropy.table import Table
myarraytable = Table(names=["TEL_NAME", "STA_NAME", "STA_INDEX", "DIAMETER", "STAXYZ"],
                    dtype=[str, str, int, float, "(3,)double"],
                    units=[None, None, None, "m", "m"])
for i, (atelx, ately) in enumerate(collector_positions_init.T):
    myarraytable.add_row([f"Tel {i}", f"", i, telescope_diam, np.array([atelx, ately, 0.])])

# 2. Initialize a nifits object
## 2.1 Showcasing a list of NIFITS extensions

In [None]:
import nifits.io.oifits as io
for aclass in io.NIFITS_EXTENSIONS:
    a = io.getclass(aclass)
    print()
    print(aclass, " :")
    print("---------------")
    print(a.__doc__)
    print("==============================================================")

In [None]:
# For the sake of this example the array is located at Paranal observatory, yet it is defined as a fixed
# array of three telescopes rotating around the line of sight, rather like a space interferometer would.
# For the sake of forward modelling, OI_ARRAY is irrelevant, only the values in NI_MOD matter.
oi_array = io.OI_ARRAY(data_table=myarraytable, header=io.OI_ARRAY_DEFAULT_VLTI_HEADER)

In [None]:
ni_catm = io.NI_CATM(data_array=combiner[None, :, :] * np.ones_like(wl_bins)[:,None,None])

In [None]:
mykmat = io.NI_KMAT(data_array=kmat)

In [None]:
from copy import copy
my_FOV_header = copy(io.NI_FOV_DEFAULT_HEADER)
my_FOV_header["NIFITS FOV_TELDIAM"] = telescope_diam
my_FOV_header["NIFITS FOV_TELDIAM_UNIT"] = "m"
ni_fov = io.NI_FOV.simple_from_header(header=my_FOV_header, lamb=wl_bins,
                                  n=n_sample_time)

In [None]:
oi_target = io.OI_TARGET.from_scratch()
oi_target.add_target(target='Test Target', 
                      raep0=14.3, 
                      decep0=-60.4)

In [None]:
mykcov = ni_kcov = io.NI_KCOV(data_array=covs, unit=(u.ph/u.s)**2)

In [None]:
from astropy.table import Table, Column
from astropy.time import Time
n_telescopes = combiner.shape[1]
total_obs_time = 10*3600      # s
times_relative = np.linspace(0, total_obs_time, n_sample_time)
dateobs = Time("2035-06-23T00:00:00.000") + times_relative*u.s
mjds = dateobs.to_value("mjd")
seconds = (dateobs - dateobs[0]).to_value("s")
target_id = np.zeros_like(times_relative)
app_index = np.arange(n_telescopes)[None,:]*np.ones(n_sample_time)[:,None]
target_ids = 0 * np.ones(n_sample_time)
int_times = np.gradient(seconds)
mod_phas = np.ones((n_sample_time, n_wl_bin, n_telescopes), dtype=complex)
appxy = collector_position.transpose((0,2,1))
arrcol = np.ones((n_sample_time, n_telescopes)) * np.pi*telescope_diam**2 / 4
fov_index = np.ones(n_sample_time)

app_index         = Column(data=app_index, name="APP_INDEX",
                   unit=None, dtype=int)
target_id         = Column(data=target_ids, name="TARGET_ID",
                   unit=None, dtype=int)
times_relative    = Column(data=seconds, name="TIME",
                   unit="", dtype=float)
mjds              = Column(data=mjds, name="MJD",
                   unit="day", dtype=float)
int_times         = Column(data=seconds, name="INT_TIME",
                   unit="s", dtype=float)
mod_phas          = Column(data=mod_phas, name="MOD_PHAS",
                   unit=None, dtype=complex)
appxy             = Column(data=appxy, name="APPXY",
                   unit="m", dtype=float)
arrcol            = Column(data=arrcol, name="ARRCOL",
                   unit="m^2", dtype=float)
fov_index         = Column(data=fov_index, name="FOV_INDEX",
                   unit=None, dtype=int)
mymod_table = Table()
mymod_table.add_columns((app_index, target_id, times_relative, mjds,
                        int_times, mod_phas, appxy, arrcol, fov_index))
mymod_table
mynimod = io.NI_MOD(mymod_table)

## 2.2 Creating the NIFITS parent object

In [None]:

from astropy.io import fits

wl_data = np.hstack((wl_bins[:,None], np.gradient(wl_bins)[:,None]))
wl_table = Table(data=wl_data, names=("EFF_WAVE", "EFF_BAND"), dtype=(float, float))
wl_table

del wl_data
oi_wavelength = io.OI_WAVELENGTH(data_table=wl_table,)
# oi_wavelength = io.OI_WAVELENGTH()

if include_downsampling:
    ni_oswavelength = io.NI_OSWAVELENGTH(data_table=wl_table,)
    ni_dsamp = io.NI_DSAMP(data_array=np.eye(len(wl_table)))
else:
    ni_oswavelength = None
    ni_dsamp = None

if include_iotags:
    
    # outbrightcol = Column(data=outbright,
    #                    name="BRIGHT", unit=None,dtype=bool)
    # outphotcol = Column(data=outphot,
    #                    name="PHOT", unit=None,dtype=bool)
    # outdarkcol = Column(data=outdark,
    #                    name="DARK", unit=None,dtype=bool)
    # inpolcol = Column(data=inpol,
    #                    name="OUTPOLA", unit=None,dtype=str)
    # outpolcol = Column(data=outpol,
    #                    name="INPOLA", unit=None,dtype=str)
    # iotags_table = Table()
    # iotags_table.add_columns((outbrightcol, outphotcol, outdarkcol, inpolcol, outpolcol))
    # ni_iotags = io.NI_IOTAGS(data_table=iotags_table)
    ni_iotags = io.NI_IOTAGS.from_arrays(outbright=outbright, outdark=outdark, outphot=outphot,
                             inpola = inpol, outpola=outpol)
else :
    ni_iotags = None

myheader = fits.Header()
mynifit = io.nifits(header=myheader,
                        oi_array=oi_array,
                        ni_catm=ni_catm,
                        ni_fov=ni_fov,
                        oi_target=oi_target,
                        oi_wavelength=oi_wavelength,
                        ni_mod=mynimod,
                        ni_kmat=mykmat,
                        ni_kcov=mykcov,
                        ni_dsamp=ni_dsamp,
                        ni_oswavelength=ni_oswavelength,
                        ni_iotags=ni_iotags)

mynifit.__dict__.keys()

#### 2.3 Saving and opening

In [None]:
mkdir log

In [None]:
myhdu = mynifit.to_nifits(filename="log/testfits.nifits",
                            static_only=False,
                          writefile=True,
                         overwrite=True)
myhdu[0].header

# Test header check


import warnings
def check_item(func):
    """
    A decorator for the header getitem. 
    """
    def inner(*args, **kwargs):
        good_kw = True
        try :
            item = func(*args, **kwargs)
            good_kw = True
        except KeyError: 
            good_kw = False
        if good_kw:
            return item
        bad_kw = True
        try : 
            akw = args[1]
            mykw = akw.split(" ")[-1]
            baditem = func(args[0], mykw, **kwargs)
            bad_kw = True
        except KeyError: 
            bad_kw = False
        
        if bad_kw and not good_kw:
            warnings.warn(f"keyword deprecated in the file. Expected `{args[1]}` (`HIERARCH` keyword)\n Returning `{mykw}`\n This file will become obsolete.")
            item = baditem
            return item
        elif not bad_kw and not good_kw:
            raise KeyError(f"Neither {args[1]} nor {mykw} found.")
            return None
            
            
        return item
    return inner
fits.Header.__getitem__ = check_item(fits.Header.__getitem__)



In [None]:
with fits.open("log/testfits.nifits") as anhdu:
    newfits = io.nifits.from_nifits(anhdu)
newfits.header

|  Column      |  format                   |  unit            | Empty |
|:------------:|:------------------------- |:---------------- | ---- | 
|  `APP_INDEX` |  $n_a \times$ int         |  NA              |     |
|  `TARGET_ID` |  int                      |  d               |     |
|  `TIME`      |  float                    |  s               |     |
|  `MJD`       |  float                    |  day             |     |
|  `INT_TIME`  |  float                    |  s               |     |
|  `MOD_PHAS`  |  $n_{\lambda}, n_a $ cpx  |                  |     |
|  `APPXY`     |  $n_a, 2 $ float          | m               |      |
|  `ARRCOL`    |  $n_a $ float             |  $\mathrm{m}^2$  |     |
|  `FOV_INDEX` |  $n_a $ int               |  NA              |     |

# 3. Testing the back end
## 3.1 Loading a file into the backend

In [None]:
import nifits.backend as be

In [None]:
mybe = be.NI_Backend(newfits)
abe = be.NI_Backend()
abe.add_instrument_definition(mynifit)
# abe.add_observation_data(mynifit)
abe.create_fov_function_all()
print(abe.nifits.header)

## 3.2 Freehanding the backend

In [None]:
halfrange = 1000
halfrange_rad = halfrange*u.mas.to(u.rad)
xs = np.linspace(-halfrange_rad, halfrange_rad, 100)
map_extent = [-halfrange, halfrange, -halfrange, halfrange]
xx, yy = np.meshgrid(xs, xs)

## 3.2.1 The field of view function

In [None]:
map_fov = abe.nifits.ni_fov.xy2phasor(xx.flatten(), yy.flatten())
plt.figure(dpi=100)
plt.imshow(np.abs(map_fov[0,0,:].reshape((xx.shape))), extent=map_extent)
plt.colorbar()
plt.contour(np.abs(map_fov[0,0,:].reshape((xx.shape))), levels=(0.5,), extent=map_extent)
plt.title("Wavelength bin 0")
plt.show()

plt.figure(dpi=100)
plt.imshow(np.abs(map_fov[0,-1,:].reshape((xx.shape))), extent=map_extent)
plt.colorbar()
plt.contour(np.abs(map_fov[0,-1,:].reshape((xx.shape))), levels=(0.5,), extent=map_extent)
plt.title("Wavelength bin -1")
plt.show()

In [None]:
print("lambda/D : ", (wl_bins/telescope_diam)*u.rad.to(u.mas))

## 3.3 The forward-modeled outputs

`backend.get_all_outs` takes in arrays of coordinates and returns either the raw outputs (`kernels=False`) or the full 

### 3.3.1 With random samples

In [None]:
xys_mas = np.random.uniform(low=-500, high=+500, size=(2,10000)) 
xys = xys_mas * u.mas.to(u.rad)
# xysm = xys[:,:]
%time z = abe.get_all_outs(xys[0,:], xys[1,:], kernels=False)

In [None]:
plt.figure()
plt.plot(z[:,1,1,1000])
plt.show()

### Note the shape of the output: 

In [None]:
print(be.NI_Backend.get_all_outs.__doc__)

In [None]:
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=z[0,-1,1,:], cmap="viridis", s=6)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response of output 1 [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

kz = abe.get_all_outs(xys[0,:], xys[1,:], kernels=True)
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=kz[0,-1,0,:], cmap="coolwarm", s=6)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The differential response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(xys[0,:], xys[1,:])
plt.figure()
plt.scatter(xys_mas[0,:],xys_mas[1,:], c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=6)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The FoV function")
plt.xlabel("Relative position [mas]")
plt.show()
print(np.max(z))

In [None]:
a = np.random.normal(size=10)
b = np.random.normal(size=5)
a[:,None].dot(b[None,:])

In [None]:
abe.nifits.ni_mod.appxy.shape

# 4. Additional convenience tools

## 4.1 Handling point collections with `PointCollection`

These can be used for a smooth and unified interface for simulating parametrically positionned point-like or extended objects.

Further down the line, these can be "summed" together with a `+` operator to create arbitrary sampled maps.

In [None]:
print(be.PointCollection.__doc__)

You can do it for one point and the computation is relatively fast.

In [None]:
one_point = be.PointCollection(np.array((10.,)), np.array((20.,)), unit=u.mas)
%timeit z = abe.get_all_outs(*one_point.coords_rad, kernels=True)

### 4.1.1 A boring cartesian grid using `PointCollection.from_centered_square_grid`

In [None]:
%%time
acollec = be.PointCollection.from_centered_square_grid(600., 100, md=np)
x_inj = abe.nifits.ni_fov.xy2phasor(*acollec.coords_rad)
plt.figure()
plt.scatter(*acollec.coords, c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=5)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title("The Fov function")
plt.xlabel("Relative position [mas]")
plt.show()

z = abe.get_all_outs(*acollec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*acollec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=5)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

print(np.max(z))


### Or reshaping to a 2D array with `.orig_shape` and `.extent`

In [None]:
plt.figure()
plt.imshow(z[0,-1,0,:].reshape(acollec.orig_shape),
           cmap="coolwarm",
           extent=acollec.extent)
plt.colorbar()
plt.xlabel("")
plt.show()

### 4.1.2 A point-sampled disk using `PointCollection.from_uniform_disk`
N.B. This merges well with `scipy.interpolate.griddata` 

In [None]:
%%time
acollec = be.PointCollection.from_uniform_disk(600., 600)

z = abe.get_all_outs(*acollec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*acollec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=80)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

x_inj = abe.nifits.ni_fov.xy2phasor(*acollec.coords_rad)
plt.figure()
plt.scatter(*acollec.coords, c=np.abs(x_inj[0,-1,:]), cmap="viridis", s=80)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title("The Fov function")
plt.xlabel("Relative position [mas]")
plt.show()
print(np.max(z))


### Can be resampled with `griddata` from scipy.

In [None]:
from scipy.interpolate import griddata
agrid = be.PointCollection.from_centered_square_grid(800., 512, md=np)
interped = griddata(acollec.coords, z[0,-1,0,:], agrid.coords_shaped, method="nearest")
plt.figure()
plt.imshow(interped, cmap="coolwarm", extent=agrid.extent)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

### Of course, this contains the whole time series:

In [None]:
print("One every 5 frames of the sereis", )
fig, axes = plt.subplots(2,10, sharex=True, sharey=True, figsize=(10,2.2), dpi=150)
for i, my_t_index in enumerate(range(len(newfits.ni_fov))[::5]):
    plt.sca(axes.flat[i])
    plt.scatter(*acollec.coords, c=z[my_t_index,-1,0,:], cmap="coolwarm", s=2)
    plt.title(my_t_index, fontsize=7)
    # plt.colorbar()
    plt.gca().set_aspect("equal")
#plt.title(f"The response map [$m^2$] collecting power.")
# plt.xlabel("Relative position [mas]")
plt.subplots_adjust(top=0.5)
plt.tight_layout()
plt.show()

### 4.1.3 A grid for a given using `PointCollection.from_grid`
This can help to work with **regions of interest** (ROI).

In [None]:
%%time
offset_collec = be.PointCollection.from_grid(np.linspace(50, 600, 100), np.linspace(-500,500, 100))
z = abe.get_all_outs(*offset_collec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*offset_collec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=5)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.xlim(-600, 600)
plt.ylim(-600, 600)
plt.show()

### 4.1.3 More fancy tools are available such as transformations, and addition

In [None]:
from scipy.spatial.transform import Rotation
acollec = be.PointCollection.from_uniform_disk(200., 400)
mymat = mymatrix = Rotation.from_euler("xyz", [60,0,30], degrees=True).as_matrix()
acollec.transform(mymat,)
bcollec = be.PointCollection.from_uniform_disk(600., 400)
offset_collec = be.PointCollection.from_grid(np.linspace(-50, -600, 50), np.linspace(-500,500, 50))
pcollec = be.PointCollection.from_uniform_disk(100., 300, offset=np.array((250., -50.)))
acollec.md = np
combined_collec = acollec + bcollec + offset_collec + pcollec
z = abe.get_all_outs(*combined_collec.coords_rad, kernels=True)

plt.figure()
plt.scatter(*combined_collec.coords, c=z[0,-1,0,:], cmap="coolwarm", s=2)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.xlim(-600, 600)
plt.ylim(-600, 600)
plt.show()

In [None]:
agrid = be.PointCollection.from_centered_square_grid(600., 512, md=np)
interped = griddata(combined_collec.coords, z[0,-1,0,:], agrid.coords_shaped, method="nearest")
plt.figure(dpi=200)
plt.imshow(interped, cmap="coolwarm", extent=agrid.extent)
plt.colorbar()
plt.gca().set_aspect("equal")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.title(f"The response map [$m^2$] collecting power.")
plt.xlabel("Relative position [mas]")
plt.show()

## 4.2 Temporal-varying samples:
A way to compute the transmision map for arrays of points movig between frames.

You can use MovingCollection which takes in a list of PointCollection -s and offer a similar friendly interface, catered for the `Backend.get_moving_outs()` method. Here is how it's done:


from dataclasses import dataclass, fields
from nifits.backend import PointCollection
from einops import rearrange




In [None]:
collects = [be.PointCollection.from_uniform_disk(100., 300,
                    offset=np.array((c, -50.))) for c in\
                    np.linspace(-300, 300,
                            len(abe.nifits.ni_mod))]
combined_collec = be.MovingCollection(collects)
z = abe.get_moving_outs(*combined_collec.coords_rad, kernels=True)


In [None]:
fig, axes = plt.subplots(5,2, sharex=True, sharey=True, figsize=(4,10))
for i, my_t_index in enumerate(range(len(newfits.ni_fov))[::10]):
    plt.sca(axes.flat[i])
    plt.scatter(*combined_collec.coords[:,my_t_index,:], c=z[my_t_index,-1,0,:], cmap="coolwarm", s=2)
    plt.title(my_t_index, fontsize=7)
    # plt.colorbar()
    plt.gca().set_aspect("equal")
#plt.title(f"The response map [$m^2$] collecting power.")
# plt.xlabel("Relative position [mas]")
plt.xlim(-600, 600)
plt.ylim(-600, 600)
plt.tight_layout()
# plt.show()

# 5. Your turn!

This space is dedicated to your own experimentations.