# MAML

In [None]:
from maml import *
import matplotlib.pyplot as plt

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current device:", device)

Current device: cuda


`MAML(num_inner_steps, inner_lr, outer_lr, num_data_i, num_data_b, num_data_f, low, high, eqname='burgers', zero_shot=False, load=False, modelpath=None, savename=None)`

Initializes First-Order Model-Agnostic Meta-Learning to train Physics-Informed Neural Networks.

        Args:
            num_inner_steps (int): number of inner-loop optimization steps
            inner_lr (float): learning rate for inner-loop optimization
            outer_lr (float): learning rate for outer-loop optimization
            num_data_i (int): number of initial data
            num_data_b (int): number of boundary data
            num_data_f (int): number of domain data
            low (float): low boundary of x
            high (float): high boundary of x
            eqname (String): type of equation, available options: 'burgers', 'poisson'
            zero_shot (boolean): whether to train zero_shot model or not
            load (boolean): whether to load pre-trained model weights from modelpath
            modelpath (String): model path to load
            savename (String): model path to save
            
After initializing MAML instances, use train() method to train FO-MAML model.

`train(self, train_steps, num_train_tasks, num_val_tasks):`

Train the MAML. Optimizes MAML meta-parameters.

        Args:
            train_steps (int): the number of steps this model should train for
            num_train_tasks (int): the number of train tasks
            num_val_tasks (int): the number of validation tasks
        
        Returns:
            train_loss (dict) contains inner_loss, inner_loss_i (if exists), inner_loss_b, inner_loss_f during training
            val_loss (dict) contains inner_loss_pre_adapt, inner_loss_i_pre_adapt (if exists), inner_loss_b_pre_adapt, inner_loss_f_pre_adapt, inner_loss, inner_loss_i (if exists), inner_loss_b, inner_loss_f during validation with ID tasks
            val_ood_loss (dict) contains inner_loss_pre_adapt, inner_loss_i_pre_adapt (if exists), inner_loss_b_pre_adapt, inner_loss_f_pre_adapt, inner_loss, inner_loss_i (if exists), inner_loss_b, inner_loss_f during validation with ID tasks during validation with OOD tasks 
            nrmse (dict) contains nrmse_val, nrmse_val_ood, nrmse_val_pre_adapt, nrsme_val_post_adapt
            


In [None]:
maml = MAML(5, 0.01, 0.0005, 0, 2, 1, -10, 10, eqname='poisson', zero_shot=True, load=False, modelpath='models/poisson_zs_2000_ref.data')
# maml = MAML(1, 0.01, 0.0001, 1, 2, 1, -1, 1, eqname='burgers', zero_shot=True, load=True, modelpath='models/model_ref/burgers_zs_1000_ref.data')
# maml = MAML(5, 0.01, 0.0001, 0, 2, 1, low=-1, high=1, eqname='poisson')

train_loss, val_loss, val_ood_loss, nrmse, model = maml.train(5000, 100, 100)

## Plot losses and metrics

### Post-adapt validation losses (in-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_loss_df['inner_loss']), label='loss')
plt.plot(x, np.array(val_loss_df['inner_loss_f']), label='loss_f')
plt.plot(x, np.array(val_loss_df['inner_loss_b']) * 10, label='loss_b')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Post-adapt validation losses (in-distribution)')

### Pre-adapt validation losses (in-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_loss_df['inner_loss_pre_adapt']), label='loss')
plt.plot(x, np.array(val_loss_df['inner_loss_f_pre_adapt']), label='loss_f')
plt.plot(x, np.array(val_loss_df['inner_loss_b_pre_adapt']) * 10, label='loss_b')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Pre-adapt validation losses (in-distribution)')

### Validation metrics (in-distribution)

In [None]:
data_len = len(nrmse['nrmse_val'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(nrmse['nrmse_val_pre_adapt']), label='pre-adapt')
plt.plot(x, np.array(nrmse['nrmse_val']), label='post-adapt')
plt.yscale('log')
plt.xlabel('Epochs')
plt.ylabel('NRMSE')
plt.legend()
plt.title('Validation metrics (in-distribution)')
# plt.plot(np.array(val_loss['inner_loss'])[:, -1])

### Post-adapt validation losses (out-of-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_ood_loss_df['inner_loss']), label='loss')
plt.plot(x, np.array(val_ood_loss_df['inner_loss_f']), label='loss_f')
plt.plot(x, np.array(val_ood_loss_df['inner_loss_b']) * 10, label='loss_b')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Post-adapt validation losses (out-of-distribution)')

### Pre-adapt validation losses (out-of-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_ood_loss_df['inner_loss_pre_adapt']), label='loss')
plt.plot(x, np.array(val_ood_loss_df['inner_loss_f_pre_adapt']), label='loss_f')
plt.plot(x, np.array(val_ood_loss_df['inner_loss_b_pre_adapt']) * 10, label='loss_b')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Pre-adapt validation losses (out-of-distribution)')

### Validation metrics (out-of-distribution)

In [None]:
data_len = len(nrmse['nrmse_val'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(nrmse['nrmse_val_ood_pre_adapt']), label='pre-adapt')
plt.plot(x, np.array(nrmse['nrmse_val_ood']), label='post-adapt')
plt.yscale('log')
plt.xlabel('Epochs')
plt.ylabel('NRMSE')
plt.legend()
plt.title('Validation metrics (out-of-distribution)')
# plt.plot(np.array(val_loss['inner_loss'])[:, -1])

### Pre-adapt vs. Post-adapt val. losses (in-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_loss_df['inner_loss_pre_adapt']), label='pre-adapt')
plt.plot(x, np.array(val_loss_df['inner_loss']), label='post-adapt')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Pre-adapt vs post-adapt validation losses (in-distribution)')

### Pre-adapt vs. Post-adapt val. losses (out-of-distribution)

In [None]:
data_len = len(val_loss_df['inner_loss'])
x = np.array([i * 5000 / data_len for i in range(data_len)])
plt.plot(x, np.array(val_ood_loss_df['inner_loss_pre_adapt']), label='pre-adapt')
plt.plot(x, np.array(val_ood_loss_df['inner_loss']), label='post-adapt')

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.yscale('log')
plt.title('Pre-adapt vs post-adapt validation losses (out-of-distribution)')

## Plot solutions - Poisson

In [26]:
from copy import deepcopy

In [None]:
task = -0.830, -0.617

In [None]:
test_x = np.linspace(-10, 10, num=100).reshape(-1, 1)
# test_alpha = np.full((100, 1), alpha[2])
# test_beta = np.full((100, 1), beta[2])
test_alpha = np.full((100, 1), task[0])
test_beta = np.full((100, 1), task[1])
test_in = np.hstack((test_x, test_alpha, test_beta))
test_u = model(torch.Tensor(test_in).to(device))
X = test_x

# Exact solution
Y = np.sin(test_alpha * X) + np.cos(test_beta * X) + 0.1 * X


MAML adaptation to task

In [None]:

phi, _, _, _, _, _, _ = model_adapted=maml._inner_loop(model.state_dict(), task, train=True)
model_adapted = deepcopy(model)
model_adapted.load_state_dict(phi)
Y2 = model_adapted(torch.Tensor(test_in).to(device)).detach().cpu().numpy()

In [None]:
C = test_u.cpu().detach().numpy()
plt.figure(figsize=(10, 8))
plt.plot(X, C, 'b-', label='MAML')
plt.plot(X, Y, 'r--', label='Answer')
plt.plot(X, Y2, 'g--', label='MAML adapted')
plt.legend()

### Print NRMSE

In [None]:
# np.sqrt( np.sum((C-Y)**2) / np.sum(C**2) )
np.sqrt( np.sum((C-Y2)**2) / np.sum(C**2) )

0.92754614

## Plot solutions - Burgers

In [None]:
from burgers import *

In [None]:
vtn = 101
vxn = 101
nu = 0.01 / np.pi
vx = np.linspace(-1, 1, vxn)
vt = np.linspace(0, 1, vtn)

vu = burgers_viscous_time_exact1(nu, vxn, vx, vtn, vt)

x, t = np.meshgrid(vx, vt)
x = x.reshape(-1, 1)
t = t.reshape(-1, 1)

plt.scatter(x, t, c=vu, cmap='seismic')
plt.colorbar()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Current device:", device)

In [None]:
vtn = 101
vxn = 101
nu = 0.01 / np.pi
vx = np.linspace(-1, 1, vxn)
vt = np.linspace(0, 1, vtn)
x, t = np.meshgrid(vx, vt)
x = x.reshape(-1, 1)
t = t.reshape(-1, 1)
alpha = np.full((x.shape), nu)
pred = model_2(torch.Tensor(np.hstack((x, t))).to(device)).detach().cpu().numpy()
# pred = model(torch.Tensor(np.hstack((x, t))).to(device)).detach().cpu().numpy()
truth = burgers_viscous_time_exact1(nu, vxn, vx, vtn, vt).T.reshape(-1, 1)

In [None]:
plt.scatter(x, t, c=pred, cmap='seismic')
plt.colorbar()