In [None]:
import torch
import numpy as np
#import matplotlib.pyplot as plt

from argparse import Namespace

from training import ModelBuilder, ModelTrainer, TrajectoryGenerator

from experiments import entropy_loss_TUR, entropy_infer_TUR, entropy_loss_ML, entropy_infer_ML

from models import fully_connected_linear

from free_diffusion import *


%matplotlib inline

### EXAMPLE ON THE FLY MODEL CREATION
### CNN IS USED SO WE CAN FEED
### POSITION COORDS ONLY TO THE NETWORK

In [None]:
# network using a CNN layer, to test just using the position in the network. currently only 1D but could easily be changed to multi-D

class TrajByConvolution(ModelBuilder):
    def __init__(self, options=None):
        default_options = {'n_input':1, 'n_output':1, 'n_hidden':512, 'num_inner':2, 'n_filters':50, 'filter_length':2}
        super().__init__(default_options, options)
    
    def generate_network(self):
        opt = self.options

        
        self.c2d = torch.nn.Sequential(torch.nn.Conv2d(1,  opt.n_filters, (opt.filter_length, opt.n_input)),
            torch.nn.ReLU(inplace=True),
            )

        self.linear_connected = fully_connected_linear(opt.n_filters, opt.n_output, opt.n_hidden, opt.num_inner)

    
    def forward(self, s):
        # this makes sure it just uses the x component, 
        s = s[...,0:1]

        post_filters = self.c2d( s.reshape(s.shape[0], 1, -1, s.shape[-1]) )

        return  self.linear_connected(post_filters.swapaxes(1,2)[...,0])
    

In [None]:
#if you define your model with proper defaults, it will fill in the blanks
WeightFunction = TrajByConvolution()

In [None]:
# must use at least 2 steps so that we can get 2 values of w for calcualting delta w
# this is because the cnn with filter_length=2
# gives one value of w for each (x_i,v_i),(x_i+1, v_i+1) pair
params['num_steps']=2
print('sim_params', params)
FreeDiffusion = TrajectoryGenerator(simulate_free_diffusion_underdamped, params)


In [None]:
optimizer = torch.optim.Adam
#training_options = Namespace()
#optimizer = torch.optim.SGD
EntProd = ModelTrainer(WeightFunction, FreeDiffusion, optimizer, entropy_loss_TUR, entropy_infer_TUR)

In [None]:
untrained_output, untrained_test_trajectories = EntProd.infer(return_trajectories=True)

In [None]:
training_options = EntProd.training_options

training_options.n_epoch = 10
training_options.epoch_s = 30_000

training_options.n_iter = 10
training_options.iter_s = 25_000

EntProd.train();

In [None]:
fig, ax = plt.subplots(1,2)
plt.close()
EntProd.plot_training_loss(ax=ax)

In [None]:
training_options.n_infer = 20
training_options.infer_s = 20_000

output = EntProd.infer(return_trajectories = False)

In [None]:
temp_traj = FreeDiffusion.batch(10)

T = params['dt']*(temp_traj.shape[1]-2)

resolution = 1_000

ents = realepr( np.linspace(0, T, resolution),*params['init'] )*(T/resolution)
ent_production = sum(ents)
print(ent_production)

In [None]:
fig, ax = plt.subplots(sharex=True, sharey=True)

error = np.array(output)/ent_production - 1


m, s = np.mean(error), np.std(error)
s /= np.sqrt(len(error))
ax.plot(error, linestyle='none', marker='D')
for l in [m, m-3*s, m+3*s]:
    ax.axhline(l, c='k', linewidth=.75)
ax.set_xlabel('trial')
ax.set_ylabel('relative error')
ax.axhline(0, c='k', linestyle='--', linewidth=2)


ax.set_title('TUR loss/inference')
