# Training Toy Experiments

In [None]:
import torch

from tqdm import tqdm


import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import sys
sys.path.append("/trinity/home/a.kolesov/EFM")
from src.utils import Config, get_mesh, forward_interpolation
from src.models import DefineMLP, ToyNet
from src.data import MultiDimStackedGauss,MultiDimStackedMixtureGauss, SwissRollSampler
from src.efm_field import EFMGroundTruth
from src.ifm_field import IFMGroundTruth
from src.train import ToyTrain
from src.ode import LearnedODESolver, LearnDippoleEFMODESolver

## 1. 3-dimensional example: Gaussian-to-Mixture

In [None]:
config = Config()

##### common #####
config.device = 'cuda' 
config.DIM = 3
config.L = 6.
config.NUM_SAMPLES = 1_000
config.NUM_MESH =30
config.DICT_MESH = {"z":[-5,5.], "x":[1e-1,8.], "y": [-12,12]}
##### common ##### 

##### p_data #####
config.p = Config()
config.p.dim = config.DIM -1
config.p.mean = (0.,0.)
config.p.x_loc = 0.
config.p.std = 1.
config.p.cov = None
##### p_data #####

##### q_data #####
config.q =  Config()
config.q.dim = config.DIM -1
config.q.num_components = 2
config.q.mix_probs = (0.5,0.5)
config.q.means = ((-8.,0.),(8.,0.))
config.q.stds = (1.,1.)
config.q.cov = None
config.q.x_loc = config.L
##### q_data ##### 


##### EFM #####
config.efm = Config()

config.efm.training = Config()
config.efm.training.stability = True
config.efm.training.gamma = 1e-7
config.efm.dist_behind = 9.
##### EFM #####

##### IFM #####
config.ifm = Config()
config.ifm.name = "3d_mixture"
config.ifm.field_type = "Shifted"
config.ifm.field_form = 'exponential'
config.ifm.plan_type = 'Independent'
config.ifm.SCALE = 1.
config.ifm.K = 10.
config.ifm.D = math.pi/(2*config.ifm.K)
##### IFM #####

##### model #####
config.model = Config()
config.model.embed_dim = 128
config.model.hidden_layers_x = [config.DIM-1, 1024*(config.DIM-1),
                                1024*(config.DIM-1),config.model.embed_dim]
config.model.hidden_layers = [ config.model.embed_dim, 1024*(config.DIM),
                              1024*(config.DIM), config.DIM]
config.model.lr =1e-4
config.model.epsilon = 0.5
config.model.M = 4.
config.model.tau = 0.03
config.model.norm_scaler= 4.
config.model.training_batch=1024
##### model #####

##### ode #####
config.ode = Config()
config.ode.step = 1.
config.ode.gamma = 0.
config.ode.behind_step = 1.5
config.ode.behind_num_steps = 350
##### ode #####

In [None]:
p_dist = MultiDimStackedGauss(config)
q_dist = MultiDimStackedMixtureGauss(config)

p_samples =  p_dist.sample(config.model.training_batch).to(config.device) #[B,DIM]
q_samples =  q_dist.sample(config.model.training_batch).to(config.device)

ode = {}
field = {}
maps ,traj = {},{}

### EFM ###
"""
config.model.training_steps = 5_000
config.model.training_batch = 2048
config.model.training_batch_small = 1024
config.model.interpolation = 'efm'

model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
model =  DefineMLP(config.model.hidden_layers).to(config.device)
net = ToyNet(model_x,  model, config)
optimizer = torch.optim.Adam(net.parameters(), lr=config.model.lr)

net, losses = ToyTrain(EFMGroundTruth, p_dist, q_dist, net, optimizer, config)
torch.save(net.cpu().state_dict(), "/trinity/home/a.kolesov/EFM/ckpt/toy/efm/mesh_bs=2048_eps=0.5.pth")
"""
with torch.no_grad():
    path = "/trinity/home/a.kolesov/EFM/ckpt/toy/efm/mesh_bs=2048_eps=0.5.pth"
    model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
    model =  DefineMLP(config.model.hidden_layers).to(config.device)
    net = ToyNet(model_x,  model, config)
    net.cpu().load_state_dict(torch.load(path))
    net = net.to(config.device)


    config.NUM_MESH =7
    config.DICT_MESH = {"z":[-5,5.], "x":[1e-5,8.], "y": [-12,12]}
    mesh = get_mesh(config).to(config.device)
    field['efm'] = net(mesh.clone())

    ode['efm'] = LearnDippoleEFMODESolver(net, config)
    init = p_dist.sample(config.model.training_batch).to(config.device)
    init[:,0] =  config.model.epsilon # starting point problem, D dependecy;
    maps['efm'], traj['efm'] = ode['efm'](init,p_samples,q_samples)

### EFM ###



### IFM ###
"""
config.model.sigma_end = 0.8
config.model.training_steps = 2_000
config.model.training_batch = 1024
config.model.training_batch_small = 256


model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
model =  DefineMLP(config.model.hidden_layers).to(config.device)
net = ToyNet(model_x,  model, config)
optimizer = torch.optim.Adam(net.parameters(), lr=config.model.lr)

net, losses = ToyTrain(IFMGroundTruth, p_dist, q_dist, net, optimizer, config)
torch.save(net.cpu().state_dict(),
           "/trinity/home/a.kolesov/EFM/ckpt/toy/ifm/both_side_bs=1024_sigma=0.8_eps=0.5_M=4.pth")
"""
with torch.no_grad():
    path = "/trinity/home/a.kolesov/EFM/ckpt/toy/ifm/both_side_bs=1024_sigma=0.8_eps=0.5_M=4.pth"
    net.cpu().load_state_dict(torch.load(path))
    net = net.to(config.device)
    
    config.NUM_MESH =7
    config.DICT_MESH = {"z":[-5,5.], "x":[1e-5,8.], "y": [-12,12]}
    mesh = get_mesh(config).to(config.device)
    field['ifm'] = net(mesh.clone())

    ode['ifm'] = LearnedODESolver(net, config)
    init = p_dist.sample(config.model.training_batch).to(config.device)
    init[:,0] =  config.model.epsilon # starting point problem, D dependecy;
    maps['ifm'], traj['ifm'] = ode['ifm'](init,p_samples,q_samples)
    ### IFM ###
    
    
traj['efm'] = torch.stack(traj['efm'] ,dim=0)
traj['ifm'] = torch.stack(traj['ifm'],dim=0)

In [None]:
fig = plt.figure()
fig.set_figheight(8)
fig.set_figwidth(8)

 
    
ax = fig.add_subplot(2,2,1,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='blue',edgecolor='black',s=80,label=r'$x_{+}  \sim \mathbb{P}(x_{+})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='red',edgecolor='black',s=80,label=r'$x_{-}  \sim \mathbb{Q}(x_{-})$')
ax.quiver(mesh[:, 0].cpu(), mesh[:, 1].cpu(), mesh[:,2].cpu(),
          field['efm'][:, 0].cpu(), field['efm'][:, 1].cpu(), field['efm'][:,2].cpu(),
          color='black',length=1, normalize=True)
ax.set_title(f"EFM Ground Truth")
ax.legend()


ax = fig.add_subplot(2,2,2,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='gray',edgecolor='black',s=80,label=r'$x_{q}  \sim \mathbb{P}(x_{q})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='yellow',edgecolor='black',s=80,label=r'$x_{\bar{q}}  \sim \mathbb{P}(x_{\bar{q}})$')
ax.quiver(mesh[:, 0].cpu(), mesh[:, 1].cpu(), mesh[:,2].cpu(),
          field['ifm'][:, 0].cpu(), field['ifm'][:, 1].cpu(), field['ifm'][:,2].cpu(),
          color='black',length=1, normalize=True)
ax.set_title(f"IFM Ground Truth")
ax.legend()


ax = fig.add_subplot(2,2,3,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='blue',edgecolor='black',s=80, label=r'$x_{+}  \sim \mathbb{P}(x_{+})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='red',edgecolor='black',s=80, label=r'$x_{-}  \sim \mathbb{Q}(x_{-})$')

traj['efm'][-1,:,0] = 7.3

ax.scatter(traj['efm'][-1][:,0].cpu(),traj['efm'][-1][:,1].cpu(),traj['efm'][-1][:,2].cpu(),
          color='lightgreen',edgecolor='black',zorder=20,label=r'$y \sim  T(x_{+})$',s=80)
for idx in range(200,230):
    ax.plot(traj['efm'][: ,idx,0].cpu(), 
        traj['efm'][:, idx,1].cpu(),
        traj['efm'][:, idx,2].cpu(),
        color='black',linewidth=0.5, zorder=3);
     
ax.legend()


ax = fig.add_subplot(2,2,4,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='gray',edgecolor='black',s=80,label=r'$x_{q}  \sim \mathbb{P}(x_{q})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='yellow',edgecolor='black',s=80,label=r'$x_{\bar{q}}  \sim \mathbb{P}(x_{\bar{q}})$')
ax.scatter(traj['ifm'][-1][:,0].cpu(),traj['ifm'][-1][:,1].cpu(),traj['ifm'][-1][:,2].cpu(),
          color='darkorange',edgecolor='black',s=80,label=r'$y \sim  T(x_{q})$')
for idx in range(200,230):
    ax.plot(traj['ifm'][: ,idx,0].cpu(), 
        traj['ifm'][:, idx,1].cpu(),
        traj['ifm'][:, idx,2].cpu(),
        color='black',linewidth=0.5, zorder=3);
     
ax.legend()


fig.tight_layout(pad=0.01)

## 2. 3-dimensional example: Gaussian-to-Swiss Roll

In [None]:
config = Config()

##### common #####
config.device = 'cuda' 
config.DIM = 3
config.L = 6.
config.NUM_SAMPLES = 1_000
config.NUM_MESH =30
config.DICT_MESH = {"z":[-4,4.], "x":[1e-1,10.], "y": [-5,5]}
##### common ##### 

##### p_data #####
config.p = Config()
config.p.dim = config.DIM -1
config.p.mean = (0.,0.)
config.p.x_loc = 0.
config.p.std = 1.
config.p.cov = None
##### p_data #####

##### q_data #####
config.q =  Config()
config.q.dim = config.DIM -1
config.q.x_loc = 6.
config.q.swiss_noise = None
##### q_data ##### 


##### EFM #####
config.efm = Config()

config.efm.training = Config()
config.efm.training.stability = True
config.efm.training.gamma = 1e-7
config.efm.dist_behind = 10.
##### EFM #####

##### IFM #####
config.ifm = Config()
config.ifm.name = "3d_mixture"
config.ifm.field_type = "Shifted"
config.ifm.field_form = 'exponential'
config.ifm.plan_type = 'Independent'
config.ifm.SCALE = 1.
config.ifm.K = 10.
config.ifm.D = math.pi/(2*config.ifm.K)
##### IFM #####

##### model #####
config.model = Config()
config.model.embed_dim = 128
config.model.hidden_layers_x = [config.DIM-1, 1024*(config.DIM-1),
                                1024*(config.DIM-1),config.model.embed_dim]
config.model.hidden_layers = [ config.model.embed_dim, 1024*(config.DIM),
                              1024*(config.DIM), config.DIM]
config.model.lr =1e-4
config.model.epsilon = 0.5
config.model.M = 4.
config.model.tau = 0.03
config.model.norm_scaler= 4.
config.model.training_batch=1024
##### model #####

##### ode #####
config.ode = Config()
config.ode.step = .25
config.ode.gamma = 0.
config.ode.behind_step = 1.5
config.ode.behind_num_steps = 350
##### ode #####

In [None]:
p_dist = MultiDimStackedGauss(config)
q_dist = SwissRollSampler(config)

p_samples =  p_dist.sample(config.model.training_batch).to(config.device) #[B,DIM]
q_samples =  q_dist.sample(config.model.training_batch).to(config.device)

ode = {}
field = {}
maps ,traj = {},{}

### EFM ###
"""
config.model.training_steps = 5_000
config.model.training_batch = 2048
config.model.training_batch_small = 1024
config.model.interpolation = 'efm'


model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
model =  DefineMLP(config.model.hidden_layers).to(config.device)
net = ToyNet(model_x,  model, config)
optimizer = torch.optim.Adam(net.parameters(), lr=config.model.lr)

net, losses = ToyTrain(EFMGroundTruth, p_dist, q_dist, net, optimizer, config)
torch.save(net.cpu().state_dict(),
           "/trinity/home/a.kolesov/EFM/ckpt/toy/efm/swiss_mesh_bs=2048_eps=0.5.pth")

"""
config.ode.step = 0.5
config.ode.behind_num_steps = 60
config.ode.behind_step = 0.5
with torch.no_grad():
    model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
    model =  DefineMLP(config.model.hidden_layers).to(config.device)
    net = ToyNet(model_x,  model, config)
    
    path = "/trinity/home/a.kolesov/EFM/ckpt/toy/efm/swiss_mesh_bs=2048_eps=0.5.pth"
    net.cpu().load_state_dict(torch.load(path))
    net = net.to(config.device)
    
    config.NUM_MESH =7
    config.DICT_MESH = {"z":[-5,5.], "x":[1e-5,8.], "y": [-4,4]}
    mesh = get_mesh(config).to(config.device)
    field['efm'] = net(mesh.clone())
    
    ode['efm'] = LearnDippoleEFMODESolver(net, config)
    init = p_dist.sample(config.model.training_batch).to(config.device)
    init[:,0] =  config.model.epsilon # starting point problem, D dependecy;
    maps['efm'], traj['efm'] = ode['efm'](init,p_samples,q_samples)

### EFM ###

### IFM ###
"""
config.model.sigma_end = 0.8
config.model.training_steps = 2_000
config.model.training_batch = 1024
config.model.training_batch_small = 256


model_x = DefineMLP(config.model.hidden_layers_x).to(config.device)
model =  DefineMLP(config.model.hidden_layers).to(config.device)
net = ToyNet(model_x,  model, config)
optimizer = torch.optim.Adam(net.parameters(), lr=config.model.lr)

net, losses = ToyTrain(IFMGroundTruth, p_dist, q_dist, net, optimizer, config)
torch.save(net.cpu().state_dict(),
           "/trinity/home/a.kolesov/EFM/ckpt/toy/ifm/swiss_both_side_bs=1024_sigma=0.8_eps=0.5_M=4.pth")
"""

### IFM ###
with torch.no_grad():
    path = "/trinity/home/a.kolesov/EFM/ckpt/toy/ifm/swiss_both_side_bs=1024_sigma=0.8_eps=0.5_M=4.pth"
    net.cpu().load_state_dict(torch.load(path))
    net = net.to(config.device)
    
    config.NUM_MESH =7
    config.DICT_MESH = {"z":[-5,5.], "x":[1e-5,8.], "y": [-4,4]}
    mesh = get_mesh(config).to(config.device)
    field['ifm'] = net(mesh.clone())

    ode['ifm'] = LearnedODESolver(net, config)
    init = p_dist.sample(config.model.training_batch).to(config.device)
    init[:,0] =  config.model.epsilon +0.5 # starting point problem, D dependecy;
    maps['ifm'], traj['ifm'] = ode['ifm'](init,p_samples,q_samples)
    ### IFM ###
    
    
traj['efm'] = torch.stack(traj['efm'] ,dim=0)
traj['ifm'] = torch.stack(traj['ifm'],dim=0)

In [None]:
fig = plt.figure()
fig.set_figheight(8)
fig.set_figwidth(8)

 
    
ax = fig.add_subplot(2,2,1,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='blue',edgecolor='black',s=80,label=r'$x_{+}  \sim \mathbb{P}(x_{+})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='red',edgecolor='black',s=80,label=r'$x_{-}  \sim \mathbb{Q}(x_{-})$')
ax.quiver(mesh[:, 0].cpu(), mesh[:, 1].cpu(), mesh[:,2].cpu(),
          field['efm'][:, 0].cpu(), field['efm'][:, 1].cpu(), field['efm'][:,2].cpu(),
          color='black',length=1, normalize=True)
ax.set_title(f"EFM Ground Truth")
ax.legend()


ax = fig.add_subplot(2,2,2,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='gray',edgecolor='black',s=80,label=r'$x_{q}  \sim \mathbb{P}(x_{q})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='yellow',edgecolor='black',s=80,label=r'$x_{\bar{q}}  \sim \mathbb{P}(x_{\bar{q}})$')
ax.quiver(mesh[:, 0].cpu(), mesh[:, 1].cpu(), mesh[:,2].cpu(),
          field['ifm'][:, 0].cpu(), field['ifm'][:, 1].cpu(), field['ifm'][:,2].cpu(),
          color='black',length=1, normalize=True)
ax.set_title(f"IFM Ground Truth")
ax.legend()


ax = fig.add_subplot(2,2,3,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='blue',edgecolor='black',s=80, label=r'$x_{+}  \sim \mathbb{P}(x_{+})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='red',edgecolor='black',s=80, label=r'$x_{-}  \sim \mathbb{Q}(x_{-})$')

traj['efm'][-1,:,0] = 7.3

ax.scatter(traj['efm'][-1][:,0].cpu(),traj['efm'][-1][:,1].cpu(),traj['efm'][-1][:,2].cpu(),
          color='lightgreen',edgecolor='black',zorder=20,label=r'$y \sim  T(x_{+})$',s=80)
for idx in range(200,230):
    ax.plot(traj['efm'][: ,idx,0].cpu(), 
        traj['efm'][:, idx,1].cpu(),
        traj['efm'][:, idx,2].cpu(),
        color='black',linewidth=0.5, zorder=3);
     
ax.legend()


ax = fig.add_subplot(2,2,4,  projection='3d' )
ax.scatter(p_samples[:,0].cpu(),p_samples[:,1].cpu(),p_samples[:,2].cpu(),
           color='gray',edgecolor='black',s=80,label=r'$x_{q}  \sim \mathbb{P}(x_{q})$')
ax.scatter(q_samples[:,0].cpu(),q_samples[:,1].cpu(),q_samples[:,2].cpu(),
           color='yellow',edgecolor='black',s=80,label=r'$x_{\bar{q}}  \sim \mathbb{P}(x_{\bar{q}})$')
ax.scatter(traj['ifm'][-1][:,0].cpu(),traj['ifm'][-1][:,1].cpu(),traj['ifm'][-1][:,2].cpu(),
          color='darkorange',edgecolor='black',s=80,label=r'$y \sim  T(x_{q})$')
for idx in range(200,230):
    ax.plot(traj['ifm'][: ,idx,0].cpu(), 
        traj['ifm'][:, idx,1].cpu(),
        traj['ifm'][:, idx,2].cpu(),
        color='black',linewidth=0.5, zorder=3);
     
ax.legend()


fig.tight_layout(pad=0.01)