## Notebook to explore interpolation methods and accuracy

In [1]:
from utils import *
from tqdm import tqdm
from os.path import join, exists

RAW_DIR = 'raw_ds' if LOCAL else '/nfsd/automatica/grandinmat/raw_ds' # where the raw data is stored


Running JOBID: local, on mps, GPU_MEM: 17.18 GB


In [5]:
# load the data
print("Loading data...")

# list all the files inside RAW_DIR/ds
files = sorted([f for f in os.listdir(RAW_DIR) if f.endswith('.mat')])
print(f'Found {len(files)} files.')
vals = {n:[] for n in DS_NAMES}  # dictionary to hold the numpy arrays
files_iter = tqdm(files, desc="Loading files", unit="file") if LOCAL else files
for f in files_iter:
    try:
        d = loadmat(join(RAW_DIR, f))
        for n in DS_NAMES: assert n in list(d.keys()) or n in [BR, BZ], f'Key {n} not found in {f}' 
        t = d['t'].flatten()  # time
        nt = t.shape[0]  # number of time points
        for n in DS_NAMES:
            if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
            v = d[n] # get the variable
            v = v.reshape((*DS_SIZES[n], nt)) # try to reshape it to the expected shape
            assert v.shape == (*DS_SIZES[n], nt), f'Variable {n} shape mismatch: {v.shape} != {((*DS_SIZES[n], nt))}'
            assert not np.isnan(v).any(), f'Variable {n} contains NaN values in {f}'
            assert np.isfinite(v).all(), f'Variable {n} contains infinite values in {f}'
            vals[n].append(v)  # append to the list
    except Exception as e:
        print(f'Error loading {f}: {e}')
        continue

print(f'Loaded {len(vals[FX])} files.')
assert len(vals[FX]) > 0, f'No samples: {len(vals[FX])}'

# convert lists to numpy arrays
print('---------------------------------------------------')
for n in DS_NAMES:
    if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
    vals[n] = np.concatenate(vals[n], axis=-1).astype(DTYPE)  # concatenate along the first dimension and convert to DTYPE
    sh, ndim = vals[n].shape, vals[n].ndim
    vals[n] = vals[n].transpose(ndim-1, *range(ndim-1))  # move the last dimension to the first
    # print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable
print('---------------------------------------------------')


# calculate BR and BZ from Fx
print('Calculating BR and BZ...') # TODO: this can be done using 2d convolutions
vals[BR], vals[BZ] = meqBrBz(vals[FX])

for n in DS_NAMES: print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable
for n in INPUT_NAMES: assert vals[n].ndim == 2, f'Variable {n} is not 2D: {vals[n].ndim}'

Loading data...
Found 500 files.


Loading files: 100%|██████████| 500/500 [00:02<00:00, 182.19file/s]


Loaded 500 files.
---------------------------------------------------
---------------------------------------------------
Calculating BR and BZ...
Bm0 -> shape: (50394, 38), dtype: float32
Ff0 -> shape: (50394, 38), dtype: float32
Ft0 -> shape: (50394, 1), dtype: float32
Ia0 -> shape: (50394, 19), dtype: float32
Ip0 -> shape: (50394, 1), dtype: float32
Iu0 -> shape: (50394, 38), dtype: float32
rBt0 -> shape: (50394, 1), dtype: float32
Fx -> shape: (50394, 65, 28), dtype: float32
Iy -> shape: (50394, 63, 26), dtype: float32
Br -> shape: (50394, 65, 28), dtype: float32
Bz -> shape: (50394, 65, 28), dtype: float32
rq -> shape: (50394, 129), dtype: float32
zq -> shape: (50394, 129), dtype: float32


In [None]:
rand_idx = np.random.randint(0, vals[FX].shape[0], 1)  # random idx
# rand_idx = 18796
rand_idx = 42278

npts = 1213
lin = 'linear'
quint = 'quintic'
zidx = 29
r0, r1 = 0.81, 0.86  # radial range
rg0, rg1 = RRD[0, 0], RRD[0, -1]  # radial grid range
zg0, zg1 = ZZD[0, 0], ZZD[-1, 0]  # vertical grid range

pts = np.zeros((npts, 2))
r = np.linspace(r0, r1, npts)  # radial points
# z = np.linspace(-0.2, 0.2, npts)  # vertical points
# z = np.random.uniform(-0.2, 0.2) * np.ones(npts)
z = np.ones(npts) * ZZD[zidx, 0]
pts[:, 0] = r
pts[:, 1] = z

Fx = vals[FX][rand_idx,:,:].squeeze()
Br = vals[BR][rand_idx,:,:].squeeze()
Bz = vals[BZ][rand_idx,:,:].squeeze()

# get the points on that z 
rg = RRD[0,:]  # radial grid
Fxg = Fx[zidx,:]  # get the Fx values at the specified z index
Brg = Br[zidx,:]  # get the Br values at the specified z index
Bzg = Bz[zidx,:]  # get the Bz values at the specified z index
# keep only the points in the radial range
cond = (rg >= r0) & (rg <= r1)
rg = rg[cond]
Fxg = Fxg[cond]
Brg = Brg[cond]
Bzg = Bzg[cond]

# interpolate Fx, Br, Bz at the points pts
# quintic interpolation
Fxq = interp_pts(Fx, pts, gr=RRD[0,:], gz=ZZD[:,0], method=quint)
Brq = interp_pts(Br, pts, gr=RRD[0,:], gz=ZZD[:,0], method=quint)
Bzq = interp_pts(Bz, pts, gr=RRD[0,:], gz=ZZD[:,0], method=quint)
# liner interpolation
Fxl = interp_pts(Fx, pts, gr=RRD[0,:], gz=ZZD[:,0], method=lin)
Brl = interp_pts(Br, pts, gr=RRD[0,:], gz=ZZD[:,0], method=lin)
Bzl = interp_pts(Bz, pts, gr=RRD[0,:], gz=ZZD[:,0], method=lin)

# test calculating Br, Bz with a finer grid
K = 13
rf = np.linspace(rg0, rg1, 28*K)
zf = np.linspace(zg0, zg1, 65*K)
rrf, zzf = np.meshgrid(rf, zf)  # radial and vertical grid for the finer grid
rzf = np.stack((rrf, zzf), axis=-1)  # stack to get the points
Fxf = interp_pts(Fx, rzf.reshape(-1,2), gr=RRD[0,:], gz=ZZD[:,0], method=quint)
Brf, Bzf = meqBrBz(Fxf, rr=rrf, zz=zzf)  # calculate Br, Bz from the flux map

# plt.figure(figsize=(12, 12))
plt.figure(figsize=(12, 4))
# plt.subplot(3, 1, 1)
# plt.scatter(rg, Fxg, label='Fxg')
# plt.plot(r, Fxq, label='Fxq', linestyle=':')
# plt.plot(r, Fxl, label='Fxl', linestyle=':')
# plt.legend()
# plt.subplot(3, 1, 2)
plt.scatter(rg, Brg, label='Brg')
plt.plot(r, Brq, label='Brq', linestyle=':')
plt.plot(r, Brl, label='Brl', linestyle=':')
plt.plot(rf, Brf[zidx,:], label='Brf', linestyle='--')
plt.legend()
# plt.subplot(3, 1, 3)
# plt.scatter(rg, Bzg, label='Bzg')
# plt.plot(r, Bzq, label='Bzq', linestyle=':')
# plt.plot(r, Bzl, label='Bzl', linestyle=':')
# plt.legend()
# plt.suptitle(f'{rand_idx} Interpolation of Fx, Br, Bz at z={ZZD[zidx, 0]:.2f} m')
# plt.tight_layout()
plt.show()




AssertionError: pts.ndim = 3, pts.shape = (845, 364, 2)