In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

In [None]:
import sys, os, pathlib
import numpy as np
import xarray as xr
import torch
import matplotlib.pyplot as plt
import seaborn as sns

os.environ['DDEBACKEND'] = 'pytorch'
import deepxde

sys.path.append('..')
%aimport mre_pinn

torch.cuda.is_available()

In [None]:
%autoreload
cohort = mre_pinn.data.ImagingCohort(['0006'])
cohort.load_images()
cohort.preprocess()
dataset = cohort.to_dataset()
dataset.save_xarrays('../data/NAFLD3')

In [None]:
example = mre_pinn.data.MREExample.load_xarrays('../data/NAFLD3', '0006')
example.metadata

In [None]:
example.describe()

In [None]:
example.eval_baseline(frequency=40, polar=True, postprocess=True)
example.view('base', ax_height=2, ax_width=2)

In [None]:
%autoreload
pde = mre_pinn.pde.WaveEquation.from_name('hetero', omega=40, detach=True)

In [None]:
%autoreload
pinn = mre_pinn.model.MREPINN(
    example,
    omega=60,
    n_layers=5,
    n_hidden=128,
    polar_input=True,
    conditional=False
)
pinn

In [None]:
%autoreload
model = mre_pinn.training.MREPINNModel(
    example, pinn, pde,
    loss_weights=[1, 0, 1e-16],
    pde_warmup_iters=5000,
    pde_step_iters=5000,
    pde_init_weight=1e-18,
    n_points=1024
)
model.compile(optimizer='adam', lr=1e-4, loss=mre_pinn.training.losses.msae_loss)

In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
model.benchmark(100)

In [None]:
test_eval = mre_pinn.testing.TestEvaluator(test_every=100, interact=True)
test_eval.model = model
test_eval.test()

In [None]:
model.train(100000, callbacks=[test_eval])