### GD-VAE: Viscous Burgers Non-Linear PDE: Periodic Case
[http://atzberger.org](http://atzberger.org)

### Overview

Trains model for the dynamics of the viscous burgers PDE $u_t = -uu_x + \nu u_{xx}$ in the case of periodic boundary conditions.  By adjusting parameters the latent space can be taken to be a standard euclidean space or a manifold latent space.

For more information see the documentation and GD-VAEs paper.


In [None]:
# general packages
import sys,os,pickle,time,argparse,numpy as np;
import torch,torch.nn;

# GD-VAE packages
import gd_vae_pytorch as gd_vae,gd_vae_pytorch.vae,gd_vae_pytorch.geo_map,gd_vae_pytorch.utils;

# local packages
import pkg,pkg.model_utils as model_utils,pkg.datasets as datasets;

# script name without extension
script_base_name = 'train_model_periodic1';

# de-reference for later convenience
PointCloudPeriodicProjWithTime = model_utils.PointCloudPeriodicProjWithTime;
analytic_periodic_proj_with_time = model_utils.analytic_periodic_proj_with_time;

PeriodicDataset = datasets.PeriodicDataset; 

Encoder_Fully_Connected = model_utils.Encoder_Fully_Connected;
Decoder_Fully_Connected = model_utils.Decoder_Fully_Connected;

dyvae_loss = gd_vae.vae.dyvae_loss;
mse_loss = torch.nn.MSELoss();

### Parameters

In [None]:
# setup parameters 
flag_load_from_file = False;
if flag_load_from_file:
  param_filename='./script_data/study_0001/VAE__Analytic_Projection_00000/params.pickle';

  print("");print("param_filename = " + str(param_filename));print("");
  with open(param_filename, 'rb') as param_file:
    run_params = pickle.load(param_file);
    
else:        
    #Define and save parameters to file
    default_params = {
      'input_dim' : 100,
      'latent_dim' : 3,
      'time_step' : 0.25,
      'train_num_samples' : int(1e4),
      'test_num_samples' : int(1e4),
      'noise' : 0.02,
      'batch_size' : 100,
      'gamma' : 1, # reconstruction regularization
      'num_epochs' : int(4e2),
      'learning_rate' : 1e-4,
      'm1_mc' : 1, # for monte carlo estimates
      'beta' : 1e-4,
      'latent_prior_std_dev' : 1.0,
      'use_analytic_projection_map' : True,
      'use_point_cloud_projection_map' : False,
      'mse_loss' : False,
      'encoder_size': [100, 400, 400],
      'decoder_size': [400, 400, 100], 
      'save_every_n_epoch': 100,  
      'num_points_in_cloud': None 
    }
    
    run_name_base = 'VAE__Analytic_Projection';     
    run_params = default_params;
    
    if run_name_base == 'VAE__Analytic_Projection':
        pass;
    elif run_name_base == 'VAE__Point_Cloud_Projection':
        run_params['use_analytic_projection_map'] = False;
        run_params['use_point_cloud_projection_map'] = True;
        run_params['num_points_in_cloud'] = 100;
    elif run_name_base == 'VAE__No_Projection':
        run_params['use_analytic_projection_map'] = False;
    elif run_name_base == 'AE__Analytic_Projection':
        run_params['mse_loss'] = True;
    elif run_name_base == 'AE__No_Projection':
        run_params['use_analytic_projection_map'] = False;
        run_params['mse_loss'] = True;
    elif run_name_base == 'VAE__2d':
        run_params['use_analytic_projection_map'] = False;
        run_params['latent_dim'] = 2;
    elif run_name_base == 'VAE__10d':
        run_params['use_analytic_projection_map'] = False;
        run_params['latent_dim'] = 10;
    else:
        raise ValueError(f'Run Name {run_name_base} Not Recognized');    

base_dir = os.path.join('output',script_base_name,run_name_base);
print("base_dir = " + base_dir);
gd_vae.utils.create_dir(base_dir);

param_filename = os.path.join(base_dir,'params.pickle');
print("");
print("param_filename = " + param_filename);
f = open(param_filename,'wb'); pickle.dump(run_params,f); f.close();

run_params['data_folder_path'] = os.path.join(base_dir,'data');

data_dir = run_params['data_folder_path'];
print("");
print("data_dir = " + data_dir);
gd_vae.utils.create_dir(data_dir);

print("");
print("run_params.keys() = " + str(run_params.keys()));

# passed into VAE loss function
extra_params = {};
extra_params['beta'] = run_params['beta']; # beta value in beta-VAE
extra_params['gamma'] = run_params['gamma']; # gamma for reconstruction term as regularization
extra_params['num_monte_carlo_samples'] = run_params['m1_mc']; 
extra_params['device'] = None;
extra_params['latent_prior_std_dev'] = torch.Tensor([run_params['latent_prior_std_dev']]);
extra_params['mse_loss'] = mse_loss;

### Create training dataset

In [None]:
xi = torch.linspace(0,1.0,run_params['input_dim']+1)[0:-1];
train_data_params = {'time_step':run_params['time_step'],'noise':run_params['noise']};
train_dataset = PeriodicDataset(xi, run_params['train_num_samples'], train_data_params);
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=run_params['batch_size'],shuffle=True);

### Encoder model

In [None]:
phi = {};
encoder = Encoder_Fully_Connected(run_params['encoder_size'], run_params['latent_dim']);
if (not run_params['use_analytic_projection_map']) and (not run_params['use_point_cloud_projection_map']):
  phi['model_mu'] = encoder.mean;
elif run_params['use_analytic_projection_map']:
  phi['model_mu'] = lambda input : analytic_periodic_proj_with_time(encoder.mean(input));
elif run_params['use_point_cloud_projection_map']:
  point_cloud_periodic_proj_with_time = PointCloudPeriodicProjWithTime(run_params['num_points_in_cloud']);
  phi['model_mu'] = lambda input : point_cloud_periodic_proj_with_time(encoder.mean(input));
phi['model_log_sigma_sq'] = encoder.log_variance;

### Decoder model

In [None]:
decoder = Decoder_Fully_Connected(run_params['decoder_size'], run_params['latent_dim'])
theta = {'model_mu' : decoder.mean}

### Latent map

In [None]:
latent_map = model_utils.latent_map_forward_in_time; # evolution map forward in time
latent_map_params = {'time_step':run_params['time_step']};

### Training parameters

In [None]:
params_to_opt = []; # list of parameters to be optimized
params_to_opt += list(encoder.mean.parameters());
params_to_opt += list(encoder.log_variance.parameters());
params_to_opt += list(decoder.mean.parameters());
optimizer = torch.optim.Adam(params_to_opt, lr=run_params['learning_rate']);

### Train

In [None]:
print("Training the models:");
num_steps = len(train_loader);
encoder.save_encoder_model(data_dir, epoch=0); decoder.save_decoder_model(data_dir, epoch=0);

print('.'*80);  
for epoch in range(run_params['num_epochs']):
  epoch_start_time = time.time();

  for i, (input,target) in enumerate(train_loader):

    # calculate loss funtion
    if not run_params['mse_loss']:
      loss = dyvae_loss(phi['model_mu'], phi['model_log_sigma_sq'], theta['model_mu'], 
                        latent_map, latent_map_params, input, target, **extra_params);
    elif run_params['mse_loss']:
      latent = phi['model_mu'](input);
      latent_ev = latent_map(latent);
      reconstructed = theta['model_mu'](latent);
      predicted = theta['model_mu'](latent_ev);
      loss = mse_loss(predicted, target) + run_params['gamma']*mse_loss(reconstructed, input);

    # perform gradient descent
    optimizer.zero_grad(); loss.backward(); optimizer.step();

    # report progress
    if ((i + 1) % 100) == 0 or i == 0:        
      msg = 'epoch: [%d/%d]; '%(epoch+1,run_params['num_epochs']);
      msg += 'batch_step = [%d/%d]; '%(i + 1,num_steps);
      msg += 'loss: %.3e; '%(loss.item());
      print(msg);

  # epoch finished      
  print("time taken: %.1e s"%(time.time()-epoch_start_time));  
  print('.'*80);  

  if (epoch+1)%run_params['save_every_n_epoch'] == 0:
    encoder.save_encoder_model(data_dir, epoch+1);
    decoder.save_decoder_model(data_dir, epoch+1);
