In [None]:
#General libraries
import os, argparse
import pickle
import numpy as np

#Torch libraries
import torch 

#Custom libraries
from utils_ide import Train_val_split, Dynamics_Dataset, Test_Dynamics_Dataset
from utils_ide import fix_random_seeds,to_np
from source.ide_func import NNIDEF, NeuralIDE, NNIDEF_wODE, NeuralIDE_wODE
import source.kernels as kernels
from source.experiments import Full_experiment, IDE_spiral_experiment_no_adjoint
from torch.utils.data import SubsetRandomSampler

if torch.cuda.is_available():  
    device = "cuda:0" 
else:  
    device = "cpu"
    

parser = argparse.ArgumentParser(description='Neural IDE')
parser.add_argument('-root_path', metavar='DIR', default='',
                    help='path to dataset')
parser.add_argument('-dataset-name', default='stl10',
                    help='dataset name', choices=['acrobot_dataset'])

parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=5000, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=20, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
                    metavar='W', help='weight decay (default: 1e-3)',
                    dest='weight_decay')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
parser.add_argument('--fp16-precision', action='store_true',
                    help='Whether or not to use 16-bit precision GPU training.')

parser.add_argument('--out_dim', default=128, type=int,
                    help='feature dimension (default: 128)')
parser.add_argument('--log-every-n-steps', default=100, type=int,
                    help='Log every n steps')
parser.add_argument('--temperature', default=0.07, type=float,
                    help='softmax temperature (default: 0.07)')
parser.add_argument('--n-views', default=2, type=int, metavar='N',
                    help='Number of views for contrastive learning training.')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')
parser.add_argument('--model', default='simclr', choices=['simclr','lipschitz_simclr','vae','gan'], 
                    help='Models to be used')
parser.add_argument('--mode', default='train', choices=['train','evaluate'], 
                    help='Set to ''evaluate'' if inference is desired')
parser.add_argument('--training_split', default=0.25,type=float, 
                    help='Fraction of the samples that will be used for validation')
parser.add_argument('--resume_from_checkpoint', default=None, 
                    help='Give string to run number. Ex: "run12"')
parser.add_argument('--plot_freq', default=1, type=int,help='')
parser.add_argument('--experiment_name', default=None,help='')


In [None]:
from source.solver import IDESolver_monoidal
import matplotlib.pyplot as plt
import pickle

In [None]:
torch.cuda.device_count()

In [None]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"
device = torch.device(dev)

In [None]:
args = parser.parse_args("")
args.model='nide' #'simclr', 
args.mode='train'
#args.mode = 'evaluate'
args.dataset_name = 'ide_spiral'
args.seed = 1
args.batch_size = 5 #batch size 
args.experiment_name = 'ide_spiral_nide_1'
args.plot_freq = 1
args.device = device
args.num_dim_plot = 2
args.lr = 1e-3
args.min_lr=1e-12
args.T_max = 51
args.plat_patience = 10
args.factor = 0.1
args.train_split = 0.2
# args.temperature=0.001
#args.lr_scheduler = 'ReduceLROnPlateau'
args.lr_scheduler = 'CosineAnnealingLR'
#args.resume_from_checkpoint = 'run149'
fix_random_seeds(args.seed)
args.number_MC_samplings = 1000
args.warmup = 11

In [None]:
t_max = 1 
t_min = 0
n_points = 20

index_np = np.arange(0, n_points, 1, dtype=int)
index_np = np.hstack(index_np[:, None])
times_np = np.linspace(t_min, t_max, num=n_points)
times_np = np.hstack([times_np[:, None]])
# print('times_np: ',times_np)

###########################################################
times = torch.from_numpy(times_np[:, :, None]).to(device)
times = times.flatten().float()
# print('times :',times)
###########################################################

In [None]:
Data = pickle.load(open( "50_IE_Spirals.pkl", "rb" ))
print(Data.shape)

In [None]:
Data = Data - Data.mean(1,keepdim=True) #Normalization step
scaling_factor = to_np(Data).max()
print('scaling_factor: ',scaling_factor)

Data = Data/scaling_factor
if args.mode == 'train':
    Data = Data[:450,...]
else:
    Data = Data[450:,...]

In [None]:
ids = np.tile(np.linspace(0,Data.shape[1]-1,num=n_points, dtype=np.int64),(Data.shape[1],1))

In [None]:
ids[0]

In [None]:
Data = Data[:,ids[0],:]

In [None]:
Data.shape

In [None]:
from torch import nn
class NN_feedforward(nn.Module):
    def __init__(self, in_dim, hid_dim,out_dim):
        super(NN_feedforward, self).__init__()

        self.lin1 = nn.Linear(in_dim, hid_dim)
        self.lin2 = nn.Linear(hid_dim, hid_dim)
        self.lin3 = nn.Linear(hid_dim, hid_dim)
        self.lin4 = nn.Linear(hid_dim, out_dim)
        self.ELU = nn.ELU(inplace=True)
        
        self.in_dim = in_dim

    def forward(self,y):
        y_in = y.to(args.device)
        
        h = self.ELU(self.lin1(y_in))
        h = self.ELU(self.lin2(h))
        h = self.ELU(self.lin3(h))
        out = self.lin4(h)
        
        return out

In [None]:
class Simple_NN(nn.Module):
    def __init__(self, in_dim, hid_dim,out_dim):
        super(Simple_NN, self).__init__()

        self.lin1 = nn.Linear(in_dim+1, hid_dim)
        self.lin2 = nn.Linear(hid_dim, hid_dim)
        self.lin3 = nn.Linear(hid_dim, hid_dim)
        self.lin4 = nn.Linear(hid_dim, out_dim)
        self.ELU = nn.ELU(inplace=True)
        
        self.in_dim = in_dim

    def forward(self,x,y):
        y = y.to(device)
        x = x.view(1,1).repeat(y.shape[0],1).to(device)
        
        y_in = torch.cat([x,y],-1)
        h = self.ELU(self.lin1(y_in))
        h = self.ELU(self.lin2(h))
        h = self.ELU(self.lin3(h))
        out = self.lin4(h)
        
        return out

In [None]:

n_steps = 10000 #number of iterations for training. default=3k epochs
print('Data.shape: ',Data.shape)


train_split = int(Data.shape[0]*args.train_split)

Dataset_train = Dynamics_Dataset(Data[:-train_split,...],times)
Dataset_val = Dynamics_Dataset(Data[-train_split:,...],times)
Dataset_all = Dynamics_Dataset(Data,times)


dataloaders = {'train': torch.utils.data.DataLoader(Dataset_train, batch_size = args.batch_size),
               'val': torch.utils.data.DataLoader(Dataset_val, batch_size = args.batch_size),
               'test': torch.utils.data.DataLoader(Dataset_all, batch_size = args.batch_size),
              }


nn_kernel = kernels.kernel_NN_nbatch(Data.shape[-1], Data.shape[-1],[64,128,256,512,256,128,64])
ker = nn_kernel.to(device)
F_func = NN_feedforward(2,256,2).to(device)
ode_func = Simple_NN(2,256,2).to(device)


kernel_type_nn = True
ode_nn = True


In [None]:
import matplotlib.pyplot as plt
Data_print = to_np(Data)
plt.scatter(Data_print[0,:,0],Data_print[0,:,1],label='Original Data',s=10)

In [None]:
kernels.flatten_kernel_parameters(ker).size()

In [None]:
kernels.flatten_F_parameters(F_func).size()

In [None]:
kernels.flatten_F_parameters(ode_func).size()

In [None]:
args.alpha = lambda x: torch.Tensor([0]).to(args.device)
args.beta = lambda x: x.to(args.device)

In [None]:
IDE_spiral_experiment_no_adjoint(ker, F_func, ode_func, Data, 
                                dataloaders, times, ode_nn, 
                                 None, args
                                 )