## Notebook to explore interpolation methods and accuracy

In [None]:
from utils import *
from tqdm import tqdm
from os.path import join, exists

RAW_DIR = 'raw_ds' if LOCAL else '/nfsd/automatica/grandinmat/raw_ds' # where the raw data is stored


In [None]:
# load the data
print("Loading data...")

# list all the files inside RAW_DIR/ds
files = sorted([f for f in os.listdir(RAW_DIR) if f.endswith('.mat')])
print(f'Found {len(files)} files.')
vals = {n:[] for n in DS_NAMES}  # dictionary to hold the numpy arrays
files_iter = tqdm(files, desc="Loading files", unit="file") if LOCAL else files
for f in files_iter:
    try:
    # if True:
        d = loadmat(join(RAW_DIR, f))
        for n in DS_NAMES: assert n in list(d.keys()) or n in [BR, BZ], f'Key {n} not found in {f}' 
        t = d['t'].flatten()  # time
        nt = t.shape[0]  # number of time points
        for n in DS_NAMES:
            if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
            v = d[n] # get the variable
            v = v.reshape((*DS_SIZES[n], nt)) # try to reshape it to the expected shape
            assert v.shape == (*DS_SIZES[n], nt), f'Variable {n} shape mismatch: {v.shape} != {((*DS_SIZES[n], nt))}'
            assert not np.isnan(v).any(), f'Variable {n} contains NaN values in {f}'
            assert np.isfinite(v).all(), f'Variable {n} contains infinite values in {f}'
            vals[n].append(v)  # append to the list
    except Exception as e:
        print(f'Error loading {f}: {e}')
        continue

print(f'Loaded {len(vals[FX])} files.')
assert len(vals[FX]) > 0, f'No samples: {len(vals[FX])}'

# convert lists to numpy arrays
print('---------------------------------------------------')
for n in DS_NAMES:
    if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
    vals[n] = np.concatenate(vals[n], axis=-1).astype(DTYPE)  # concatenate along the first dimension and convert to DTYPE
    sh, ndim = vals[n].shape, vals[n].ndim
    vals[n] = vals[n].transpose(ndim-1, *range(ndim-1))  # move the last dimension to the first
    # print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable
print('---------------------------------------------------')

# calculate BR and BZ from Fx
print('Calculating BR and BZ...') # TODO: this can be done using 2d convolutions
vals[BR], vals[BZ] = meqBrBz(vals[FX])

for n in DS_NAMES: print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable

for n in INPUT_NAMES: assert vals[n].ndim == 2, f'Variable {n} is not 2D: {vals[n].ndim}'

# assign values to variables
X = np.concatenate([vals[n] for n in INPUT_NAMES], axis=1)  # concatenate along the first dimension
print(f'-----------------------------------------\n X -> shape: {X.shape}, dtype: {X.dtype}')

# make IY the same shape as FX (and pad the sides)
tmpIy = np.zeros_like(vals[FX])  # create a zero array with the same shape as FX
tmpIy[:, 1:-1, 1:-1] = vals[IY]  # fill the inner part with IY values
# fill sides 
tmpIy[:, 0, 1:-1] = vals[IY][:, 0, :]  # left side
tmpIy[:, -1, 1:-1] = vals[IY][:, -1, :]  # right side
tmpIy[:, 1:-1, 0] = vals[IY][:, :, 0]  # top side
tmpIy[:, 1:-1, -1] = vals[IY][:, :, -1]  # bottom side
# corners
tmpIy[:, 0, 0] = vals[IY][:, 0, 0]  # top left corner
tmpIy[:, 0, -1] = vals[IY][:, 0, -1]  # top right corner
tmpIy[:, -1, 0] = vals[IY][:, -1, 0]  # bottom left corner
tmpIy[:, -1, -1] = vals[IY][:, -1, -1]  # bottom right corner
vals[IY] = tmpIy  # replace IY with the new array

# reassign the values to the variables
Y = {  
    FX:vals[FX], 
    IY:vals[IY], 
    BR:vals[BR],
    BZ:vals[BZ],
    RQ:vals[RQ],
    ZQ:vals[ZQ],
}

assert Y[FX].shape[0] > 0, f'No samples: {Y[FX].shape}'
N_OR = Y[FX].shape[0]  # number of original samples
print(f'Loaded {N_OR} samples.')

assert X.dtype == DTYPE, f'X dtype mismatch: {X.dtype} != {DTYPE}'
assert all(y.dtype == DTYPE for y in Y.values()), f'Y dtype mismatch: {[y.dtype for y in Y.values()]} != {DTYPE}'

assert X.shape == (N_OR, NIN), f'X shape mismatch: {X.shape} != ({N_OR}, {NIN})'

# check the shapes
print(f'X -> {X.shape}')
print(f'Y -> {[y.shape for y in Y.values()]}')

# print sizes in MB
print(f'X size: {X.nbytes / 1024**2:.2f} MB, Y size: {sum([y.nbytes for y in Y.values()]) / 1024**2:.2f} MB')