Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add multivariate basis functions

  • Loading branch information...
commit ce57d3fedfcd796577be231ea1d81fd520e96b48 1 parent 70a539a
@batterseapower authored
Showing with 17 additions and 2 deletions.
  1. +17 −2 Algorithms/MachineLearning/BasisFunctions.hs
View
19 Algorithms/MachineLearning/BasisFunctions.hs
@@ -2,11 +2,12 @@
module Algorithms.MachineLearning.BasisFunctions where
import Algorithms.MachineLearning.Framework
+import Algorithms.MachineLearning.LinearAlgebra
import Algorithms.MachineLearning.Utilities
-- | Basis function that is 1 everywhere
-constantBasis :: Double -> Double
+constantBasis :: a -> Double
constantBasis = const 1
-- | /Unnormalized/ 1D Gaussian, suitable for use as a basis function.
@@ -19,4 +20,18 @@ gaussianBasis mean variance x = exp (negate $ (square (x - mean)) / (2 * varianc
-- | Family of gaussian basis functions with constant variance and the given means, with
-- a constant basis function to capture the mean of the target variable.
gaussianBasisFamily :: [Mean] -> Variance -> [Double -> Double]
-gaussianBasisFamily means variance = constantBasis : map (flip gaussianBasis variance) means
+
+-- | /Unnormalized/ multi-dimensional Gaussian, suitable for use as a basis function.
+multivariateGaussianBasis :: Vector Mean -- ^ Mean of the Gaussian
+ -> Matrix Variance -- ^ Covariance matrix
+ -> Vector Double -- ^ Point to sample
+ -> Double
+multivariateGaussianBasis mean covariance x = exp (negate $ (deviation <.> (inv covariance <> deviation)) / 2)
+ where deviation = x - mean
+
+-- | Family of multi-dimensional gaussian basis functions with constant, isotropic variance and
+-- the given means, with a constant basis function to capture the mean of the target variable.
+multivariateIsotropicGaussianBasisFamily :: [Vector Mean] -> Variance -> [Vector Double -> Double]
+multivariateIsotropicGaussianBasisFamily means common_variance = constantBasis : map (flip multivariateGaussianBasis covariance) means
+ where covariance = (1 / common_variance) .* ident (dim (head means))
Please sign in to comment.
Something went wrong with that request. Please try again.