In [41]:
import torch
import torch.nn as nn
import numpy as np
from tc.tc_fc import TTLinear

In [42]:
hid = [5, 2, 5, 2]
rank = [1, 3, 3, 3, 1]

model = torch.nn.Sequential(
        nn.Linear(2, 100),
        nn.Tanh(),
        TTLinear(hid, hid, rank, activation=None),
        nn.Tanh(),
        TTLinear(hid, hid, rank, activation=None),
        nn.Tanh(),
        nn.Linear(100, 1))



In [43]:
x_grid = np.linspace(0, 1, 51)
t_grid = np.linspace(0, 1, 51)

x = torch.from_numpy(x_grid)
t = torch.from_numpy(t_grid)

grid = torch.cartesian_prod(x, t).float()

def nn_autograd_simple(model, points, order,axis=0):
    points.requires_grad=True
    f = model(points).sum()
    for i in range(order):
        grads, = torch.autograd.grad(f, points, create_graph=True)
        f = grads[:,axis].sum()
    return grads[:,axis]

func_bnd1 = lambda x: 10 ** 4 * np.sin((1 / 10) * x * (x - 1)) ** 2
bnd1 = torch.cartesian_prod(x, torch.from_numpy(np.array([0], dtype=np.float64))).float()
bndval1 = func_bnd1(bnd1[:, 0])

# du/dx (x,0) = 1e3*sin^2(x(x-1)/10)
func_bnd2 = lambda x: 10 ** 3 * np.sin((1 / 10) * x * (x - 1)) ** 2
bnd2 = torch.cartesian_prod(x, torch.from_numpy(np.array([0], dtype=np.float64))).float()
bop2 = {
    'du/dt':
        {
            'coeff': 1,
            'du/dt': [1],
            'pow': 1,
            'var': 0
        }
}
bndval2 = func_bnd2(bnd2[:, 0])

# u(0,t) = u(1,t)
bnd3_left = torch.cartesian_prod(torch.from_numpy(np.array([0], dtype=np.float64)), t).float()
bnd3_right = torch.cartesian_prod(torch.from_numpy(np.array([1], dtype=np.float64)), t).float()
bnd3 = [bnd3_left, bnd3_right]

# du/dt(0,t) = du/dt(1,t)
bnd4_left = torch.cartesian_prod(torch.from_numpy(np.array([0], dtype=np.float64)), t).float()
bnd4_right = torch.cartesian_prod(torch.from_numpy(np.array([1], dtype=np.float64)), t).float()
bnd4 = [bnd4_left, bnd4_right]

bop4 = {
    'du/dx':
        {
            'coeff': 1,
            'du/dx': [0],
            'pow': 1,
            'var': 0
        }
}
bcond_type = 'periodic'

bconds = [[bnd1, bndval1, 'dirichlet'],
          [bnd2, bop2, bndval2, 'operator'],
          [bnd3, bcond_type],
          [bnd4, bop4, bcond_type]]

def wave_op(model, grid):
    u_xx = nn_autograd_simple(model, grid, order=2, axis=0)
    u_tt = nn_autograd_simple(model, grid, order=2, axis=1)
    a = -(1 / 4)

    op = u_tt + a * u_xx

    return op

def op_loss(operator):
    return torch.mean(torch.square(operator))

def bcs_loss(model):
    bc1 = model(bnd1)
    bc2 = nn_autograd_simple(model, bnd2, order=1, axis=1)
    bc3 = model(bnd3_left) - model(bnd3_right)
    bc4 = nn_autograd_simple(model, bnd4_left, order=1, axis=0) - nn_autograd_simple(model, bnd4_right, order=1, axis=0)
    
    loss_bc1 = torch.mean(torch.square(bc1.reshape(-1) - bndval1))
    loss_bc2 = torch.mean(torch.square(bc2.reshape(-1) - bndval2))
    loss_bc3 = torch.mean(torch.square(bc3))
    loss_bc4 = torch.mean(torch.square(bc4))
    
    loss = loss_bc1 + loss_bc2 + loss_bc3 + loss_bc4
    return loss



In [44]:
operator = wave_op(model, grid)
loss = op_loss(operator) + 1000 * bcs_loss(model)

In [45]:
mu = 0.001
ksi = torch.randn([100,100])

In [15]:
params = list(model.parameters())

In [17]:
params[0]

Parameter containing:
tensor([[ 0.6901, -0.6154],
        [ 0.5805, -0.0923],
        [ 0.4770,  0.0150],
        [-0.4324, -0.2337],
        [ 0.6982, -0.6722],
        [-0.4653,  0.1984],
        [-0.0904,  0.2589],
        [ 0.0382, -0.1752],
        [-0.6422,  0.1519],
        [-0.0023,  0.3478],
        [ 0.1821, -0.5024],
        [ 0.1504,  0.5313],
        [-0.2626,  0.1752],
        [-0.0616,  0.3248],
        [ 0.4898, -0.2440],
        [-0.1539,  0.1076],
        [ 0.5015, -0.3953],
        [-0.1237, -0.1950],
        [ 0.2373, -0.4534],
        [-0.3294,  0.5058],
        [ 0.6162,  0.5603],
        [ 0.5616,  0.2299],
        [-0.2720, -0.4119],
        [-0.6649, -0.3691],
        [-0.5913, -0.5012],
        [ 0.0427, -0.3164],
        [ 0.5699, -0.2680],
        [ 0.2929,  0.3336],
        [-0.5287,  0.0719],
        [-0.4514, -0.1119],
        [ 0.3975,  0.6816],
        [ 0.6841,  0.4176],
        [ 0.1728, -0.6546],
        [-0.6136, -0.6792],
        [-0.3277,  0.3069]

In [21]:
params[0] + mu * ksi[0].reshape(-1,1)

tensor([[ 0.6902, -0.6153],
        [ 0.5823, -0.0905],
        [ 0.4788,  0.0168],
        [-0.4334, -0.2347],
        [ 0.6977, -0.6727],
        [-0.4654,  0.1983],
        [-0.0887,  0.2606],
        [ 0.0384, -0.1749],
        [-0.6416,  0.1525],
        [-0.0023,  0.3478],
        [ 0.1817, -0.5028],
        [ 0.1492,  0.5301],
        [-0.2617,  0.1760],
        [-0.0609,  0.3255],
        [ 0.4896, -0.2442],
        [-0.1543,  0.1071],
        [ 0.5036, -0.3933],
        [-0.1240, -0.1953],
        [ 0.2381, -0.4527],
        [-0.3301,  0.5051],
        [ 0.6167,  0.5609],
        [ 0.5614,  0.2297],
        [-0.2725, -0.4123],
        [-0.6639, -0.3681],
        [-0.5917, -0.5016],
        [ 0.0420, -0.3171],
        [ 0.5705, -0.2675],
        [ 0.2928,  0.3335],
        [-0.5283,  0.0724],
        [-0.4520, -0.1125],
        [ 0.3968,  0.6809],
        [ 0.6839,  0.4174],
        [ 0.1730, -0.6543],
        [-0.6144, -0.6800],
        [-0.3279,  0.3068],
        [ 0.5209,  0

In [23]:
normal = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0]))

In [89]:
params[3].shape

torch.Size([100])

In [85]:
torch.squeeze(normal.sample(params[1].shape))

tensor([ 0.3107, -0.2146,  0.2425, -0.5914,  0.3832,  0.0153, -0.7345, -0.4751,
         0.3524,  0.5107, -0.6375, -1.2887,  0.8339,  0.6045,  2.0861, -1.2221,
         0.2102,  2.2010, -1.2213, -1.5470,  1.3446, -1.4734, -0.4475, -0.2212,
         0.4054, -1.4152,  0.3169,  0.0759, -0.5265,  1.4931,  0.7876,  0.8344,
         1.0809, -0.5209, -0.0421, -1.4536,  1.0178,  0.9649,  0.9055,  0.3081,
         1.3881,  1.1719, -0.7269,  1.1892, -1.1570,  0.2068, -1.0432,  0.0322,
        -0.2972,  0.8783,  0.0659,  0.4017, -1.4554,  0.7266,  0.3174,  0.7007,
        -0.8411,  2.3655,  0.8974,  0.7079,  0.9068, -0.0161, -1.3229,  0.5312,
        -1.4255,  2.0900, -2.1557,  1.0151,  1.2546, -1.2771,  2.3064,  0.9315,
         0.8850, -1.4041,  2.4246, -0.7522, -0.6582, -1.5307,  1.6193, -0.7670,
        -0.4263,  1.2909,  1.1901,  1.2401,  0.2069,  0.3133, -0.0203,  0.3917,
         0.8059, -1.5660, -0.0997, -0.4297,  2.0981, -0.0628,  0.3134,  1.6776,
        -0.5759,  0.2034,  2.5837, -0.88

In [75]:
params[0] + mu * normal.sample(torch.tensor([100]))

tensor([[ 0.6910, -0.6145],
        [ 0.5804, -0.0924],
        [ 0.4775,  0.0155],
        [-0.4315, -0.2328],
        [ 0.6968, -0.6735],
        [-0.4671,  0.1966],
        [-0.0892,  0.2601],
        [ 0.0382, -0.1751],
        [-0.6422,  0.1519],
        [-0.0023,  0.3479],
        [ 0.1829, -0.5016],
        [ 0.1521,  0.5330],
        [-0.2630,  0.1747],
        [-0.0604,  0.3260],
        [ 0.4888, -0.2450],
        [-0.1538,  0.1077],
        [ 0.5030, -0.3939],
        [-0.1231, -0.1944],
        [ 0.2373, -0.4535],
        [-0.3318,  0.5034],
        [ 0.6175,  0.5616],
        [ 0.5638,  0.2320],
        [-0.2718, -0.4116],
        [-0.6641, -0.3682],
        [-0.5905, -0.5005],
        [ 0.0423, -0.3168],
        [ 0.5692, -0.2688],
        [ 0.2926,  0.3333],
        [-0.5289,  0.0717],
        [-0.4483, -0.1089],
        [ 0.3979,  0.6819],
        [ 0.6839,  0.4173],
        [ 0.1721, -0.6552],
        [-0.6146, -0.6802],
        [-0.3289,  0.3058],
        [ 0.5187,  0

In [28]:
from pso import PSO

In [None]:
pso = PSO(model,
        
        grid=grid,
        n_iter=300,
        pop_size=60,
        gd_alpha=1e-3,
        verbose=True)

In [33]:
len(torch.nn.utils.parameters_to_vector(model.parameters()))

20601

In [36]:
tttt = torch.Tensor([1,234,4235,float('nan')])

In [57]:
model

Sequential(
  (0): Linear(in_features=2, out_features=100, bias=True)
  (1): Tanh()
  (2): TTLinear(
    (TTLayer): inp_modes=[5, 2, 5, 2], out_modes=[5, 2, 5, 2], mat_ranks=[1, 3, 3, 3, 1]
    (W_cores): ParameterList(
        (0): Parameter containing: [torch.float32 of size 1x5x5x3]
        (1): Parameter containing: [torch.float32 of size 3x2x2x3]
        (2): Parameter containing: [torch.float32 of size 3x5x5x3]
        (3): Parameter containing: [torch.float32 of size 3x2x2x1]
    )
  )
  (3): Tanh()
  (4): TTLinear(
    (TTLayer): inp_modes=[5, 2, 5, 2], out_modes=[5, 2, 5, 2], mat_ranks=[1, 3, 3, 3, 1]
    (W_cores): ParameterList(
        (0): Parameter containing: [torch.float32 of size 1x5x5x3]
        (1): Parameter containing: [torch.float32 of size 3x2x2x3]
        (2): Parameter containing: [torch.float32 of size 3x5x5x3]
        (3): Parameter containing: [torch.float32 of size 3x2x2x1]
    )
  )
  (5): Tanh()
  (6): Linear(in_features=100, out_features=1, bias=True)
)

In [46]:
grads = torch.autograd.grad(loss, model.parameters(), allow_unused=True)

In [71]:
list(model.parameters())

[Parameter containing:
 tensor([[ 0.5406,  0.5869],
         [-0.1657,  0.6496],
         [-0.1549,  0.1427],
         [-0.3443,  0.4153],
         [ 0.6233, -0.5188],
         [ 0.6146,  0.1323],
         [ 0.5224,  0.0958],
         [ 0.3410, -0.0998],
         [ 0.5451,  0.1045],
         [-0.3301,  0.1802],
         [-0.3258, -0.0829],
         [-0.2872,  0.4691],
         [-0.5582, -0.3260],
         [-0.1997, -0.4252],
         [ 0.0667, -0.6984],
         [ 0.6386, -0.6007],
         [ 0.5459,  0.1177],
         [-0.2296,  0.4370],
         [ 0.1102,  0.5713],
         [ 0.0773, -0.2230],
         [ 0.1900, -0.1918],
         [ 0.2976,  0.6313],
         [ 0.4087, -0.3091],
         [ 0.4082,  0.1265],
         [ 0.3591, -0.4310],
         [-0.7000, -0.2732],
         [-0.5424,  0.5802],
         [ 0.2037,  0.2929],
         [ 0.2236, -0.0123],
         [ 0.5534, -0.5024],
         [ 0.0445, -0.4826],
         [ 0.2180, -0.2435],
         [ 0.2167, -0.1473],
         [ 0.5865, -

In [60]:
len(grads)

14

In [66]:
grads[3].shape

torch.Size([1, 5, 5, 3])

In [72]:
for g in grads:
    if g is None:
        print(type(g))

<class 'NoneType'>
<class 'NoneType'>


In [73]:
tuple(map(lambda x: torch.zeros(1) if x is None else x, grads))

(tensor([[-7.1767e+01, -1.1139e+01],
         [-1.0257e+02, -1.5775e+01],
         [-3.0692e+01, -5.1005e+00],
         [-8.5005e+01, -1.4283e+01],
         [-1.5660e+01, -2.4237e+00],
         [-7.5693e+01, -1.2532e+01],
         [-4.6440e+01, -7.9834e+00],
         [-8.5994e+01, -1.4703e+01],
         [-4.2906e+01, -7.6244e+00],
         [-1.1833e+02, -2.0372e+01],
         [-1.0558e+01, -1.8908e+00],
         [ 1.7222e+01,  2.5819e+00],
         [-6.8607e+00, -1.1505e+00],
         [-3.4274e+01, -5.8352e+00],
         [-1.0534e+01, -1.8209e+00],
         [ 1.3952e+01,  1.6590e+00],
         [ 1.3210e+01,  2.2358e+00],
         [ 3.2653e+01,  5.4454e+00],
         [ 1.2030e+01,  1.8721e+00],
         [ 2.5947e+01,  3.9356e+00],
         [-4.2787e+00, -7.3888e-01],
         [-6.0176e-01,  4.8034e-02],
         [ 1.5581e+01,  2.9748e+00],
         [ 4.4638e+01,  8.7705e+00],
         [-1.3349e+01, -2.5758e+00],
         [-6.4960e+01, -1.2111e+01],
         [ 2.6028e+01,  4.5505e+00],
 

In [74]:
sum(p.numel() for p in model.parameters())

1099