Skip to content

Commit

Permalink
added some time warping experiments, couldnt get them to exactly matc…
Browse files Browse the repository at this point in the history
…h paper but the multimodal is outperforming unimodal
  • Loading branch information
onenoc committed Apr 7, 2024
1 parent d1642f4 commit b2f1b85
Show file tree
Hide file tree
Showing 6 changed files with 579 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
@@ -1 +1,4 @@
*.png
time_warping/mlruns
time_warping/lightning_logs
time_warping/__pycache__
109 changes: 109 additions & 0 deletions time_warping/basis_functions.py
@@ -0,0 +1,109 @@
import torch
import math
#This code was lifted from Andre Martins' group's codebase that they sent to me via email in an early stage.

class BasisFunctions(object):
def __len__(self):
"""Number of basis functions."""
raise NotImplementedError

def evaluate(self, t):
raise NotImplementedError

def integrate_t2_times_psi(self, a, b):
"""Compute integral int_a^b (t**2) * psi(t)."""
raise NotImplementedError

def integrate_t_times_psi(self, a, b):
"""Compute integral int_a^b t * psi(t)."""
raise NotImplementedError

def integrate_psi(self, a, b):
"""Compute integral int_a^b psi(t)."""
raise NotImplementedError

class GaussianBasisFunctions(BasisFunctions):
"""Function phi(t) = Gaussian(t; mu, sigma_sq)."""

def __init__(self, mu, sigma):
self.mu = mu.unsqueeze(0)
self.sigma = sigma.unsqueeze(0)

def __repr__(self):
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})"

def __len__(self):
"""Number of basis functions."""
return self.mu.size(1)

def _phi(self, t):
return 1.0 / math.sqrt(2 * math.pi) * torch.exp(-0.5 * t ** 2)

def _Phi(self, t):
return 0.5 * (1 + torch.erf(t / math.sqrt(2)))

def _integrate_product_of_gaussians(self, mu, sigma_sq):
sigma = torch.sqrt(self.sigma ** 2 + sigma_sq)
return self._phi((mu - self.mu) / sigma) / sigma

def evaluate(self, t):
return self._phi((t - self.mu) / self.sigma) / self.sigma

def integrate_t2_times_psi(self, a, b):
"""Compute integral int_a^b (t**2) * psi(t)."""
return (
(self.mu ** 2 + self.sigma ** 2)
* (
self._Phi((b - self.mu) / self.sigma)
- self._Phi((a - self.mu) / self.sigma)
)
- (
self.sigma
* (b + self.mu)
* self._phi((b - self.mu) / self.sigma)
)
+ (
self.sigma
* (a + self.mu)
* self._phi((a - self.mu) / self.sigma)
)
)

def integrate_t_times_psi(self, a, b):
"""Compute integral int_a^b t * psi(t)."""
return self.mu * (
self._Phi((b - self.mu) / self.sigma)
- self._Phi((a - self.mu) / self.sigma)
) - self.sigma * (
self._phi((b - self.mu) / self.sigma)
- self._phi((a - self.mu) / self.sigma)
)

def integrate_psi(self, a, b):
"""Compute integral int_a^b psi(t)."""
return self._Phi((b - self.mu) / self.sigma) - self._Phi(
(a - self.mu) / self.sigma
)

def integrate_t2_times_psi_gaussian(self, mu, sigma_sq):
"""Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t)."""
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / (
self.sigma ** 2 + sigma_sq
)
sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / (
self.sigma ** 2 + sigma_sq
)
return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde)

def integrate_t_times_psi_gaussian(self, mu, sigma_sq):
"""Compute integral int N(t; mu, sigma_sq) * t * psi(t)."""
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq)
mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / (
self.sigma ** 2 + sigma_sq
)
return S_tilde * mu_tilde

def integrate_psi_gaussian(self, mu, sigma_sq):
"""Compute integral int N(t; mu, sigma_sq) * psi(t)."""
return self._integrate_product_of_gaussians(mu, sigma_sq)
114 changes: 114 additions & 0 deletions time_warping/data.py
@@ -0,0 +1,114 @@
import lightning as pl
import torch
import pdb
import numpy as np
import math
from utils import create_psi
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
import torch.nn.functional as F

class DATASET(Dataset):
def __init__(self, N,T,nb_basis=64):
self.X = []
self.y = []
self.N = N
self.T = T
self.nb_basis = nb_basis
Z_all = []
l_all = []
for i in range(N):
#using torch and not numpy
l = torch.rand(1)*20
Z = torch.rand(4)*8-4
l_all.append(l)
Z_all.append(Z)
X_obs,X_orig = self.generate_trajectory(torch.linspace(0,1,T),l,Z)
self.X.append(X_obs)
for i in range(N):
Z = Z_all[i]
pk = self.class_probs(Z)
y = (pk[0]>0.5).float()
self.y.append(y)
self.X = torch.stack(self.X)
self.y = torch.stack(self.y)
G,F = self.get_G(self.X.shape[1],nb_basis)
B = self.X.matmul(G)
self.B = B

def integral(self,l,t):
return (1-torch.exp(-l*t))/l

def g(self,l,t):
C = l/(1-torch.exp(-l))
return C*self.integral(l,t)

def g_deriv(self,l,t):
C = l/(1-torch.exp(-l))
return C*torch.exp(-l*t)

def sigmoid(self,t):
return 1 / (1 + torch.exp(-t))

def f(self,t,Z):
return Z[0]*torch.cos(9*math.pi*t)*(t<0.25)+Z[1]*(t**2)*(t>=0.25)*(t<0.5)+Z[2]*torch.sin(t)*(t>=0.5)*(t<0.75)+Z[3]*(torch.cos(17*math.pi*t))*(t>=0.75)*(t<=1)

def class_probs(self,Z):
c1 = 2*(Z[0]*math.sin(9*math.pi*0.25)/(9*math.pi)+Z[1]*(0.5**3/3-0.25**3/3))
c2 = -2*Z[2]*(math.cos(0.75)-math.cos(0.5))+2*Z[3]*(np.sin(17*math.pi)/(17*math.pi)-np.sin(17*math.pi*0.75)/(17*math.pi))
#return softmax applied to a torch array formed from c1 and c2
return F.softmax(torch.tensor([c1,c2]))

def generate_trajectory(self,t,l,Z):
inv_warp = self.g(l,t)
X_obs = self.f(inv_warp,Z)
X_orig = self.f(t,Z)
return X_obs,X_orig

def p_i(self,l,t,c):
C = l/(1-torch.exp(-l))
if c==0:
return 2*(self.g(l,t)<0.5)*self.g_deriv(l,t)
else:
return 2*(self.g(l,t)>=0.5)*self.g_deriv(l,t)

def get_G(self, max_length,nb_basis):
psis=[]
Gs = []

for length in range(1,max_length+1):
psi = create_psi(length,nb_basis)
shift = 1 / float(2 * length)
positions = torch.linspace(shift, 1 - shift, length)
positions = positions.unsqueeze(1)
all_basis = [basis_function.evaluate(positions)
for basis_function in psi]
F = torch.cat(all_basis, dim=-1).t()
nb_basis = sum([len(b) for b in psi])
assert F.size(0) == nb_basis

# compute G with a ridge penalty
penalty = 0.01
I = torch.eye(nb_basis)
G = F.t().matmul((F.matmul(F.t()) + penalty * I).inverse())
psis.append(psi)
Gs.append(G)
G = Gs[max_length-1]
return G,F

def __len__(self):
return len(self.y)

def __getitem__(self, idx):
return self.X[idx,:], self.B[idx,:], self.y[idx]

if __name__=='__main__':
X_obs_all = []
X_orig_all = []
l_all = []
Z_all = []
max_l=25
max_l=20
T = 95
N = 5
dataset = DATASET(N,T)
51 changes: 51 additions & 0 deletions time_warping/main.py
@@ -0,0 +1,51 @@
import lightning as pl
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from data import DATASET
from matplotlib import pyplot as plt
from model import MODEL
import pdb
import mlflow

if __name__=="__main__":
#1e-5 seems to work
#encoder hidden size
hidden_size = 128
#encoder output size
output_size = 128
params = {
'N':10000,
'T':95,
'bs':25,
'nb_basis':64,
'heads':2,
'lr':1e-5,
'hidden_size':128,
'output_size':128,
'inducing_points':128,
'attention_mechanism':'kernel_sparsemax',
'scheduler':'ReduceLROnPlateau',
'optimizer':'RAdam'
}
#initialize dataset
dataset = DATASET(params['N'],params['T'],params['nb_basis'])
#split dataset into 0.75, 0.15, 0.1 using random split
train, val, test = random_split(dataset, [0.75,0.15,0.1])
#create train/val/test dataloaders, each with batch size bs
train_loader = DataLoader(train, batch_size=params['bs'])
val_loader = DataLoader(val, batch_size=params['bs'])
test_loader = DataLoader(test, batch_size=params['bs'])
#get a single example from the train_loader
x,B,y = next(iter(train_loader))
#kernel sparsemax isn't working well
model = MODEL(params['T'], params['hidden_size'], params['output_size'], params['heads'], params['nb_basis'], params['inducing_points'], params['attention_mechanism'], params['heads'], params['optimizer'], params['lr'],params['scheduler'])
#initialize trainer
trainer = pl.Trainer(max_epochs=10,gradient_clip_val=0.1,precision=32)
mlflow.pytorch.autolog()
with mlflow.start_run():
#mlflow log params
mlflow.log_params(params)
#train model
trainer.fit(model, train_loader, val_loader)
#test model
trainer.test(model, test_loader)

0 comments on commit b2f1b85

Please sign in to comment.