In [1]:
%matplotlib widget

In [2]:
from matplotlib import pyplot as plot

In [3]:
import numpy

In [4]:
import torch
from torch import nn
from torch import optim
from torch import autograd

In [5]:
x_train = numpy.random.rand(100,1) * 10.

In [16]:
def true_func(x, noise=0.):
    return numpy.sin(x) + noise * numpy.random.randn(*x.shape)

def true_inv(y):
    return numpy.arcsin(y)

In [24]:
class EnergyNet(nn.Module):
    def __init__(self, n_in, n_hid=20):
        super(EnergyNet,self).__init__()
        self.n_hid = n_hid
        self.net = nn.Sequential(nn.Linear(n_in, n_hid), nn.Tanh(), nn.Linear(n_hid, 1))
        
    def forward(self, x):
        return self.net(x)

In [25]:
energy_net = EnergyNet(1, 50)
optimizer = optim.Adam(energy_net.parameters())

n_epochs = 500
batch_sz = 10

loss_run = -numpy.Inf

for ei in range(n_epochs):
    x_train = x_train[numpy.random.permutation(x_train.shape[0]),:]
    
    for ui in range(int(numpy.ceil(float(x_train.shape[0]) / batch_sz))):
        optimizer.zero_grad()
        
        batch_x = x_train[ui * batch_sz:(ui+1)*batch_sz]
        batch_y = true_func(batch_x, noise=0.2)
        batch_disp_x = 0.2 * numpy.random.randn(batch_x.shape[0]).reshape(batch_x.shape)
        batch_disp_y = true_func(batch_x+batch_disp_x, noise=0.2)
        
        batch_delta = batch_disp_y - batch_y
        
        x_in = torch.from_numpy(batch_x.astype('float32'))
        x_diff = torch.from_numpy(batch_disp_x.astype('float32'))
        x_in.requires_grad = True
        
        dummy = torch.ones_like(x_diff)
        dummy.requires_grad = True
        
        batch_loss = ((autograd.grad(autograd.grad(energy_net(x_in), x_in, 
                                                   grad_outputs=dummy, create_graph=True),
                                     dummy,
                                     grad_outputs=x_diff.view(-1,1), create_graph=True)[0] - 
                       torch.from_numpy(batch_delta.astype('float32'))) ** 2).sum()
        
        batch_loss.backward()
        optimizer.step()
        
        if loss_run == -numpy.Inf:
            loss_run = batch_loss.item()
        else:
            loss_run = 0.9 * loss_run + 0.1 * batch_loss.item()
            
print(F'Epoch {ei+1} Loss {loss_run}')

Epoch 500 Loss 0.7835099073193783


In [26]:
plot.figure()

plot.plot(x_train, true_func(x_train), 'x')
plot.plot(x_train, energy_net(torch.from_numpy(x_train.astype('float32'))).data.numpy(), 'o')

plot.grid(True)
plot.legend(['ground truth', 'learned'])


plot.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to  previous…

In [45]:
batch_delta.shape

(10, 1)

In [44]:
batch_y.shape

(10, 1)