In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
#%matplotlib notebook
#%matplotlib qt
    
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
from boutdata.collect import collect
from tqdm import tqdm
import xbout

def getgrid(ds):
    gridfile = ds.options["grid"]
    with xr.open_dataset(f"~/soft/zoidberg-w7x/{gridfile}") as grid:
        return grid

def minmax(x):
    print(np.min(x), np.max(x))

pre = "residuum_"

In [None]:
path="/ptmp/dave/hermes-2/7-emc3.c156/"
path="/ptmp/dave/hermes-3/examples/tokamak/diffusion-flow-evolveT.c0/"
path="/u/dave/soft/hermes-3/fci-auto/examples/tokamak/diffusion-flow-evolveT/"
path="/u/dave/soft/hermes-3/fci-auto/examples/stellarator/diffusion-flow-evolveT/"
path="/ptmp/dave/hermes-3/examples/stellarator/diffusion-flow-evolveT.c23/"

ds = xbout.open_boutdataset(datapath=f"{path}/BOUT.debug.*.nc", geometry='fci',
                            gridfilepath='/u/dave/soft/hermes-3/auto-fci/W7X.nc',
                            inputfilepath=path + "BOUT.settings",
                            info=False, 
                            grid_kw=dict(drop_variables=["offset_3x3"]), 
                            #use_modules=False
                           )

if "Ne" in ds and "normalised_by" in ds.Ne.attrs:
    ds["Ne"] /= ds.Ne.attrs["normalised_by"]
    ds.Ne.attrs["normalised_by"]=1

def gettol(ds):
    try:
        atol = ds.options["solver:atol"]
    except KeyError:
        atol = 1e-12
    try:
        rtol = ds.options["solver:rtol"]
    except KeyError:
        rtol = 1e-5
    return atol, rtol
pre = "residuum_"
worst = {}
def get_res(ds, k):
        var = ds[k]
        res = ds[pre + k]
        atol, rtol = gettol(ds)
        err = np.abs(var) * rtol + atol
        where = np.abs(res) > err
        return res, err, where


def check_res(ds):
    keys = [k.split("_", 1)[1] for k in ds if k.startswith(pre)]
    for k in keys:
        print(f"Residuum for {k}:")
        res, err, where = get_res(ds, k)
        sumwhere = np.sum(where.compute())
        if not np.all(np.isfinite(err)):
            print(f"Not finite at {err.size - np.sum(np.isfinite(err))} of {err.size} points")
        if sumwhere:
            print(f"Above limit at {sumwhere} not fullfilled")
        else:
            print(f"Limit fullfilled everywhere")
        w = np.unravel_index(np.argmax(np.abs(res.values)), res.shape)
        print(f"Maximum residuum at {w}")
        worst[k] = w
        print("Maximum value is", np.max(err).values)

check_res(ds)

In [None]:
plt.figure()
k0 = "Ph+"
w = worst[k0]
isel = dict(y=w[1], z=w[2])#, x=slice(2, -2))
for k in [k0] + [f"track_ddt_{k0}_{k}" for k in range(20)]:
    if k not in ds:
        continue
    label = ds[k].attrs.get("rhs.name", k)
    #print( ds[k].attrs)
    cutoff = 50
    if len(label) > cutoff:
        label=label[:cutoff-3] + "..."
    ds[k].isel(**isel).plot(label=label)
plt.yscale("symlog")
plt.legend()

if 0:
    plt.figure()
    for k in "Pe", "track_ddt_Pe_3", "track_ddt_Pe_5":
        np.log(np.abs(ds[k].isel(**isel))).plot(label=k)
    yll, ylu = plt.ylim()
    ylu -= yll
    ylu /= 4
    yll += ylu * 2
    for k in "Pe", "track_ddt_Pe_3", "track_ddt_Pe_5":
        (yll + ylu * np.sign(ds[k].isel(**isel))).plot(label="sign "+k)
    plt.legend()
#(ds.Pe.isel(**isel)).plot()
#(ds.track_ddt_Pe_3.isel(**isel)/1e-3).plot()
#(ds.track_ddt_Pe_5.isel(**isel)/1e-3).plot()


In [None]:
from boutdata.collect import collect

In [None]:
# Pe = collect("Pe", path=path, prefix="BOUT.debug", info=False)

In [None]:
worst

In [None]:
plt.plot(Pe[2, :, 637])

In [None]:
ds.Ne.attrs

In [None]:
ds.Pe[:, 0, 0].values

In [None]:
np.prod(ds.Pe.shape)

In [None]:
plt.figure()
w=worst["NVi"]
isel = dict(y=w[1], z=w[2], x=slice(2, -2))
for k in ["NVi"] + [f"track_ddt_NVi_{k}" for k in [2,3,4,5]]:
    label = ds[k].attrs.get("rhs.name", k)
    #print( ds[k].attrs)
    ds[k].isel(**isel).plot(label=label)
plt.yscale("symlog")
plt.legend()

In [None]:
plt.figure()
isel = dict(y=4, z=43, x=slice(2, -2))
k0 = "Pe"
for k in [k0] + [f"track_ddt_{k0}_{i}" for i in range(2, dict(NVi=6, Pe=8, Ne=5)[k0])]:
    label=" ".join([k] + [ds[k].attrs.get(k2, "") for k2 in ["rhs.name", "operation"]])
    ds[k].isel(**isel).plot(label=label)
plt.yscale("symlog")
plt.legend()


In [None]:
def check_res_more(ds, key):
    print(f"Check {key} ...")
    var = ds[key]
    res = ds[pre + key]
    atol, rtol = gettol(ds)
    err = np.abs(var) * rtol + atol
    where = np.abs(res) > err
    for i, j in enumerate(np.sum(where, axis=(0, 2))):
        if j:
            print(f"in slice jy={i}  there are {j} non-converged")
    plt.figure()
    plt.plot(np.arange(len(ds.y)), np.max(np.abs(res), axis=(0, 2)), label="Max | Residuum |")
    plt.plot(np.arange(len(ds.y)), np.max(np.abs(var), axis=(0, 2)), label="Max | Value |")
    plt.plot(np.arange(len(ds.y)), np.max(np.abs(err), axis=(0, 2)), label="Max | Error |")
    plt.title(key)
    plt.legend()
    plt.gca().set_yscale('log')
    #if np.sum(where):
    
check_res_more(ds, 'Pe')

In [None]:
from boututils.datafile import DataFile as DF

with DF(path+"/BOUT.debug.0.nc") as f:
    Ne = f["Ne"]
print(Ne.shape)
plt.plot(Ne[:, 0, 1])

def check_ddt_component(i, val):
    try:
        val.shape
    except AttributeError:
        print(i, val)
        return
    label = " ".join([str(x) for x in [i] + [val.attrs[k] for k in ["rhs.name", "operation"]]])
    tmp = np.abs(val.values)
    co = 1e200
    tmp[tmp > co] = co
    co = 1e-100
    tmp[tmp < co] = co
    plt.plot(np.nanmax(tmp[2:-2], axis=(0, 2)), label=label)
    #print(i)
    #for k in "rhs.name", "operation":
    #    print(k, val.attrs[k])
    
def check_ddt(ds, key):
    vals = []
    for i in range(1, 1000):
        k = f"track_ddt_{key}_{i}"
        print(k)
        try:
            vals.append(ds[k])
        except KeyError: 
            try:
                vals.append(ds.attrs["metadata"][k])
            except KeyError:
                break
    plt.figure()
    for i, val in enumerate(vals):
        check_ddt_component(i + 1, val)
    print(len(vals))
    plt.legend()
    plt.gca().set_yscale('log')
    if plt.ylim() == (1, 10):
        plt.ylim(1e-10, 1e250)

        
#check_ddt(ds, "Ne")
#check_ddt(ds, "Pe")
check_ddt(ds, "NVi")
check_ddt(ds, "Pi")

In [None]:
ds.attrs["metadata"]["track_ddt_Pe_1"]
#ds.attrs

In [None]:
check_ddt(ds, "Pi")

In [None]:
key0 = "Pi"
yid = worst[key0][1]
yid=10
print(yid)
xlim = None,
ylim = None,
#xlim = 4.5,4.8
#ylim = 0.5,0.7
xslc =slice(2, -2)
%matplotlib qt
_, _, where = get_res(ds, key0)
for i in []:
    plt.figure(figsize=(10, 15))
    key = f"track_ddt_{key0}_{i}"
    label = key
    for fu in "operation", "rhs.name":
        #print(ds[key].attrs[fu])
        label += " " + ds[key].attrs.get(fu,"")
    RZ = [ds[k].isel(y=yid, x=xslc).T for k in "RZ"]
    theta = np.linspace(0, np.pi*2, endpoint=False)
    drz = [f(theta)*.01 for f in [np.sin, np.cos]]
    whereh = where.isel(y=yid, x=xslc)
    print( np.array(np.where(whereh)).T)
    if np.sum(whereh) < 20:
        for i,j in np.array(np.where(whereh)).T:
            plt.plot(*[dx+X.values[j, i] for dx, X in zip(drz, RZ)], "r-")

    ds[key].isel(y=yid, x=xslc).bout.pcolormesh(ax=plt.gca())
    #print(ds[key].attrs)
    #plt.plot(*RZ)
    plt.title(label)
    plt.ylim(*ylim)
    plt.xlim(*xlim)
if 1:
    ds[f"{key0}"].isel(y=yid, x=xslc).bout.pcolormesh(vmax=0.23, vmin=0.15)
    if np.sum(whereh) < 20:
        for i,j in np.array(np.where(whereh)).T:
            plt.plot(*[dx+X.values[j, i] for dx, X in zip(drz, RZ)], "r-")
    plt.ylim(*ylim)
    plt.xlim(*xlim)
    #plt.plot(*RZ)
None

In [None]:
grid = getgrid(ds)
gi = grid.isel(y=10)
plt.figure()
plt.pcolormesh(gi.R, gi.Z, gi.forward_xt_prime > 128)
gi.dims

In [None]:
def argmax(d):
    return np.unravel_index(np.argmax(d), d.shape)

ijk = argmax(ds.track_ddt_Ne_3.values)

In [None]:
gridname = ds.attrs["options"]["grid"]
with DF(f"/u/dave/soft/zoidberg-w7x/{gridname}") as gridfile:
    
    delp2 = [gridfile[f"delp2_3x3_{i}"] for i in range(9)]
    pass

In [None]:
delp2 = np.array(delp2)
idelp = delp2[(slice(None), *ijk)]

In [None]:
plt.figure()
plt.imshow(idelp.reshape(3,3))
plt.colorbar()

In [None]:
for i in "p", "m", "p_temp", "m_temp":
    ds[f"ddt_NVi_Gnv{i}"].isel(y=4, x=slice(2,-2)).bout.pcolormesh()
    plt.ylim(None, -.7)
    plt.xlim(5.5, 5.7)
ds[f"NVi"].isel(y=4, x=slice(2,-2)).bout.pcolormesh()
plt.ylim(None, -.7)
plt.xlim(5.5, 5.7)

In [None]:
    ds[f"forward_xt_prime"].isel(y=4, x=slice(2,-2)).bout.pcolormesh(vmin=130)
    plt.ylim(None, -.7)
    plt.xlim(5.5, 5.7)

In [None]:
    ds[f"backward_xt_prime"].isel(y=4, x=slice(2,-2)).bout.pcolormesh(vmin=130, vmax=130.001)
    plt.ylim(None, -.7)
    plt.xlim(5.5, 5.7)

In [None]:
for i in range(2,5):
    print(ds[f"track_ddt_NVi_{i}"].attrs["rhs.name"])

In [None]:
for i in 2,3,4:
    key = f"track_ddt_NVi_{i}"
    ds[key].isel(y=2, x=slice(2, -2)).bout.pcolormesh()
    plt.title(ds[key].attrs["rhs.name"])
    plt.ylim(.7, None)
    plt.xlim(5.1, 5.8)

In [None]:
i=2
key=f"track_ddt_Ne_{i}"
key = "residuum_Ne"
key ="NVi"
ds[key].isel(y=2, x=slice(2, -2)).bout.pcolormesh()

In [None]:
plt.figure()
ds[key].isel(y=1).plot()

In [None]:
grid = getgrid(ds)
plt.figure(figsize=(15,15))
gi = grid.isel(y=27)
#grid.dz.isel(y=27).plot()
plt.pcolormesh(gi.R, gi.Z, gi.dz)

In [None]:
nvivi = ds.NVi**2 / ds.Ne

In [None]:
ds.NVi.isel(y=4).bout.pcolormesh()
plt.ylim(-1, -.75)
plt.xlim(5.4, 5.7)

In [None]:
ds.NVi.isel(y=2).bout.pcolormesh()
plt.ylim(-1, -.75)
plt.xlim(5.35, 5.6)

In [None]:
grid = getgrid(ds)

In [None]:
ward = [x for x in grid if "ward" in x]
for w in ward:
    ds[w] = grid[w]
for k in "geometry", "sizes", "metadata":
    ds = xbout.utils._set_attrs_on_all_vars(ds, k, getattr(ds.Ne, k))

In [None]:
#ds.emc3.plot_rz("backward_xt_prime", phii=i)
dsi = ds.isel(y=2)
plt.figure()
plt.pcolormesh(dsi.R, dsi.Z, (dsi.forward_xt_prime > 128))
plt.ylim(-1, -.75)
plt.xlim(5.35, 5.6)

In [None]:
ds.backward_xt_prime.isel(y=2).metadata

In [None]:
setattr?

In [None]:
evo = xbout.open_boutdataset(datapath=f"{path}/BOUT.dmp*.nc", geometry='fci', gridfilepath='/u/dave/soft/hermes-2/', inputfilepath=path + "BOUT.settings", info=False)

In [None]:
plt.figure()
(np.abs(evo.kappa_ipar)).max(dim=('x', 'z')).plot()

In [None]:
plt.figure()
print(list(enumerate(np.log(np.abs(evo.kappa_ipar)).isel(y=4).max(dim=('x', 'z')).values)))

In [None]:
for t in range(0, 14):
    evo.isel(t=t, y=4).Pi.bout.pcolormesh(vmax=500)
    plt.ylim(-1, -.75)
    plt.xlim(5.5, 5.75)

In [None]:
for t in range(0, 23, 4):
    evo.isel(t=t, y=4).NVi.bout.pcolormesh()
    #plt.ylim(-1, -.75)
    #plt.xlim(5.5, 5.75)

In [None]:
keys = []
pre = "residuum_"
with xr.open_dataset(f"{path}/BOUT.debug.0.nc") as ds:
    for k in ds:
        if k.startswith(pre):
            keys.append(k.removeprefix(pre))

print(keys)

rtol = 1e-5
atol = 1e-10
def load(k):
    res = collect(pre + k, prefix='BOUT.deb2', path=path, info=False)
    var = collect(k, prefix='BOUT.deb2', path=path, info=False)
    return res, var

for k in keys:
    res, var = load(k)
    err = np.abs(var) * rtol + atol
    where = np.abs(res) > err
    print(k, np.sum(where), np.unravel_index(np.argmax(np.abs(res)), res.shape), np.max(err))


In [None]:
patho = "/raven/ptmp/dave/hermes-2/7-emc3.c92.c0/"
evoo = xbout.open_boutdataset(datapath=f"{patho}/BOUT.dmp*.nc", geometry='fci', gridfilepath='/u/dave/soft/hermes-2/', inputfilepath=patho + "BOUT.settings", info=False)

In [None]:
for t in range(0, len(evoo.t)):
    evoo.isel(t=t, y=4).NVi.bout.pcolormesh()
    #plt.ylim(-1, -.75)
    #plt.xlim(5.5, 5.75)

In [None]:
path

In [None]:
import zoidberg as zb

In [None]:
from zoidberg import diff

In [None]:
ds.R.values

In [None]:
for args in ((0, False), (2, True)):
    dist = np.sqrt(diff.c2(ds.R, *args)**2 + diff.c2(ds.Z, *args)**2)
    print(np.array([np.mean(dist), np.max(dist), *np.percentile(dist, [50, 90, 99])])/12e-4)

In [None]:
ds.R.shape
128*2.54 + 2, 36, 768*5.79

In [None]:
fn = "/u/dave/soft/zoidberg-w7x/v17/W7X-conf0-132x36x768.emc3.inner:f.vessel:f.island:f.fci.nc"
with xr.open_dataset(fn, drop_variables=["offset_3x3"]) as grid:
    pass

In [None]:
gi = grid.isel(y=2)
plt.plot(gi.R, gi.Z)
plt.plot(gi.R.T, gi.Z.T)

In [None]:
grid.isel(x=2)["forward_xt_prime"] = 2
grid.isel(x=2)["backward_xt_prime"] = 2
grid.to_netcdf(fn[:-3] + ".force_inner.nc")