In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload

import sys
import torch
import holoviews as hv
from holoviews import opts
hv.extension('bokeh', 'matplotlib')

sys.path.append('..')
import mre_siren

In [25]:
%autoreload

class Slicer(object):
    def __getitem__(self, idx):
        return idx
s = Slicer()

data = mre_siren.bioqic.BIOQICDataset(
    data_root='../data/BIOQIC',
    select=dict(freq=70, MEG='y'),
    downsample=4,
    dtype=torch.float32,
    device='cuda',
    verbose=False
)

data.view(var='phase_unshifted', scale=8)

In [27]:
data.arr.shape, data.x.shape, data.u.shape

((8, 7, 32, 20), torch.Size([35840, 4]), torch.Size([35840, 1]))

In [28]:
%autoreload
from torch.utils.data import DataLoader
data_loader = DataLoader(data, batch_size=data.arr.size, shuffle=True)

In [43]:
%autoreload

from torch.nn import functional as F
from torch import optim
import pandas as pd

def compute_pde_loss(x, u_pred):
    return torch.zeros(1).cuda()

model = mre_siren.models.SIREN(
    n_input=data.x.shape[1], n_output=2, n_hidden=256, n_layers=3
).float().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 1000
out_prefix = 'TEST'

metric_index = ['epoch', 'batch']
metrics = pd.DataFrame(columns=metric_index).set_index(metric_index)

for i in range(n_epochs):

    for j, (x, u) in enumerate(data_loader):
    
        u_pred, G_pred = torch.split(model.forward(x), 1, dim=1)

        mse_loss = F.mse_loss(u, u_pred)
        pde_loss = compute_pde_loss(x, u_pred)
        loss = mse_loss + pde_loss
        
        metrics.loc[(i,j), 'mse_loss'] = mse_loss.item()
        metrics.loc[(i,j), 'pde_loss'] = pde_loss.item()
        metrics.loc[(i,j), 'loss'] = loss.item()

        print(f'[epoch = {i}, batch = {j}] mse_loss = {mse_loss.item():.4f}, ' + 
            f'pde_loss = {pde_loss.item():.4f}, loss = {loss.item():.4f}')

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()


[epoch = 0, batch = 0] mse_loss = 409.1010, pde_loss = 0.0000, loss = 409.1010
[epoch = 1, batch = 0] mse_loss = 408.7883, pde_loss = 0.0000, loss = 408.7883
[epoch = 2, batch = 0] mse_loss = 408.5021, pde_loss = 0.0000, loss = 408.5021
[epoch = 3, batch = 0] mse_loss = 408.2135, pde_loss = 0.0000, loss = 408.2135
[epoch = 4, batch = 0] mse_loss = 407.9086, pde_loss = 0.0000, loss = 407.9086
[epoch = 5, batch = 0] mse_loss = 407.5825, pde_loss = 0.0000, loss = 407.5825
[epoch = 6, batch = 0] mse_loss = 407.2309, pde_loss = 0.0000, loss = 407.2309
[epoch = 7, batch = 0] mse_loss = 406.8469, pde_loss = 0.0000, loss = 406.8469
[epoch = 8, batch = 0] mse_loss = 406.4449, pde_loss = 0.0000, loss = 406.4449
[epoch = 9, batch = 0] mse_loss = 406.0289, pde_loss = 0.0000, loss = 406.0289
[epoch = 10, batch = 0] mse_loss = 405.6173, pde_loss = 0.0000, loss = 405.6173
[epoch = 11, batch = 0] mse_loss = 405.2132, pde_loss = 0.0000, loss = 405.2132
[epoch = 12, batch = 0] mse_loss = 404.7845, pde_l

[epoch = 103, batch = 0] mse_loss = 317.0431, pde_loss = 0.0000, loss = 317.0431
[epoch = 104, batch = 0] mse_loss = 315.7532, pde_loss = 0.0000, loss = 315.7532
[epoch = 105, batch = 0] mse_loss = 314.3314, pde_loss = 0.0000, loss = 314.3314
[epoch = 106, batch = 0] mse_loss = 312.8423, pde_loss = 0.0000, loss = 312.8423
[epoch = 107, batch = 0] mse_loss = 311.4303, pde_loss = 0.0000, loss = 311.4303
[epoch = 108, batch = 0] mse_loss = 309.9875, pde_loss = 0.0000, loss = 309.9875
[epoch = 109, batch = 0] mse_loss = 308.6180, pde_loss = 0.0000, loss = 308.6180
[epoch = 110, batch = 0] mse_loss = 307.2470, pde_loss = 0.0000, loss = 307.2470
[epoch = 111, batch = 0] mse_loss = 305.7144, pde_loss = 0.0000, loss = 305.7144
[epoch = 112, batch = 0] mse_loss = 304.2006, pde_loss = 0.0000, loss = 304.2006
[epoch = 113, batch = 0] mse_loss = 302.7054, pde_loss = 0.0000, loss = 302.7054
[epoch = 114, batch = 0] mse_loss = 301.2779, pde_loss = 0.0000, loss = 301.2779
[epoch = 115, batch = 0] mse

[epoch = 205, batch = 0] mse_loss = 162.8089, pde_loss = 0.0000, loss = 162.8089
[epoch = 206, batch = 0] mse_loss = 161.3113, pde_loss = 0.0000, loss = 161.3113
[epoch = 207, batch = 0] mse_loss = 159.8920, pde_loss = 0.0000, loss = 159.8920
[epoch = 208, batch = 0] mse_loss = 158.4949, pde_loss = 0.0000, loss = 158.4949
[epoch = 209, batch = 0] mse_loss = 157.0310, pde_loss = 0.0000, loss = 157.0310
[epoch = 210, batch = 0] mse_loss = 155.7499, pde_loss = 0.0000, loss = 155.7499
[epoch = 211, batch = 0] mse_loss = 154.3539, pde_loss = 0.0000, loss = 154.3539
[epoch = 212, batch = 0] mse_loss = 152.7763, pde_loss = 0.0000, loss = 152.7763
[epoch = 213, batch = 0] mse_loss = 151.3555, pde_loss = 0.0000, loss = 151.3555
[epoch = 214, batch = 0] mse_loss = 150.0220, pde_loss = 0.0000, loss = 150.0220
[epoch = 215, batch = 0] mse_loss = 148.6231, pde_loss = 0.0000, loss = 148.6231
[epoch = 216, batch = 0] mse_loss = 147.2197, pde_loss = 0.0000, loss = 147.2197
[epoch = 217, batch = 0] mse

[epoch = 308, batch = 0] mse_loss = 70.5513, pde_loss = 0.0000, loss = 70.5513
[epoch = 309, batch = 0] mse_loss = 69.9890, pde_loss = 0.0000, loss = 69.9890
[epoch = 310, batch = 0] mse_loss = 69.4129, pde_loss = 0.0000, loss = 69.4129
[epoch = 311, batch = 0] mse_loss = 68.8487, pde_loss = 0.0000, loss = 68.8487
[epoch = 312, batch = 0] mse_loss = 68.3202, pde_loss = 0.0000, loss = 68.3202
[epoch = 313, batch = 0] mse_loss = 67.8224, pde_loss = 0.0000, loss = 67.8224
[epoch = 314, batch = 0] mse_loss = 67.3130, pde_loss = 0.0000, loss = 67.3130
[epoch = 315, batch = 0] mse_loss = 66.7774, pde_loss = 0.0000, loss = 66.7774
[epoch = 316, batch = 0] mse_loss = 66.2506, pde_loss = 0.0000, loss = 66.2506
[epoch = 317, batch = 0] mse_loss = 65.7640, pde_loss = 0.0000, loss = 65.7640
[epoch = 318, batch = 0] mse_loss = 65.2999, pde_loss = 0.0000, loss = 65.2999
[epoch = 319, batch = 0] mse_loss = 64.8294, pde_loss = 0.0000, loss = 64.8294
[epoch = 320, batch = 0] mse_loss = 64.3537, pde_los

[epoch = 412, batch = 0] mse_loss = 33.6550, pde_loss = 0.0000, loss = 33.6550
[epoch = 413, batch = 0] mse_loss = 33.4256, pde_loss = 0.0000, loss = 33.4256
[epoch = 414, batch = 0] mse_loss = 33.2008, pde_loss = 0.0000, loss = 33.2008
[epoch = 415, batch = 0] mse_loss = 32.9822, pde_loss = 0.0000, loss = 32.9822
[epoch = 416, batch = 0] mse_loss = 32.7693, pde_loss = 0.0000, loss = 32.7693
[epoch = 417, batch = 0] mse_loss = 32.5611, pde_loss = 0.0000, loss = 32.5611
[epoch = 418, batch = 0] mse_loss = 32.3566, pde_loss = 0.0000, loss = 32.3566
[epoch = 419, batch = 0] mse_loss = 32.1575, pde_loss = 0.0000, loss = 32.1575
[epoch = 420, batch = 0] mse_loss = 31.9630, pde_loss = 0.0000, loss = 31.9630
[epoch = 421, batch = 0] mse_loss = 31.7720, pde_loss = 0.0000, loss = 31.7720
[epoch = 422, batch = 0] mse_loss = 31.5819, pde_loss = 0.0000, loss = 31.5819
[epoch = 423, batch = 0] mse_loss = 31.3906, pde_loss = 0.0000, loss = 31.3906
[epoch = 424, batch = 0] mse_loss = 31.1974, pde_los

[epoch = 516, batch = 0] mse_loss = 17.3785, pde_loss = 0.0000, loss = 17.3785
[epoch = 517, batch = 0] mse_loss = 17.2724, pde_loss = 0.0000, loss = 17.2724
[epoch = 518, batch = 0] mse_loss = 17.1667, pde_loss = 0.0000, loss = 17.1667
[epoch = 519, batch = 0] mse_loss = 17.0621, pde_loss = 0.0000, loss = 17.0621
[epoch = 520, batch = 0] mse_loss = 16.9588, pde_loss = 0.0000, loss = 16.9588
[epoch = 521, batch = 0] mse_loss = 16.8575, pde_loss = 0.0000, loss = 16.8575
[epoch = 522, batch = 0] mse_loss = 16.7579, pde_loss = 0.0000, loss = 16.7579
[epoch = 523, batch = 0] mse_loss = 16.6599, pde_loss = 0.0000, loss = 16.6599
[epoch = 524, batch = 0] mse_loss = 16.5627, pde_loss = 0.0000, loss = 16.5627
[epoch = 525, batch = 0] mse_loss = 16.4655, pde_loss = 0.0000, loss = 16.4655
[epoch = 526, batch = 0] mse_loss = 16.3683, pde_loss = 0.0000, loss = 16.3683
[epoch = 527, batch = 0] mse_loss = 16.2712, pde_loss = 0.0000, loss = 16.2712
[epoch = 528, batch = 0] mse_loss = 16.1750, pde_los

[epoch = 620, batch = 0] mse_loss = 9.8297, pde_loss = 0.0000, loss = 9.8297
[epoch = 621, batch = 0] mse_loss = 9.7811, pde_loss = 0.0000, loss = 9.7811
[epoch = 622, batch = 0] mse_loss = 9.7326, pde_loss = 0.0000, loss = 9.7326
[epoch = 623, batch = 0] mse_loss = 9.6846, pde_loss = 0.0000, loss = 9.6846
[epoch = 624, batch = 0] mse_loss = 9.6371, pde_loss = 0.0000, loss = 9.6371
[epoch = 625, batch = 0] mse_loss = 9.5900, pde_loss = 0.0000, loss = 9.5900
[epoch = 626, batch = 0] mse_loss = 9.5432, pde_loss = 0.0000, loss = 9.5432
[epoch = 627, batch = 0] mse_loss = 9.4968, pde_loss = 0.0000, loss = 9.4968
[epoch = 628, batch = 0] mse_loss = 9.4507, pde_loss = 0.0000, loss = 9.4507
[epoch = 629, batch = 0] mse_loss = 9.4051, pde_loss = 0.0000, loss = 9.4051
[epoch = 630, batch = 0] mse_loss = 9.3599, pde_loss = 0.0000, loss = 9.3599
[epoch = 631, batch = 0] mse_loss = 9.3151, pde_loss = 0.0000, loss = 9.3151
[epoch = 632, batch = 0] mse_loss = 9.2709, pde_loss = 0.0000, loss = 9.2709

[epoch = 726, batch = 0] mse_loss = 6.4680, pde_loss = 0.0000, loss = 6.4680
[epoch = 727, batch = 0] mse_loss = 6.4541, pde_loss = 0.0000, loss = 6.4541
[epoch = 728, batch = 0] mse_loss = 6.4041, pde_loss = 0.0000, loss = 6.4041
[epoch = 729, batch = 0] mse_loss = 6.3566, pde_loss = 0.0000, loss = 6.3566
[epoch = 730, batch = 0] mse_loss = 6.3287, pde_loss = 0.0000, loss = 6.3287
[epoch = 731, batch = 0] mse_loss = 6.2999, pde_loss = 0.0000, loss = 6.2999
[epoch = 732, batch = 0] mse_loss = 6.2594, pde_loss = 0.0000, loss = 6.2594
[epoch = 733, batch = 0] mse_loss = 6.2273, pde_loss = 0.0000, loss = 6.2273
[epoch = 734, batch = 0] mse_loss = 6.1993, pde_loss = 0.0000, loss = 6.1993
[epoch = 735, batch = 0] mse_loss = 6.1638, pde_loss = 0.0000, loss = 6.1638
[epoch = 736, batch = 0] mse_loss = 6.1340, pde_loss = 0.0000, loss = 6.1340
[epoch = 737, batch = 0] mse_loss = 6.1043, pde_loss = 0.0000, loss = 6.1043
[epoch = 738, batch = 0] mse_loss = 6.0751, pde_loss = 0.0000, loss = 6.0751

[epoch = 833, batch = 0] mse_loss = 4.3047, pde_loss = 0.0000, loss = 4.3047
[epoch = 834, batch = 0] mse_loss = 4.2909, pde_loss = 0.0000, loss = 4.2909
[epoch = 835, batch = 0] mse_loss = 4.2771, pde_loss = 0.0000, loss = 4.2771
[epoch = 836, batch = 0] mse_loss = 4.2635, pde_loss = 0.0000, loss = 4.2635
[epoch = 837, batch = 0] mse_loss = 4.2499, pde_loss = 0.0000, loss = 4.2499
[epoch = 838, batch = 0] mse_loss = 4.2363, pde_loss = 0.0000, loss = 4.2363
[epoch = 839, batch = 0] mse_loss = 4.2228, pde_loss = 0.0000, loss = 4.2228
[epoch = 840, batch = 0] mse_loss = 4.2094, pde_loss = 0.0000, loss = 4.2094
[epoch = 841, batch = 0] mse_loss = 4.1960, pde_loss = 0.0000, loss = 4.1960
[epoch = 842, batch = 0] mse_loss = 4.1827, pde_loss = 0.0000, loss = 4.1827
[epoch = 843, batch = 0] mse_loss = 4.1695, pde_loss = 0.0000, loss = 4.1695
[epoch = 844, batch = 0] mse_loss = 4.1563, pde_loss = 0.0000, loss = 4.1563
[epoch = 845, batch = 0] mse_loss = 4.1431, pde_loss = 0.0000, loss = 4.1431

[epoch = 940, batch = 0] mse_loss = 3.1078, pde_loss = 0.0000, loss = 3.1078
[epoch = 941, batch = 0] mse_loss = 3.0987, pde_loss = 0.0000, loss = 3.0987
[epoch = 942, batch = 0] mse_loss = 3.0897, pde_loss = 0.0000, loss = 3.0897
[epoch = 943, batch = 0] mse_loss = 3.0807, pde_loss = 0.0000, loss = 3.0807
[epoch = 944, batch = 0] mse_loss = 3.0718, pde_loss = 0.0000, loss = 3.0718
[epoch = 945, batch = 0] mse_loss = 3.0628, pde_loss = 0.0000, loss = 3.0628
[epoch = 946, batch = 0] mse_loss = 3.0540, pde_loss = 0.0000, loss = 3.0540
[epoch = 947, batch = 0] mse_loss = 3.0451, pde_loss = 0.0000, loss = 3.0451
[epoch = 948, batch = 0] mse_loss = 3.0363, pde_loss = 0.0000, loss = 3.0363
[epoch = 949, batch = 0] mse_loss = 3.0275, pde_loss = 0.0000, loss = 3.0275
[epoch = 950, batch = 0] mse_loss = 3.0187, pde_loss = 0.0000, loss = 3.0187
[epoch = 951, batch = 0] mse_loss = 3.0100, pde_loss = 0.0000, loss = 3.0100
[epoch = 952, batch = 0] mse_loss = 3.0013, pde_loss = 0.0000, loss = 3.0013

In [44]:
# final evaluation
x, u = data.x, data.u
u_pred, G_pred = torch.split(model.forward(x), 1, dim=1)

mse_loss = F.mse_loss(u, u_pred)
pde_loss = compute_pde_loss(x, u_pred)
loss = mse_loss + pde_loss

print(f'[epoch = {i+1}, batch = 0] mse_loss = {mse_loss}, ' + 
    f'pde_loss = {pde_loss}, loss = {loss}')

[epoch = 1000, batch = 0] mse_loss = 2.6184914112091064, pde_loss = tensor([0.], device='cuda:0'), loss = tensor([2.6185], device='cuda:0', grad_fn=<AddBackward0>)


In [45]:
data.ds['phase_pred'] = (data.ds.dims, u_pred.reshape(data.arr.shape).detach().cpu().numpy())

data.view(var=['phase_unshifted', 'phase_pred'], scale=8).cols(1)