In [None]:
import torch
class a_opt(torch.optim.lr_scheduler.LRScheduler):
    def __init__(self, x):
        pass
        
model = torch.nn.LSTM(100, 100)
c = a_opt(torch.optim.AdamW(model.parameters(), lr=0.1))
isinstance(c, torch.optim.lr_scheduler.LRScheduler)

In [None]:
import torch
class outProd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, branch, trunk):
        b, p = branch.shape[0], trunk.shape[0]
        branch = branch.expand(b, p)
        trunk = trunk.expand(b, p)
        ctx.save_for_backward(branch, trunk)
        return branch * trunk
    @staticmethod
    def backward(ctx, grad_output):
        branch, trunk = ctx.saved_tensors
        grad_input = grad_output.clone()
        print(grad_input.shape, branch.shape, trunk.shape)
        branch_grad = grad_input * trunk
        trunk_grad = grad_input * branch
        print(branch_grad.shape, trunk_grad.shape)
        return branch_grad, trunk_grad


x = torch.ones(10, 1)
y = torch.linspace(0, 1, 10)[...,None].requires_grad_()
a = outProd.apply(x,y)
print(a)
torch.autograd.grad(a, y, grad_outputs= torch.ones_like(a),retain_graph=True)

In [None]:
x = torch.ones(10, 1)
y = torch.linspace(0, 1, 10)[None,...].requires_grad_()
a = torch.outer(x, y)

In [None]:
a = torch.ones(10, 10)
b = torch.linspace(0, 1, 10)[None,...]
print(a * b)

In [None]:
import torch.nn as nn
class originDeepONetCard(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch = nn.Linear(1, 10)
        self.trunk = nn.Linear(1, 10)
    def forward(self, x):
        branch = self.branch(x[0])
        trunk = self.trunk(x[1])
        out = torch.einsum("b i, p i -> b p", branch, trunk)
        return out
    
class DeepONetCard(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch = nn.Linear(1, 10)
        self.trunk = nn.Linear(1, 10)
    def forward(self, x):
        p, b = x[1].shape[0], x[0].shape[0]
        branch = self.branch(x[0])
        trunk = self.trunk(x[1])
        branch = branch.unsqueeze(1).expand(b, p, 10)
        trunk = trunk.unsqueeze(0).expand(b, p, 10)
        return (branch * trunk).sum(-1)
    
netA = originDeepONetCard()
netB = DeepONetCard()
a = torch.ones(10, 1)
b = torch.linspace(0, 1, 10)[..., None].requires_grad_()
x = (a, b)
A = netA(x)
B = netB(x)
print(torch.autograd.grad(A, b, grad_outputs= torch.ones_like(A),retain_graph=True)[0])
print(torch.autograd.grad(B, b, grad_outputs= torch.ones_like(B),retain_graph=True)[0])

In [None]:
import torch
from torch import nn
import time
class net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Linear(10, 1)
    
    def forward(self, x):
        return self.a(x)

y = net()
a = torch.randn(1, 10, requires_grad=True)
a = a.expand(1000, 10)
b = a.mean(0, keepdim=True)
out = y(a)
print(out.shape)
grad_batched = torch.eye(1000)[...,None]
grad = torch.autograd.grad(out, a, grad_outputs= grad_batched,retain_graph=True, is_grads_batched=True)[0]
print(grad)
for i in y.parameters():
    print(i)

In [None]:
import torch
from torch import nn
import time
class DeepOnet(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch = nn.Linear(10, 10)
        self.trunk = nn.Sequential(nn.Linear(1, 100), nn.Linear(100,10))
    
    def forward(self, b, t, mode = "a"):
        if mode == "a":
            branch = self.branch(b)
            trunk = self.trunk(t)
            return torch.einsum("b i, p i -> b p", branch, trunk)
        else:
            t = t.mean(0) # p, 1
            branch = self.branch(b) # b, 10
            trunk = self.trunk(t) # p, 10
            return torch.einsum("b i, p i -> b p", branch, trunk)

batch = 600
y = DeepOnet().cuda()
trunk_inp = torch.randn(1000, 1, requires_grad=True, device="cuda")
branch_inp = torch.randn(batch, 10, requires_grad=True, device="cuda")
out = y(branch_inp, trunk_inp, mode="a")

# 1000 x 1000 x 1; b, p, 1

In [None]:
torch.cuda.empty_cache()
t = time.time()
grad2 = []
for i in out:
    grad2.append(torch.autograd.grad(i, trunk_inp, grad_outputs= torch.ones_like(i),retain_graph=True, create_graph=True)[0])
grad2 = torch.stack(grad2)
print(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())
print(grad2.shape)
#print(grad2)
print(time.time() - t)

In [None]:
torch.cuda.empty_cache()
t = time.time()
for i in range(1):
    out = y(branch_inp, trunk_inp, mode="a")

grad_batched = torch.eye(batch, device = "cuda")[...,None].expand(batch, batch, 1000)
print(grad_batched.shape, out.shape)
grad3 = torch.autograd.grad(out, trunk_inp, grad_outputs= grad_batched,retain_graph=True, is_grads_batched= True, create_graph=True)[0]
print(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())
print(grad3.shape)

print(time.time() - t)

In [None]:
import torch
import deepxde.deepxde as dde
import time

net = dde.nn.pytorch.DeepONetCartesianProd([101, 100, 100], [2, 100, 100], "gelu", "Glorot normal")

batch = 50
branch = torch.randn(batch, 101)
trunk = torch.randn(10000, 2).requires_grad_()

result = net((branch, trunk))

In [None]:
t = time.time()
grad2 = []
for i in result:
    grad2.append(torch.autograd.grad(i, trunk, grad_outputs= torch.ones_like(i),retain_graph=True, create_graph=True)[0])
grad2 = torch.stack(grad2)
print(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())
print(grad2.shape)
#print(grad2)
print(time.time() - t)

In [None]:
t = time.time()
grad_batched = torch.eye(batch, device = "cuda")[...,None].expand(batch, batch, 10000)
grad3 = torch.autograd.grad(result, trunk, grad_outputs= grad_batched,retain_graph=True, is_grads_batched= True, create_graph=True)[0]
print(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated())
print(grad3.shape)

print(time.time() - t)

In [None]:
print((grad2 - grad3 > 1e-6).sum())

In [None]:
import deepxde.deepxde as dde
import numpy as np
f = dde.data.GRF(length_scale=0.1)
fea = f.random(1)
func = f.eval_batch(fea, np.linspace(0, 1, 101)[:, None])
func = np.ones((1, 101))
from datasets.solver import advection_solver
xt, u = advection_solver(func)

import matplotlib.pyplot as plt
fig , (ax1, ax2) = plt.subplots(1, 2)
ax1.scatter(np.linspace(0, 1, 101), func[0])
print(u[0])
xt2 = xt.reshape(-1, 2)
u2 = u.reshape(-1, 1)
ax2.scatter(xt2[:, 0], xt2[:, 1], c = u2)


In [None]:
u.T[0]

In [None]:
import torch

stat = torch.load("./results/adr_H1_norm_GRF_norm_19.pth")
print(stat.keys())

In [None]:
stat = torch.load("./results/adr_pial_300000.pth")
print(stat.keys())

In [None]:
print(stat['trunk.linears.0.weight'].shape)
print(stat['branch.linears.0.weight'].shape)

In [None]:
import deepxde.deepxde as dde
import numpy as np
import torch
from datasets import solver

def periodic(x):
    # print( "shape",torch.sin(x[:, 0] * 2 * np.pi).shape)
    return torch.cat((torch.cos(x[:, 0] * 2 * np.pi).reshape(-1, 1), torch.sin(x[:, 0] * 2 * np.pi).reshape(-1, 1),
                      torch.cos(2 * x[:, 0] * 2 * np.pi).reshape(-1, 1), torch.sin(2 * x[:, 0] * 2 * np.pi).reshape(-1, 1), x[:, 1].reshape(-1, 1)), 1)

fsp = dde.data.GRF(length_scale = 0.05)
fea = fsp.random(1)
vx = fsp.eval_batch(fea, np.linspace(0, 1, 100))

net = dde.nn.DeepONetCartesianProd(
    [100, 100, 100],
    [5, 100, 100, 100],
    "gelu",
    "Glorot normal",
)

net.apply_feature_transform(periodic)

net.load_state_dict(stat, strict = False)
print(vx.shape)
xt, u = solver.diffusion_reaction_solver(vx[0], Nx = 100, Nt = 100)
print(xt.shape)
print(u.shape)



geom = dde.geometry.Interval(0, 1)
timedomain = dde.geometry.TimeDomain(0, 1)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

xt_uniform = xt.reshape(-1, 2)
#xt_uniform = geomtime.uniform_points(10000)
inputs = (torch.as_tensor(vx), torch.as_tensor(xt_uniform).float())
# print(inputs[0], inputs[1])
up = net(inputs)

In [None]:
print(u.shape, up.shape)
u_p = up[0].detach().cpu().numpy()
u_t = u.flatten()
print(u_p.shape, u_t.shape)

In [None]:
import matplotlib.pyplot as plt
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 5))
ax1.scatter(np.linspace(0, 1, 100), vx[0])
ax2.scatter(x = xt_uniform[:, 0], y = xt_uniform[:, 1], c = u_t)
ax3.scatter(x = xt_uniform[:, 0], y = xt_uniform[:, 1], c = u_p)
ax1.set_title("vx")
ax1.set_xlim(0, 1)
ax2.set_aspect("equal")
ax3.set_aspect("equal")
ax4.set_aspect("equal")
plt.tight_layout()
plt.show()


In [None]:
from utils.test_model import normONet
import torch
from torch import nn
t = torch.randn(10000, 2)
b = torch.randn(100, 101)

net = normONet()

net((t, b))

In [None]:
import matplotlib.pyplot as plt
