diff --git a/.gitignore b/.gitignore index e33609d..1778d3d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ *.png +time_warping/mlruns +time_warping/lightning_logs +time_warping/__pycache__ \ No newline at end of file diff --git a/time_warping/basis_functions.py b/time_warping/basis_functions.py new file mode 100644 index 0000000..a397126 --- /dev/null +++ b/time_warping/basis_functions.py @@ -0,0 +1,109 @@ +import torch +import math +#This code was lifted from Andre Martins' group's codebase that they sent to me via email in an early stage. + +class BasisFunctions(object): + def __len__(self): + """Number of basis functions.""" + raise NotImplementedError + + def evaluate(self, t): + raise NotImplementedError + + def integrate_t2_times_psi(self, a, b): + """Compute integral int_a^b (t**2) * psi(t).""" + raise NotImplementedError + + def integrate_t_times_psi(self, a, b): + """Compute integral int_a^b t * psi(t).""" + raise NotImplementedError + + def integrate_psi(self, a, b): + """Compute integral int_a^b psi(t).""" + raise NotImplementedError + +class GaussianBasisFunctions(BasisFunctions): + """Function phi(t) = Gaussian(t; mu, sigma_sq).""" + + def __init__(self, mu, sigma): + self.mu = mu.unsqueeze(0) + self.sigma = sigma.unsqueeze(0) + + def __repr__(self): + return f"GaussianBasisFunction(mu={self.mu}, sigma={self.sigma})" + + def __len__(self): + """Number of basis functions.""" + return self.mu.size(1) + + def _phi(self, t): + return 1.0 / math.sqrt(2 * math.pi) * torch.exp(-0.5 * t ** 2) + + def _Phi(self, t): + return 0.5 * (1 + torch.erf(t / math.sqrt(2))) + + def _integrate_product_of_gaussians(self, mu, sigma_sq): + sigma = torch.sqrt(self.sigma ** 2 + sigma_sq) + return self._phi((mu - self.mu) / sigma) / sigma + + def evaluate(self, t): + return self._phi((t - self.mu) / self.sigma) / self.sigma + + def integrate_t2_times_psi(self, a, b): + """Compute integral int_a^b (t**2) * psi(t).""" + return ( + (self.mu ** 2 + self.sigma ** 2) + * ( + self._Phi((b - self.mu) / self.sigma) + - self._Phi((a - self.mu) / self.sigma) + ) + - ( + self.sigma + * (b + self.mu) + * self._phi((b - self.mu) / self.sigma) + ) + + ( + self.sigma + * (a + self.mu) + * self._phi((a - self.mu) / self.sigma) + ) + ) + + def integrate_t_times_psi(self, a, b): + """Compute integral int_a^b t * psi(t).""" + return self.mu * ( + self._Phi((b - self.mu) / self.sigma) + - self._Phi((a - self.mu) / self.sigma) + ) - self.sigma * ( + self._phi((b - self.mu) / self.sigma) + - self._phi((a - self.mu) / self.sigma) + ) + + def integrate_psi(self, a, b): + """Compute integral int_a^b psi(t).""" + return self._Phi((b - self.mu) / self.sigma) - self._Phi( + (a - self.mu) / self.sigma + ) + + def integrate_t2_times_psi_gaussian(self, mu, sigma_sq): + """Compute integral int N(t; mu, sigma_sq) * t**2 * psi(t).""" + S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) + mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / ( + self.sigma ** 2 + sigma_sq + ) + sigma_sq_tilde = ((self.sigma ** 2) * sigma_sq) / ( + self.sigma ** 2 + sigma_sq + ) + return S_tilde * (mu_tilde ** 2 + sigma_sq_tilde) + + def integrate_t_times_psi_gaussian(self, mu, sigma_sq): + """Compute integral int N(t; mu, sigma_sq) * t * psi(t).""" + S_tilde = self._integrate_product_of_gaussians(mu, sigma_sq) + mu_tilde = (self.mu * sigma_sq + mu * self.sigma ** 2) / ( + self.sigma ** 2 + sigma_sq + ) + return S_tilde * mu_tilde + + def integrate_psi_gaussian(self, mu, sigma_sq): + """Compute integral int N(t; mu, sigma_sq) * psi(t).""" + return self._integrate_product_of_gaussians(mu, sigma_sq) \ No newline at end of file diff --git a/time_warping/data.py b/time_warping/data.py new file mode 100644 index 0000000..74fecf5 --- /dev/null +++ b/time_warping/data.py @@ -0,0 +1,114 @@ +import lightning as pl +import torch +import pdb +import numpy as np +import math +from utils import create_psi +from torch.utils.data import Dataset, DataLoader +from matplotlib import pyplot as plt +import torch.nn.functional as F + +class DATASET(Dataset): + def __init__(self, N,T,nb_basis=64): + self.X = [] + self.y = [] + self.N = N + self.T = T + self.nb_basis = nb_basis + Z_all = [] + l_all = [] + for i in range(N): + #using torch and not numpy + l = torch.rand(1)*20 + Z = torch.rand(4)*8-4 + l_all.append(l) + Z_all.append(Z) + X_obs,X_orig = self.generate_trajectory(torch.linspace(0,1,T),l,Z) + self.X.append(X_obs) + for i in range(N): + Z = Z_all[i] + pk = self.class_probs(Z) + y = (pk[0]>0.5).float() + self.y.append(y) + self.X = torch.stack(self.X) + self.y = torch.stack(self.y) + G,F = self.get_G(self.X.shape[1],nb_basis) + B = self.X.matmul(G) + self.B = B + + def integral(self,l,t): + return (1-torch.exp(-l*t))/l + + def g(self,l,t): + C = l/(1-torch.exp(-l)) + return C*self.integral(l,t) + + def g_deriv(self,l,t): + C = l/(1-torch.exp(-l)) + return C*torch.exp(-l*t) + + def sigmoid(self,t): + return 1 / (1 + torch.exp(-t)) + + def f(self,t,Z): + return Z[0]*torch.cos(9*math.pi*t)*(t<0.25)+Z[1]*(t**2)*(t>=0.25)*(t<0.5)+Z[2]*torch.sin(t)*(t>=0.5)*(t<0.75)+Z[3]*(torch.cos(17*math.pi*t))*(t>=0.75)*(t<=1) + + def class_probs(self,Z): + c1 = 2*(Z[0]*math.sin(9*math.pi*0.25)/(9*math.pi)+Z[1]*(0.5**3/3-0.25**3/3)) + c2 = -2*Z[2]*(math.cos(0.75)-math.cos(0.5))+2*Z[3]*(np.sin(17*math.pi)/(17*math.pi)-np.sin(17*math.pi*0.75)/(17*math.pi)) + #return softmax applied to a torch array formed from c1 and c2 + return F.softmax(torch.tensor([c1,c2])) + + def generate_trajectory(self,t,l,Z): + inv_warp = self.g(l,t) + X_obs = self.f(inv_warp,Z) + X_orig = self.f(t,Z) + return X_obs,X_orig + + def p_i(self,l,t,c): + C = l/(1-torch.exp(-l)) + if c==0: + return 2*(self.g(l,t)<0.5)*self.g_deriv(l,t) + else: + return 2*(self.g(l,t)>=0.5)*self.g_deriv(l,t) + + def get_G(self, max_length,nb_basis): + psis=[] + Gs = [] + + for length in range(1,max_length+1): + psi = create_psi(length,nb_basis) + shift = 1 / float(2 * length) + positions = torch.linspace(shift, 1 - shift, length) + positions = positions.unsqueeze(1) + all_basis = [basis_function.evaluate(positions) + for basis_function in psi] + F = torch.cat(all_basis, dim=-1).t() + nb_basis = sum([len(b) for b in psi]) + assert F.size(0) == nb_basis + + # compute G with a ridge penalty + penalty = 0.01 + I = torch.eye(nb_basis) + G = F.t().matmul((F.matmul(F.t()) + penalty * I).inverse()) + psis.append(psi) + Gs.append(G) + G = Gs[max_length-1] + return G,F + + def __len__(self): + return len(self.y) + + def __getitem__(self, idx): + return self.X[idx,:], self.B[idx,:], self.y[idx] + +if __name__=='__main__': + X_obs_all = [] + X_orig_all = [] + l_all = [] + Z_all = [] + max_l=25 + max_l=20 + T = 95 + N = 5 + dataset = DATASET(N,T) \ No newline at end of file diff --git a/time_warping/main.py b/time_warping/main.py new file mode 100644 index 0000000..87e1fe0 --- /dev/null +++ b/time_warping/main.py @@ -0,0 +1,51 @@ +import lightning as pl +from torch.utils.data import random_split +from torch.utils.data import DataLoader +from data import DATASET +from matplotlib import pyplot as plt +from model import MODEL +import pdb +import mlflow + +if __name__=="__main__": + #1e-5 seems to work + #encoder hidden size + hidden_size = 128 + #encoder output size + output_size = 128 + params = { + 'N':10000, + 'T':95, + 'bs':25, + 'nb_basis':64, + 'heads':2, + 'lr':1e-5, + 'hidden_size':128, + 'output_size':128, + 'inducing_points':128, + 'attention_mechanism':'kernel_sparsemax', + 'scheduler':'ReduceLROnPlateau', + 'optimizer':'RAdam' + } + #initialize dataset + dataset = DATASET(params['N'],params['T'],params['nb_basis']) + #split dataset into 0.75, 0.15, 0.1 using random split + train, val, test = random_split(dataset, [0.75,0.15,0.1]) + #create train/val/test dataloaders, each with batch size bs + train_loader = DataLoader(train, batch_size=params['bs']) + val_loader = DataLoader(val, batch_size=params['bs']) + test_loader = DataLoader(test, batch_size=params['bs']) + #get a single example from the train_loader + x,B,y = next(iter(train_loader)) + #kernel sparsemax isn't working well + model = MODEL(params['T'], params['hidden_size'], params['output_size'], params['heads'], params['nb_basis'], params['inducing_points'], params['attention_mechanism'], params['heads'], params['optimizer'], params['lr'],params['scheduler']) + #initialize trainer + trainer = pl.Trainer(max_epochs=10,gradient_clip_val=0.1,precision=32) + mlflow.pytorch.autolog() + with mlflow.start_run(): + #mlflow log params + mlflow.log_params(params) + #train model + trainer.fit(model, train_loader, val_loader) + #test model + trainer.test(model, test_loader) diff --git a/time_warping/model.py b/time_warping/model.py new file mode 100644 index 0000000..3f3cc69 --- /dev/null +++ b/time_warping/model.py @@ -0,0 +1,248 @@ +import lightning as pl +import torch +from torch import nn +from utils import add_gaussian_basis_functions, create_psi, _phi, exp_kernel, beta_exp, truncated_parabola +import math +import pdb +from lightning.pytorch.utilities import grad_norm + + +class FeedforwardEncoder(torch.nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim): + super().__init__() + #add dimensions as attributes of class + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + #define the encoder + self.encoder = torch.nn.Sequential( + torch.nn.Linear(input_dim, hidden_dim), + torch.nn.ReLU(), + torch.nn.Linear(hidden_dim, output_dim) + ) + + def forward(self, x): + return self.encoder(x) + +class Attention(torch.nn.Module): + def __init__(self,input_dim, heads,nb_basis,inducing_points,method='cts_softmax'): + super(Attention, self).__init__() + ''' + @param input_dim: int, dimension of input, which is actually the output dimension of the encoder + @param heads: int, number of heads + @param nb_basis: int, number of basis functions + @param inducing_points: int, number of inducing points + @param method: str, method to use for attention + ''' + self.heads = heads + self.inducing_points = inducing_points + self.method = method + self.inducing_locations = torch.linspace(0,1,inducing_points) + #map input to mu, sigma_sq, alpha + self.encode_mu = torch.nn.Linear(input_dim, heads) + self.encode_sigma_sq1 = torch.nn.Linear(input_dim, heads) + self.encode_sigma_sq2 = torch.nn.Softplus() + self.encode_alpha = torch.nn.Linear(input_dim,heads*inducing_points) + self.attn_weights = torch.nn.Softmax(1) + #add basis functions + GB = add_gaussian_basis_functions(nb_basis, + sigmas=[.1, .5]) + #add basis functions as attributes of class + self.mu_basis = GB.mu.unsqueeze(0) + self.sigma_basis = GB.sigma.unsqueeze(0) + #initialize a,b,dx,bandwidth + self.a=0 + self.b=1 + self.dx = 100 + self.bandwidth = 0.05 + + def _integrate_wrt_kernel_deformed(self,mu,sigma_sq,alpha): + T = torch.linspace(self.a,self.b,self.dx).unsqueeze(0).unsqueeze(0).unsqueeze(0) + #phi1_upper: size 1 x 1 x nb x dx + phi1_upper = self.mu_basis.unsqueeze(-1)-T + #phi1_lower: size 1 x 1 x nb x 1 + phi1_lower = self.sigma_basis.unsqueeze(-1) + #phi1: size 1 x 1 x nb x dx + phi1 = _phi(phi1_upper/phi1_lower)/phi1_lower + K_inputs = torch.cdist(self.inducing_locations.unsqueeze(-1),torch.linspace(self.a,self.b,self.dx).unsqueeze(-1)) + K = exp_kernel(K_inputs,self.bandwidth) + f = torch.matmul(alpha,K).unsqueeze(-2)#-0.5*(mu.unsqueeze(-1).unsqueeze(-1)-T)**2/sigma_sq.unsqueeze(-1).unsqueeze(-1) + f_max = torch.max(f,dim=-1,keepdim=True).values + exp_terms = beta_exp(f-f_max) + Z = torch.trapz(exp_terms,torch.linspace(self.a,self.b,self.dx),dim=-1).unsqueeze(-1) + numerical_integral = torch.trapz(phi1*exp_terms/Z,torch.linspace(self.a,self.b,self.dx),dim=-1) + return numerical_integral + + def _integrate_wrt_truncated_parabaloid(self,mu,sigma_sq): + T = torch.linspace(self.a,self.b,self.dx).unsqueeze(0).unsqueeze(0).unsqueeze(0) + #phi1_upper: size 1 x 1 x nb x dx + phi1_upper = self.mu_basis.unsqueeze(-1)-T + #phi1_lower: size 1 x 1 x nb x 1 + phi1_lower = self.sigma_basis.unsqueeze(-1) + #phi1: size 1 x 1 x nb x dx + phi1 = _phi(phi1_upper/phi1_lower)/phi1_lower + deformed_term = truncated_parabola(T,mu.unsqueeze(-1).unsqueeze(-1),sigma_sq.unsqueeze(-1).unsqueeze(-1)) + unnormalized_density = deformed_term + Z = torch.trapz(unnormalized_density,torch.linspace(self.a,self.b,self.dx),dim=-1).unsqueeze(-1) + numerical_integral = torch.trapz(phi1*unnormalized_density/Z,torch.linspace(self.a,self.b,self.dx),dim=-1) + return numerical_integral + + def _integrate_product_of_gaussians(self,mu,sigma_sq): + #T: size 1 x 1 x 1 x dx + T = torch.linspace(self.a,self.b,self.dx).unsqueeze(0).unsqueeze(0).unsqueeze(0) + #phi1_upper: size 1 x 1 x nb x dx + phi1_upper = self.mu_basis.unsqueeze(-1)-T + #phi1_lower: size 1 x 1 x nb x 1 + phi1_lower = self.sigma_basis.unsqueeze(-1) + #phi1: size 1 x 1 x nb x dx + phi1 = _phi(phi1_upper/phi1_lower)/phi1_lower + #phi2_upper: size bs x heads x 1 x dx + phi2_upper = mu.unsqueeze(-1).unsqueeze(-1)-T + #phi2_lower: size bs x heads x 1 x 1 + phi2_lower = sigma_sq.unsqueeze(-1).unsqueeze(-1).pow(0.5) + #phi2: size bs x heads x 1 x dx + phi2 = _phi(phi2_upper/phi2_lower)/phi2_lower + #phi1*phi2: size bs x heads x nb x dx + numerical_integral = torch.trapz(phi1*phi2,torch.linspace(self.a,self.b,self.dx),dim=-1) + return numerical_integral + + def _integrate_kernel_exp_wrt_gaussian(self,mu,sigma_sq,alpha): + #T: size 1 x 1 x 1 x dx + T = torch.linspace(self.a,self.b,self.dx).unsqueeze(0).unsqueeze(0).unsqueeze(0) + #K: kernel matrix, inducing x dx + K_inputs = torch.cdist(self.inducing_locations.unsqueeze(-1),torch.linspace(self.a,self.b,self.dx).unsqueeze(-1)) + K = exp_kernel(K_inputs,self.bandwidth) + #f: score, bs x heads x 1 x dx + f = torch.matmul(alpha,K).unsqueeze(-2) + #get max of f across dx + f_max = torch.max(f,dim=-1,keepdim=True).values + exp_f = torch.exp(f-f_max) + + #phi1_upper: size 1 x 1 x nb x dx + phi1_upper = self.mu_basis.unsqueeze(-1)-T + #phi1_lower: size 1 x 1 x nb x 1 + phi1_lower = self.sigma_basis.unsqueeze(-1) + #phi1: size 1 x 1 x nb x dx + phi1 = _phi(phi1_upper/phi1_lower)/phi1_lower + #phi2_upper: size bs x heads x 1 x dx + phi2_upper = mu.unsqueeze(-1).unsqueeze(-1)-T + #phi2_lower: size bs x heads x 1 x 1 + phi2_lower = sigma_sq.unsqueeze(-1).unsqueeze(-1).pow(0.5) + #phi2: size bs x heads x 1 x dx + phi2 = _phi(phi2_upper/phi2_lower)/phi2_lower + unnormalized_density = phi2*exp_f + Z = torch.trapz(unnormalized_density,torch.linspace(self.a,self.b,self.dx),dim=-1).unsqueeze(-1) + #phi1*phi2: size bs x heads x nb x dx + numerical_integral = torch.trapz(phi1*unnormalized_density/Z,torch.linspace(self.a,self.b,self.dx),dim=-1) + return numerical_integral + + def forward(self,x,B): + v = torch.nan_to_num(x,0) + #Compute mu, sigma_sq + mu = self.encode_mu(v) + sigma_sq = self.encode_sigma_sq1(v) + sigma_sq = self.encode_sigma_sq2(sigma_sq) + #alpha: bs x heads x inducing_points + alpha_init = self.encode_alpha(v) + alpha = alpha_init.reshape((alpha_init.shape[0],self.heads,self.inducing_points)) + if self.method=='kernel_softmax': + integrals = self._integrate_kernel_exp_wrt_gaussian(mu,sigma_sq,alpha) + elif self.method=='cts_softmax': + integrals = self._integrate_product_of_gaussians(mu,sigma_sq) + elif self.method=='cts_sparsemax': + integrals = self._integrate_wrt_truncated_parabaloid(mu,sigma_sq) + elif self.method=='kernel_sparsemax': + integrals = self._integrate_wrt_kernel_deformed(mu,sigma_sq,alpha) + integrals = torch.nan_to_num(integrals,0) + c = torch.bmm(integrals,B.unsqueeze(-1)).squeeze(-1) + return c,mu,sigma_sq,alpha + +class MODEL(pl.LightningModule): + def __init__(self, input_dim, hidden_dim, output_dim, heads, nb_basis, inducing_points, method, num_classes,optimizer='Adam',lr=1e-4,scheduler=None): + super().__init__() + self.encoder = FeedforwardEncoder(input_dim, hidden_dim, output_dim) + self.attention = Attention(output_dim, heads, nb_basis, inducing_points, method) + self.optimizer = optimizer + self.lr = lr + self.scheduler = scheduler + + def forward(self, x, B): + x = self.encoder(x) + c,mu,sigma_sq,alpha = self.attention(x,B) + return c + + def on_before_optimizer_step(self, optimizer): + # Compute the 2-norm for each layer + # If using mixed precision, the gradients are already unscaled here + norms = grad_norm(self.encoder, norm_type=2) + self.log_dict(norms) + norms = grad_norm(self.attention, norm_type=2) + self.log_dict(norms) + + def training_step(self, batch, batch_idx): + x,B,y = batch + y_orig = y.clone() + #map y to one hot encoding + y = torch.nn.functional.one_hot(y.to(torch.int64),2).float() + y_hat = self(x,B) + loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat,y)+1e-5*torch.norm(y_hat) + self.log("train_loss", loss, prog_bar=True) + self.log('train output norm',torch.norm(y_hat),prog_bar=True) + #compute accuracy + y_hat = torch.nn.functional.sigmoid(y_hat) + y_hat = torch.argmax(y_hat,dim=1) + acc = torch.sum(y_hat==y_orig)/len(y_orig) + self.log("train_acc", acc, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + x,B,y = batch + y_orig = y.clone() + #map y to one hot encoding + y = torch.nn.functional.one_hot(y.to(torch.int64),2).float() + y_hat = self(x,B) + loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat,y) + self.log("val_loss", loss, prog_bar=True) + #compute accuracy + y_hat = torch.nn.functional.sigmoid(y_hat) + y_hat = torch.argmax(y_hat,dim=1) + acc = torch.sum(y_hat==y_orig)/len(y) + self.log("val_acc", acc, prog_bar=True) + return loss + + def test_step(self, batch, batch_idx): + x,B,y = batch + y_orig = y.clone() + #map y to one hot encoding + y = torch.nn.functional.one_hot(y.to(torch.int64),2).float() + y_hat = self(x,B) + loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat,y) + self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + #compute accuracy + y_hat = torch.nn.functional.sigmoid(y_hat) + y_hat = torch.argmax(y_hat,dim=1) + acc = torch.sum(y_hat==y_orig)/len(y_orig) + self.log("test_acc", acc, prog_bar=True, on_step=False, on_epoch=True) + return loss + + def configure_optimizers(self): + if self.optimizer=='Adam': + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) + elif self.optimizer=='RAdam': + optimizer = torch.optim.RAdam(self.parameters(), lr=self.lr) + elif self.optimizer=='SGD': + optimizer = torch.optim.SGD(self.parameters(), lr=self.lr) + if self.scheduler=='StepLR': + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) + elif self.scheduler=='ReduceLROnPlateau': + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5,verbose=True) + elif self.scheduler=='CosineAnnealingLR': + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) + if self.scheduler == None: + return [optimizer] + return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'} + + +if __name__=='__main__': + print('hello world') + \ No newline at end of file diff --git a/time_warping/utils.py b/time_warping/utils.py new file mode 100644 index 0000000..7978460 --- /dev/null +++ b/time_warping/utils.py @@ -0,0 +1,54 @@ +import torch +import math +from basis_functions import GaussianBasisFunctions + +def add_gaussian_basis_functions(nb_basis, sigmas): + mu, sigma = torch.meshgrid(torch.linspace(0, 1, nb_basis // len(sigmas)), + torch.Tensor(sigmas)) + mus = mu.flatten() + sigmas = sigma.flatten() + return GaussianBasisFunctions(mus, sigmas) + +def create_psi(length,nb_basis): + psi = [] + nb_waves = nb_basis + nb_waves = max(2,nb_waves) + psi.append( + add_gaussian_basis_functions(nb_waves, + sigmas=[.1, .5], + # sigmas=[.03, .1, .3], + ) + ) + return psi + +def _phi(t): + ''' + @summary: Gaussian radial basis function + ''' + return 1.0/math.sqrt(2*math.pi)*torch.exp(-0.5*t**2) + +def exp_kernel(t,bandwidth): + ''' + @summary: Exponential kernel + ''' + return torch.exp(-torch.abs(t)/bandwidth) + + +def beta_exp(t): + ''' + @summary: beta-exponential function + ''' + q = 0 + plus = torch.nn.ReLU() + return plus(1+(1-q)*t)**(1./(1-q)) + + +def truncated_parabola(t,mu,sigma_sq): + ''' + @summary: Truncated parabola function + @param t: torch.Tensor, input + @param mu: torch.Tensor, mean + @param sigma_sq: torch.Tensor, variance + ''' + plus = torch.nn.ReLU() + return plus(-(t-mu)**2/(2*sigma_sq)+0.5*(3/(2*torch.sqrt(sigma_sq)))**(2./3.))