In [2]:
%load_ext lab_black

In [3]:
%matplotlib inline
import torch
import torch.nn as nn
import numba
import numpy as np
import matplotlib.pyplot as plt

In [4]:
device = "cuda:0"

In [48]:
spec = [
    ("solved", numba.types.boolean),
    ("g", numba.types.float64),
    ("mu", numba.types.float64),
    ("__v_no_drag", numba.types.float64[:]),
    ("__v", numba.types.float64[:]),
    ("__y_no_drag", numba.types.float64[:]),
    ("__y", numba.types.float64[:]),
    ("__t", numba.types.float64[:]),
]


@numba.jitclass(spec)
class ODEData:  # torch.utils.data.Dataset):
    def __init__(self, g=9.81, mu=0):
        self.solved = False
        self.g = g
        self.mu = mu

    def __getitem__(self, i):
        if not self.solved:
            raise RuntimeError("Please run solve() first!")
        return self.y_no_drag[i], self.y[i]

    def solve(self, v0, y0, t, dt):
        N = int(t // dt + 1)
        self.__v_no_drag = np.zeros((N, len(v0)))
        self.__y_no_drag = np.zeros_like(self.__v_no_drag)
        self.__v = np.zeros_like(self.__v_no_drag)
        self.__y = np.zeros_like(self.__v_no_drag)
        self.__v_no_drag[0] = v0
        self.__v[0] = v0
        self.__y_no_drag[0] = y0
        self.__y[0] = y0
        self.__t = np.linspace(0, t, N)
        for i in range(N - 1):
            a_pull = -self.g * self.__y[i]
            a_drag = self.mu * self.__v[i]
            a_pull_no_drag = -self.g * self.__y_no_drag[i]
            step = a_pull_no_drag * dt
            self.__v_no_drag[i + 1] = self.__v_no_drag[i] + step * (step < 0)
            step = (a_pull - a_drag) * dt
            self.__v[i + 1] = self.__v[i] + step * (step < 0)
            step = self.__y_no_drag[i] + self.__v_no_drag[i + 1] * dt
            self.__y_no_drag[i + 1] = step * (step > 0)
            step = self.__y[i] + self.__v[i + 1] * dt
            self.__y[i + 1] = step * (step > 0)
        self.solved = True

    @property
    def y(self):
        return self.__y

    @property
    def y_no_drag(self):
        return self.__y_no_drag

    @property
    def v(self):
        return self.__v

    @property
    def v_no_drag(self):
        return self.__v_no_drag

    @property
    def t(self):
        return self.__t

    @property
    def size(self):
        return len(self.__t)

In [49]:
data = ODEData(mu=1)
data.solve(np.array([1, 0, 3]), np.array([0.5, 1, 0]), 1, 0.01)
plt.plot(data.t, data.y[:, 0], label="Drag")
plt.plot(data.t, data.y_no_drag[:, 0], label="No drag")
batch_size = data.y.shape[1]
plt.legend()
plt.show()

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of Function(<built-in function setitem>) with argument(s) of type(s): (array(float64, 1d, A), Literal[int](0), array(int64, 1d, C))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
In definition 6:
    All templates rejected with literals.
In definition 7:
    All templates rejected without literals.
In definition 8:
    All templates rejected with literals.
In definition 9:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of staticsetitem at <ipython-input-48-d4f621253348> (31)

File "<ipython-input-48-d4f621253348>", line 31:
    def solve(self, v0, y0, t, dt):
        <source elided>
        self.__y = np.zeros_like(self.__v_no_drag)
        self.__v_no_drag[0] = v0
        ^

[1] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'solve') for instance.jitclass.ODEData#7f168bc7ee20<solved:bool,g:float64,mu:float64,_ODEData__v_no_drag:array(float64, 1d, A),_ODEData__v:array(float64, 1d, A),_ODEData__y_no_drag:array(float64, 1d, A),_ODEData__y:array(float64, 1d, A),_ODEData__t:array(float64, 1d, A)>)
[2] During: typing of call at <string> (3)

[3] During: resolving callee type: BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'solve') for instance.jitclass.ODEData#7f168bc7ee20<solved:bool,g:float64,mu:float64,_ODEData__v_no_drag:array(float64, 1d, A),_ODEData__v:array(float64, 1d, A),_ODEData__y_no_drag:array(float64, 1d, A),_ODEData__y:array(float64, 1d, A),_ODEData__t:array(float64, 1d, A)>)
[4] During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>


In [7]:
class HybridLSTM(nn.Module):
    def __init__(self, nodes=10, layers=1, dropout=0):
        super(HybridLSTM, self).__init__()
        self.lstm = nn.LSTM(
            input_size=1, hidden_size=nodes, num_layers=layers, dropout=dropout
        )
        self.output_layer = nn.Linear(nodes, 1)

    def forward(self, y):
        output, (h_n, c_n) = self.lstm(y)
        return self.output_layer(h_n[-1])

In [12]:
def train(net, data, lr=0.1, epochs=10):
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
    loss_func = nn.MSELoss()
    net.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        y_no_drag = torch.zeros([data.size, 1, 1]).to(device)
        y = torch.zeros_like(y_no_drag)
        net_correction = torch.zeros_like(y_no_drag)
        # loss = 0
        for i in range(data.size):
            tmp = torch.zeros_like(y)
            tmp2 = torch.zeros_like(y)
            tmp[i], tmp2[i] = data[i]
            y_no_drag = y_no_drag + tmp
            y = y + tmp2
            tmp = torch.zeros_like(y)
            tmp[i] = net.forward(y_no_drag[: i + 1])
            net_correction = net_correction + tmp
        loss = loss_func(net_correction + y_no_drag, y)
        print(f"Epoch {epoch}: {loss.item()}")
        loss.backward()
        optimizer.step()

In [13]:
net = HybridLSTM().to(device)
train(net, data)

Epoch 0: 0.004300587810575962
Epoch 1: 0.002545009832829237
Epoch 2: 0.0016466024098917842
Epoch 3: 0.0011856707278639078
Epoch 4: 0.0009482927271164954
Epoch 5: 0.0008252396364696324
Epoch 6: 0.0007606817525811493
Epoch 7: 0.0007260639104060829
Epoch 8: 0.0007067751721478999
Epoch 9: 0.0006953388219699264
