In [0]:
import matplotlib.pyplot as plt
import autograd.numpy as np
import autograd

from sklearn.model_selection import train_test_split
from scipy.integrate import solve_ivp

# Считаем sin(0.5) через дифур

In [0]:
def f(t, y):
    return [y[1], -y[0]]

In [0]:
solution = solve_ivp(f, [0, 0.5], [0, 1])

In [4]:
solution.y[:,-1]

array([0.4794257 , 0.87758156])

# Классы для моделей

In [0]:
class ResidualF:
    def __init__(self, x_dim, param):
        self.x_dim = x_dim
        self.param_dim = param
    
    def f(self, x, t, param):
        pass
        
    def grad_x(self, t, param):
        def cur_f(x):
            return self.f(x, t, param)
        return autograd.elementwise_grad(cur_f)
    
    def grad_param(self, x, t):
        def cur_f(param):
            return self.f(x, t, param)
        return autograd.jacobian(cur_f)
        

In [0]:
class LinearF(ResidualF):
    def __init__(self, x_dim):
        param_dim = x_dim * (x_dim + 1)
        super().__init__(x_dim, param_dim)
        
    def f(self, x, t, param):
        assert(len(x) == self.x_dim)
        assert(len(param) == self.param_dim)
        
        param = param.reshape(self.x_dim, self.x_dim + 1)
        x_aug = np.concatenate([x, [1]])
        return np.dot(param, x_aug)

In [7]:
linear = LinearF(2)
x = np.asarray([1, 2])
t = 3
param = np.concatenate([[1, 10, 100], [100, 10, 1]])

linear.f(x, t, param)

array([121, 121])

In [8]:
linear.grad_x(t, param)(x)

array([101,  20])

In [9]:
linear.grad_param(x, t)(param)

array([[1, 2, 1, 0, 0, 0],
       [0, 0, 0, 1, 2, 1]])

# Оптимзасция с помощью SGD

In [0]:
def forward(f, param, x):
    def d_y(t, y):
        return f.f(y, t, param)
    
    return solve_ivp(d_y, [0, 1], x).y[:, -1]
    
def backward(f, param, x_start, x_pred, d_loss):
    def pack(x, a, d_param):
        assert len(x) == f.x_dim
        assert len(a) == f.x_dim
        assert len(d_param) == f.param_dim
        return np.concatenate([x, a, d_param])
    
    def unpack(y):
        return y[:f.x_dim], y[f.x_dim : f.x_dim * 2], y[-f.param_dim:]
    
    def d_y(t, y):
        x, a_x, a_param = unpack(y)
        return pack(f.f(x, t, param),
                - a_x * f.grad_x(t, param)(x),
                - np.dot(a_x, f.grad_param(x, t)(param))
               )
    
    y_1 = pack(x_pred, d_loss, np.zeros(f.param_dim))
    solution = solve_ivp(d_y, [1, 0], y_1)
    
    x, a, d_param = unpack(solution.y[:,-1])
    return d_param

def train(f, param, X, X_expected, iters = 50):
    lr = 0.1
    for i in range(iters):
        losses = []
        grads = []
        
        for x, x_expected in zip(X, X_expected):
            x_pred = forward(f, param, x)
            d_loss = (x_pred - x_expected)
            losses.append(0.5 * d_loss ** 2)
            d_param = backward(f, param, x, x_pred, d_loss)
            grads.append(d_param)
            
        print("iter " + str(i))
        print("average loss: " + str(np.average(losses)))
        print()
        ave_grad = np.average(grads, axis=0)
        param -= ave_grad * lr 

# Обучаем на датасете

In [0]:
dim = 2

linear = LinearF(dim)
param = np.random.randn(linear.param_dim)

X = [np.random.randn(dim) for _ in range(100)]
X_expected = np.asarray([2 * x + np.random.randn(dim) * 0.1 + 1 for x in X])

In [12]:
train(linear, param, X, X_expected)

iter 0
average loss: 3.0240618703767255

iter 1
average loss: 2.1955968688945653

iter 2
average loss: 1.5320615228292902

iter 3
average loss: 1.0237704228292706

iter 4
average loss: 0.6929184793472646

iter 5
average loss: 0.4929510048101811

iter 6
average loss: 0.3576535602281513

iter 7
average loss: 0.25799926739524465

iter 8
average loss: 0.1834563732665244

iter 9
average loss: 0.12884215207480826

iter 10
average loss: 0.09010921517391834

iter 11
average loss: 0.06338363027191533

iter 12
average loss: 0.0452119146894416

iter 13
average loss: 0.03287967128201629

iter 14
average loss: 0.024453958889918536

iter 15
average loss: 0.01863556722928673

iter 16
average loss: 0.014571702760054529

iter 17
average loss: 0.011703421926321643

iter 18
average loss: 0.009660639891419472

iter 19
average loss: 0.008194724188190181

iter 20
average loss: 0.007136129509751907

iter 21
average loss: 0.006367659521135751

iter 22
average loss: 0.005807349588255755

iter 23
average loss: 