In [137]:
import os
import h5py
import numpy as np

from torchdiffeq import odeint_adjoint as odeint

import torch
import torch.nn as nn
import torchcde

            
device = torch.device('cuda')
path = './train_dataset_1'
ext = ('.hdf5')
length = 100

In [138]:
fs = []
for root, dirs, files in os.walk(path):
    for file in files:
        if file.endswith(ext):
            f = h5py.File(os.path.join(root, file),'r')
            fs.append(f) 

keysList_synced = list(fs[0]['synced'].keys())
keysList_synced.remove('gyro_uncalib')
keysList_synced.remove('linacce')

train_ts = []
for f in fs:
    train_t = f['synced']['time'][:]
    train_t -= train_t[0]  # t start from 0
    train_t = train_t[0:length]
    train_ts.append(train_t)
train_ts = torch.FloatTensor(train_ts).to(device) 


keysList_synced.remove('time')


# y is observations

train_ys = []
for f in fs:
    train_y = np.empty((length, 0))
    for key in keysList_synced:
        train_y = np.column_stack([train_y, np.array(f['synced'][key][0:length])])
    train_ys.append(train_y)
train_ys = torch.FloatTensor(train_ys).to(device)

train_ys_with_t = torch.cat([train_ts.unsqueeze(2), train_ys],dim=2).to(device)  ##### include t as first feature



train_xs = []
for f in fs:
    train_x = f['pose']['tango_pos']
    train_x = np.column_stack([train_x, f['pose']['tango_ori']])
    train_x = train_x[0:length]
    train_xs.append(train_x)
train_xs = torch.FloatTensor(train_xs).to(device) 

#train_xs_with_t = torch.cat([train_ts.unsqueeze(2), train_xs],dim=2).to(device)  ##### include t as first feature

In [139]:
train_x0=train_xs[:,0,:]
m,n = train_x0.shape
train_x0_withVirance0 = torch.zeros(m, n*2).to(device)
train_x0_withVirance0[:, ::2] = train_x0[:,:]

#t0 = train_ts[:, 0]
#x0_withVirance0_with_t = torch.cat([t0.unsqueeze(1), x0_withVirance0], dim=1)

In [140]:
train_x0_withVirance0.shape

torch.Size([50, 14])

In [141]:
class ODEFunc(nn.Module):

    def __init__(self, hidden_channels):
        super(ODEFunc, self).__init__()
        
        self.hidden_channels = hidden_channels

        self.net = nn.Sequential(
            nn.Linear(hidden_channels, 64),   ###### include t
            nn.ReLU(),
            nn.Linear(64, hidden_channels),
            nn.Tanh(),
        )

    def forward(self, t, y):
        return self.net(y)

In [142]:
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels):
        
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 64)
        self.linear2 = torch.nn.Linear(64, hidden_channels*input_channels)
        
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        z = z.tanh()

        z = z.view(z.size(0), self.hidden_channels, self.input_channels)   
        return z

In [143]:
class NeuralCDE(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super(NeuralCDE, self).__init__()

        self.func = CDEFunc(input_channels, hidden_channels)
        #self.initial = torch.nn.Linear(input_channels, hidden_channels)
        #self.readout = torch.nn.Linear(hidden_channels, output_channels)

    def forward(self, coeffs, x_prev, interval):
        
        X = torchcde.CubicSpline(coeffs)
        
        
        #X0 = X.evaluate(0)
        #z0 = self.initial(X0)
        z0 = x_prev

        ######################
        # Actually solve the CDE.
        ######################
        z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.func,
                              t=interval)

        ######################
        # Both the initial value and the terminal value are returned from cdeint; extract just the terminal value,
        # and then apply a linear map.
        ######################
        z_T = z_T[:, 1]
        
        z_T = z_T.view(batch_size,hidden_channels,2)
        
        #z_T = z_T.relu()
        #z_T_relu = z_T
        #z_T_relu[:,:,1] = z_T[:,:,1].relu()
        #pred_y = self.readout(z_T)
        #return pred_y
        return z_T

In [144]:
class Ensemble(nn.Module):
    def __init__(self, model_ODE, model_CDE):
        super(Ensemble, self).__init__()
        self.model_ODE = model_ODE
        self.model_CDE = model_CDE
        self.Softplus = nn.Softplus()
        
    def forward(self, coeffs, interval, t, x_prev, x_0):
        ########### ODE with x0 
        
        ########### ODE with previous 
        #x_prev_with_t = torch.cat([t[0].repeat(batch_size).unsqueeze(1), x_prev],dim=1).to(device)  #####********** 50 need change
        #x_doublePrime = odeint(model_ODE, x_prev, t)
        
        
        x_ODE = odeint(model_ODE, x_prev, t)  
        x_ODE = x_ODE.view(batch_size,hidden_channels,2)
        

        
        x_prime = self.Softplus(x_ODE)
        
        ########### CDE with previous
        x_CDE = self.model_CDE(coeffs, x_prev, interval)

        x_hat = self.Softplus(x_CDE)
        
        pred = (x_hat+x_prime)/2
        pred = pred.view(batch_size,hidden_channels,2)
        
        
        return pred

In [145]:
epochs = 3
input_channels_CDE = 21  ##### with t
hidden_channels = 7     ###### with sigma 
batch_size = 50

model_CDE = NeuralCDE(input_channels=input_channels_CDE, hidden_channels=hidden_channels*2, output_channels=0).to(device)
model_ODE = ODEFunc(hidden_channels*2).to(device)  # with virance and t
model = Ensemble(model_ODE, model_CDE).to(device)

optimizer = torch.optim.Adam(model.parameters())


train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_ys_with_t).to(device)

train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_xs, train_x0_withVirance0)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)


loss_Gaussian = torch.nn.GaussianNLLLoss()

loss_l1 = torch.nn.L1Loss()


In [146]:
##### only coefficent in this interval

for epoch in range(epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_xs, batch_x0 = batch
            
            x_doublePrime = batch_x0
            x_prev = batch_x0
            
            loss=0
            loss2=0
            for i in range(1,length):
                #x_prev = batch_xs_with0[:,i-1,:]  
                
                
                
                t = train_ts[0][i-1:i]      ###### change from i+1 to i, just two t points          
                x_current = batch_xs[:,i,:]
                
                
                interval = torch.FloatTensor([0, 1]).to(device)  
                ###### only coefficent in this interval
                coeffs = batch_coeffs[:,i-1,:]
                
                pred = model(batch_coeffs, interval, t, x_prev, x_doublePrime)
                
  
                x_prev = pred.view(batch_size,hidden_channels*2)   ######## update previous
                
                pred_x = pred[:,:,0]
                var_x = pred[:,:,1] # must positive  
                
                
                loss += loss_Gaussian(pred_x, x_current, var_x)
                loss2 += loss_l1(pred_x, x_current)
                
                if i%10 == 0:
                    print('i: {}   Training loss: {}'.format(i, loss.item()))
                    
                    
            #print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))  
            loss = loss/length
            loss.backward()
            optimizer.step()
            optimizer.zero_grad() 
            
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))
        print('Epoch: {}   Training loss2: {}'.format(epoch, loss2.item()))

i: 10   Training loss: 11.093199729919434
i: 20   Training loss: 34.76270294189453
i: 30   Training loss: 76.65408325195312
i: 40   Training loss: 150.62533569335938
i: 50   Training loss: 272.8928527832031
i: 60   Training loss: 459.9002685546875
i: 70   Training loss: 727.0252685546875
i: 80   Training loss: 1089.5721435546875
i: 90   Training loss: 1563.6435546875
Epoch: 0   Training loss: 21.012041091918945
Epoch: 0   Training loss2: 543.6048583984375
i: 10   Training loss: 11.04780387878418
i: 20   Training loss: 34.02960205078125
i: 30   Training loss: 72.13883209228516
i: 40   Training loss: 134.72885131835938


KeyboardInterrupt: 