Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added some time warping experiments, couldnt get them to exactly matc…
…h paper but the multimodal is outperforming unimodal
- Loading branch information
Showing
6 changed files
with
579 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
*.png | ||
time_warping/mlruns | ||
time_warping/lightning_logs | ||
time_warping/__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
import math | ||
#This code was lifted from Andre Martins' group's codebase that they sent to me via email in an early stage. | ||
|
||
class BasisFunctions(object): | ||
def __len__(self): | ||
"""Number of basis functions.""" | ||
raise NotImplementedError | ||
|
||
def evaluate(self, t): | ||
raise NotImplementedError | ||
|
||
def integrate_t2_times_psi(self, a, b): | ||
"""Compute integral int_a^b (t**2) * psi(t).""" | ||
raise NotImplementedError | ||
|
||
def integrate_t_times_psi(self, a, b): | ||
"""Compute integral int_a^b t * psi(t).""" | ||
raise NotImplementedError | ||
|
||
def integrate_psi(self, a, b): | ||
"""Compute integral int_a^b psi(t).""" | ||
raise NotImplementedError | ||
|
||
class GaussianBasisFunctions(BasisFunctions): | ||
"""Function phi(t) = Gaussian(t; mu, sigma_sq).""" | ||
|
||
def __init__(self, mu, sigma): | ||
self.mu = mu.unsqueeze(0) | ||
self.sigma = sigma.unsqueeze(0) | ||
|
||
def __repr__(self): | ||
return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})" | ||
|
||
def __len__(self): | ||
"""Number of basis functions.""" | ||
return self.mu.size(1) | ||
|
||
def _phi(self, t): | ||
return 1.0 / math.sqrt(2 * math.pi) * torch.exp(-0.5 * t ** 2) | ||
|
||
def _Phi(self, t): | ||
return 0.5 * (1 + torch.erf(t / math.sqrt(2))) | ||
|
||
def _integrate_product_of_gaussians(self, mu, sigma_sq): | ||
sigma = torch.sqrt(self.sigma ** 2 + sigma_sq) | ||
return self._phi((mu - self.mu) / sigma) / sigma | ||
|
||
def evaluate(self, t): | ||
return self._phi((t - self.mu) / self.sigma) / self.sigma | ||
|
||
def integrate_t2_times_psi(self, a, b): | ||
"""Compute integral int_a^b (t**2) * psi(t).""" | ||
return ( | ||
(self.mu ** 2 + self.sigma ** 2) | ||
* ( | ||
self._Phi((b - self.mu) / self.sigma) | ||
- self._Phi((a - self.mu) / self.sigma) | ||
) | ||
- ( | ||
self.sigma | ||
* (b + self.mu) | ||
* self._phi((b - self.mu) / self.sigma) | ||
) | ||
+ ( | ||
self.sigma | ||
* (a + self.mu) | ||
* self._phi((a - self.mu) / self.sigma) | ||
) | ||
) | ||
|
||
def integrate_t_times_psi(self, a, b): | ||
"""Compute integral int_a^b t * psi(t).""" | ||
return self.mu * ( | ||
self._Phi((b - self.mu) / self.sigma) | ||
- self._Phi((a - self.mu) / self.sigma) | ||
) - self.sigma * ( | ||
self._phi((b - self.mu) / self.sigma) | ||
- self._phi((a - self.mu) / self.sigma) | ||
) | ||
|
||
def integrate_psi(self, a, b): | ||
"""Compute integral int_a^b psi(t).""" | ||
return self._Phi((b - self.mu) / self.sigma) - self._Phi( | ||
(a - self.mu) / self.sigma | ||
) | ||
|
||
def integrate_t2_times_psi_gaussian(self, mu, sigma_sq): | ||
"""Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t).""" | ||
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) | ||
mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / ( | ||
self.sigma ** 2 + sigma_sq | ||
) | ||
sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / ( | ||
self.sigma ** 2 + sigma_sq | ||
) | ||
return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde) | ||
|
||
def integrate_t_times_psi_gaussian(self, mu, sigma_sq): | ||
"""Compute integral int N(t; mu, sigma_sq) * t * psi(t).""" | ||
S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) | ||
mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / ( | ||
self.sigma ** 2 + sigma_sq | ||
) | ||
return S_tilde * mu_tilde | ||
|
||
def integrate_psi_gaussian(self, mu, sigma_sq): | ||
"""Compute integral int N(t; mu, sigma_sq) * psi(t).""" | ||
return self._integrate_product_of_gaussians(mu, sigma_sq) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import lightning as pl | ||
import torch | ||
import pdb | ||
import numpy as np | ||
import math | ||
from utils import create_psi | ||
from torch.utils.data import Dataset, DataLoader | ||
from matplotlib import pyplot as plt | ||
import torch.nn.functional as F | ||
|
||
class DATASET(Dataset): | ||
def __init__(self, N,T,nb_basis=64): | ||
self.X = [] | ||
self.y = [] | ||
self.N = N | ||
self.T = T | ||
self.nb_basis = nb_basis | ||
Z_all = [] | ||
l_all = [] | ||
for i in range(N): | ||
#using torch and not numpy | ||
l = torch.rand(1)*20 | ||
Z = torch.rand(4)*8-4 | ||
l_all.append(l) | ||
Z_all.append(Z) | ||
X_obs,X_orig = self.generate_trajectory(torch.linspace(0,1,T),l,Z) | ||
self.X.append(X_obs) | ||
for i in range(N): | ||
Z = Z_all[i] | ||
pk = self.class_probs(Z) | ||
y = (pk[0]>0.5).float() | ||
self.y.append(y) | ||
self.X = torch.stack(self.X) | ||
self.y = torch.stack(self.y) | ||
G,F = self.get_G(self.X.shape[1],nb_basis) | ||
B = self.X.matmul(G) | ||
self.B = B | ||
|
||
def integral(self,l,t): | ||
return (1-torch.exp(-l*t))/l | ||
|
||
def g(self,l,t): | ||
C = l/(1-torch.exp(-l)) | ||
return C*self.integral(l,t) | ||
|
||
def g_deriv(self,l,t): | ||
C = l/(1-torch.exp(-l)) | ||
return C*torch.exp(-l*t) | ||
|
||
def sigmoid(self,t): | ||
return 1 / (1 + torch.exp(-t)) | ||
|
||
def f(self,t,Z): | ||
return Z[0]*torch.cos(9*math.pi*t)*(t<0.25)+Z[1]*(t**2)*(t>=0.25)*(t<0.5)+Z[2]*torch.sin(t)*(t>=0.5)*(t<0.75)+Z[3]*(torch.cos(17*math.pi*t))*(t>=0.75)*(t<=1) | ||
|
||
def class_probs(self,Z): | ||
c1 = 2*(Z[0]*math.sin(9*math.pi*0.25)/(9*math.pi)+Z[1]*(0.5**3/3-0.25**3/3)) | ||
c2 = -2*Z[2]*(math.cos(0.75)-math.cos(0.5))+2*Z[3]*(np.sin(17*math.pi)/(17*math.pi)-np.sin(17*math.pi*0.75)/(17*math.pi)) | ||
#return softmax applied to a torch array formed from c1 and c2 | ||
return F.softmax(torch.tensor([c1,c2])) | ||
|
||
def generate_trajectory(self,t,l,Z): | ||
inv_warp = self.g(l,t) | ||
X_obs = self.f(inv_warp,Z) | ||
X_orig = self.f(t,Z) | ||
return X_obs,X_orig | ||
|
||
def p_i(self,l,t,c): | ||
C = l/(1-torch.exp(-l)) | ||
if c==0: | ||
return 2*(self.g(l,t)<0.5)*self.g_deriv(l,t) | ||
else: | ||
return 2*(self.g(l,t)>=0.5)*self.g_deriv(l,t) | ||
|
||
def get_G(self, max_length,nb_basis): | ||
psis=[] | ||
Gs = [] | ||
|
||
for length in range(1,max_length+1): | ||
psi = create_psi(length,nb_basis) | ||
shift = 1 / float(2 * length) | ||
positions = torch.linspace(shift, 1 - shift, length) | ||
positions = positions.unsqueeze(1) | ||
all_basis = [basis_function.evaluate(positions) | ||
for basis_function in psi] | ||
F = torch.cat(all_basis, dim=-1).t() | ||
nb_basis = sum([len(b) for b in psi]) | ||
assert F.size(0) == nb_basis | ||
|
||
# compute G with a ridge penalty | ||
penalty = 0.01 | ||
I = torch.eye(nb_basis) | ||
G = F.t().matmul((F.matmul(F.t()) + penalty * I).inverse()) | ||
psis.append(psi) | ||
Gs.append(G) | ||
G = Gs[max_length-1] | ||
return G,F | ||
|
||
def __len__(self): | ||
return len(self.y) | ||
|
||
def __getitem__(self, idx): | ||
return self.X[idx,:], self.B[idx,:], self.y[idx] | ||
|
||
if __name__=='__main__': | ||
X_obs_all = [] | ||
X_orig_all = [] | ||
l_all = [] | ||
Z_all = [] | ||
max_l=25 | ||
max_l=20 | ||
T = 95 | ||
N = 5 | ||
dataset = DATASET(N,T) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import lightning as pl | ||
from torch.utils.data import random_split | ||
from torch.utils.data import DataLoader | ||
from data import DATASET | ||
from matplotlib import pyplot as plt | ||
from model import MODEL | ||
import pdb | ||
import mlflow | ||
|
||
if __name__=="__main__": | ||
#1e-5 seems to work | ||
#encoder hidden size | ||
hidden_size = 128 | ||
#encoder output size | ||
output_size = 128 | ||
params = { | ||
'N':10000, | ||
'T':95, | ||
'bs':25, | ||
'nb_basis':64, | ||
'heads':2, | ||
'lr':1e-5, | ||
'hidden_size':128, | ||
'output_size':128, | ||
'inducing_points':128, | ||
'attention_mechanism':'kernel_sparsemax', | ||
'scheduler':'ReduceLROnPlateau', | ||
'optimizer':'RAdam' | ||
} | ||
#initialize dataset | ||
dataset = DATASET(params['N'],params['T'],params['nb_basis']) | ||
#split dataset into 0.75, 0.15, 0.1 using random split | ||
train, val, test = random_split(dataset, [0.75,0.15,0.1]) | ||
#create train/val/test dataloaders, each with batch size bs | ||
train_loader = DataLoader(train, batch_size=params['bs']) | ||
val_loader = DataLoader(val, batch_size=params['bs']) | ||
test_loader = DataLoader(test, batch_size=params['bs']) | ||
#get a single example from the train_loader | ||
x,B,y = next(iter(train_loader)) | ||
#kernel sparsemax isn't working well | ||
model = MODEL(params['T'], params['hidden_size'], params['output_size'], params['heads'], params['nb_basis'], params['inducing_points'], params['attention_mechanism'], params['heads'], params['optimizer'], params['lr'],params['scheduler']) | ||
#initialize trainer | ||
trainer = pl.Trainer(max_epochs=10,gradient_clip_val=0.1,precision=32) | ||
mlflow.pytorch.autolog() | ||
with mlflow.start_run(): | ||
#mlflow log params | ||
mlflow.log_params(params) | ||
#train model | ||
trainer.fit(model, train_loader, val_loader) | ||
#test model | ||
trainer.test(model, test_loader) |
Oops, something went wrong.