Permalink
Browse files

Add variance estimate model to Bayesian linear regression and remove …

…DataSet phantom type parameter
  • Loading branch information...
1 parent bae1ecf commit af9b579b8460dcca978d248ec929db14ee8248b2 @batterseapower committed Aug 17, 2008
@@ -4,8 +4,13 @@
-- the machine learning algorithms.
module Algorithms.MachineLearning.Framework where
+import Algorithms.MachineLearning.Utilities
+
import Numeric.LinearAlgebra
+import Data.List
+
+import System.Random
--
-- Ubiquitous synonyms for documentation purposes
@@ -55,21 +60,27 @@ instance Vectorable (Vector Double) where
-- Labelled data set
--
-data DataSet input = DataSet {
+data DataSet = DataSet {
ds_inputs :: Matrix Double, -- One row per sample, one column per input variable
ds_targets :: Vector Target -- One row per sample, each value being a single target variable
}
-dataSetFromSampleList :: Vectorable a => [(a, Target)] -> DataSet a
+dataSetFromSampleList :: Vectorable a => [(a, Target)] -> DataSet
dataSetFromSampleList elts
= DataSet {
ds_inputs = fromRows $ map (toVector . fst) elts,
ds_targets = fromList $ map snd elts
}
-dataSetToSampleList :: Vectorable a => DataSet a -> [(a, Target)]
+dataSetToSampleList :: Vectorable a => DataSet -> [(a, Target)]
dataSetToSampleList ds = zip (map fromVector $ toRows $ ds_inputs ds) (toList $ ds_targets ds)
+binDS :: StdGen -> Int -> DataSet -> [DataSet]
+binDS gen bins ds = map dataSetFromSampleList $ chunk (ceiling $ (fromIntegral $ length samples :: Double) / (fromIntegral bins)) shuffled_samples
+ where
+ samples = zip (toRows $ ds_inputs ds) (toList $ ds_targets ds)
+ shuffled_samples = shuffle gen samples
+
--
-- Models
--
@@ -24,6 +24,37 @@ instance Model LinearModel where
phi_app_x = applyVector (lm_basis_fns model) input
+data BayesianVarianceModel input = BayesianVarianceModel {
+ bvm_basis_fns :: [input -> Target],
+ bvm_weight_covariance :: Matrix Weight,
+ bvm_beta :: Precision
+ }
+
+instance Show (BayesianVarianceModel input) where
+ show model = "Weight Covariance: " ++ show (bvm_weight_covariance model) ++ "\n" ++
+ "Beta: " ++ show (bvm_beta model)
+
+instance Model BayesianVarianceModel where
+ predict model input = recip (bvm_beta model) + phi_app_x <.> (bvm_weight_covariance model <> phi_app_x)
+ where
+ phi_app_x = applyVector (bvm_basis_fns model) input
+
+
+regressDesignMatrix :: (Vectorable input) => [input -> Target] -> Matrix Double -> Matrix Double
+regressDesignMatrix basis_fns inputs
+ = applyMatrix (map (. fromVector) basis_fns) inputs -- One row per sample, one column per basis function
+
+-- | Regularized pseudo-inverse of a matrix, with regularization coefficient lambda.
+regularizedPinv :: RegularizationCoefficient -> Matrix Double -> Matrix Double
+regularizedPinv lambda phi = regularizedPrePinv lambda 1 phi <> trans phi
+
+-- | Just the left portion of the formula for the pseudo-inverse, with coefficients alpha and beta, i.e.:
+--
+-- > (alpha * _I_ + beta * _phi_ ^ T * _phi_) ^ -1
+regularizedPrePinv :: Precision -> Precision -> Matrix Double -> Matrix Double
+regularizedPrePinv alpha beta phi = inv $ (alpha .* (ident (cols phi))) + (beta .* (trans phi <> phi))
+
+
-- | Regress a basic linear model with no regularization at all onto the given data using the
-- supplied basis functions.
--
@@ -34,8 +65,11 @@ instance Model LinearModel where
-- is also very quick to find, since there is a closed form solution.
--
-- Equation 3.15 in Bishop.
-regressLinearModel :: (Vectorable input) => [input -> Target] -> DataSet input -> LinearModel input
-regressLinearModel = regressLinearModelCore pinv
+regressLinearModel :: (Vectorable input) => [input -> Target] -> DataSet -> LinearModel input
+regressLinearModel basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
+ where
+ design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
+ weights = pinv design_matrix <> ds_targets ds
-- | Regress a basic linear model with a sum-of-squares regularization term. This penalizes models with weight
-- vectors of large magnitudes and hence ameliorates the over-fitting problem of 'regressLinearModel'.
@@ -46,30 +80,29 @@ regressLinearModel = regressLinearModelCore pinv
-- the weight vector. Like 'regressLinearModel', a closed form solution is used to find the model quickly.
--
-- Equation 3.28 in Bishop.
-regressRegularizedLinearModel :: (Vectorable input) => RegularizationCoefficient -> [input -> Target] -> DataSet input -> LinearModel input
-regressRegularizedLinearModel lambda = regressLinearModelCore regularizedPinv
+regressRegularizedLinearModel :: (Vectorable input) => RegularizationCoefficient -> [input -> Target] -> DataSet -> LinearModel input
+regressRegularizedLinearModel lambda basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
where
- regularizedPinv phi = let trans_phi = trans phi
- in inv ((lambda .* (ident (cols phi))) + (trans_phi <> phi)) <> trans_phi
-
-regressLinearModelCore :: (Vectorable input) => (Matrix Double -> Matrix Double) -> [input -> Target] -> DataSet input -> LinearModel input
-regressLinearModelCore find_pinv basis_fns ds
- = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
- where
- designMatrix = applyMatrix (map (. fromVector) basis_fns) (ds_inputs ds) -- One row per sample, one column per basis function
- weights = find_pinv designMatrix <> (ds_targets ds)
-
+ design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
+ weights = regularizedPinv lambda design_matrix <> ds_targets ds
-- | Bayesian linear regression, using an isotropic Gaussian prior for the weights centred at the origin. The precision
-- of the weight prior is controlled by the parameter alpha, and our belief about the inherent noise in the data is
-- controlled by the precision parameter beta.
--
-- Bayesion linear regression with this prior is entirely equivalent to calling 'regressRegularizedLinearModel' with
--- lambda = alpha / beta.
+-- lambda = alpha / beta. However, the twist is that we can use our knowledge of the prior to also make an estimate
+-- for the variance of the true value about any input point.
--
--- Equation 3.55 in Bishop.
+-- Equations 3.53, 3.54 and 3.59 in Bishop.
bayesianLinearRegression :: (Vectorable input)
=> Precision -- ^ Precision of Gaussian weight prior
-> Precision -- ^ Precision of noise on samples
- -> [input -> Target] -> DataSet input -> LinearModel input
-bayesianLinearRegression alpha beta = regressRegularizedLinearModel (alpha / beta)
+ -> [input -> Target] -> DataSet -> (LinearModel input, BayesianVarianceModel input)
+bayesianLinearRegression alpha beta basis_fns ds
+ = (LinearModel { lm_basis_fns = basis_fns, lm_weights = weights },
+ BayesianVarianceModel { bvm_basis_fns = basis_fns, bvm_weight_covariance = weight_covariance, bvm_beta = beta })
+ where
+ design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
+ weight_covariance = regularizedPrePinv alpha beta design_matrix
+ weights = beta .* weight_covariance <> trans design_matrix <> (ds_targets ds)
@@ -12,8 +12,8 @@ import Algorithms.MachineLearning.Framework
-- @
--
-- Source: http://research.microsoft.com/~cmbishop/PRML/webdatasets/curvefitting.txt
-sinDataSet :: DataSet Double
-sinDataSet = dataSetFromSampleList [
+sinDataSet :: DataSet
+sinDataSet = dataSetFromSampleList ([
(0.000000, 0.349486),
(0.111111, 0.830839),
(0.222222, 1.007332),
@@ -24,4 +24,4 @@ sinDataSet = dataSetFromSampleList [
(0.777778, -0.445686),
(0.888889, -0.563567),
(1.000000, 0.261502)
- ]
+ ] :: [(Double, Double)])
@@ -8,7 +8,12 @@ import Algorithms.MachineLearning.Tests.Data
import Algorithms.MachineLearning.Utilities
import GNUPlot
+
+import Data.List
+import Data.Ord
+
import System.Cmd
+import System.Random
basisFunctions :: [Double -> Double]
@@ -20,20 +25,28 @@ sumOfSquaresError targetsAndPredictions = sum $ map (abs . uncurry (-)) targetsA
sample :: (Double -> Double) -> [(Double, Double)]
sample f = map (\(x :: Rational) -> let x' = rationalToDouble x in (x', f x')) [0,0.01..1.0]
-evaluate :: (Model model, Show (model Double)) => model Double -> DataSet Double -> IO ()
+evaluate :: (Model model, Show (model Double)) => model Double -> DataSet -> IO ()
evaluate model true_data = do
putStrLn $ "Target Mean = " ++ show (vectorMean (ds_targets true_data))
putStrLn $ "Error = " ++ show (sumOfSquaresError comparable_data)
putStrLn $ "Model:\n" ++ show model
- plot (dataSetToSampleList true_data) (sample fittedFunction)
where
fittedFunction = predict model
comparable_data = map (fittedFunction `onLeft`) (dataSetToSampleList true_data)
-plot :: [(Double, Target)] -> [(Double, Target)] -> IO ()
-plot true_samples fitted_samples = do
- plotPaths [EPS "output.ps"] [true_samples, fitted_samples]
+plot :: [[(Double, Target)]] -> IO ()
+plot sampless = do
+ plotPaths [EPS "output.ps"] (map (sortBy (comparing fst)) sampless)
void $ rawSystem "open" ["output.ps"]
main :: IO ()
-main = evaluate (regressRegularizedLinearModel 1 basisFunctions sinDataSet) sinDataSet
+main = do
+ gen <- newStdGen
+ let used_data = head $ binDS gen 2 sinDataSet
+ (model, variance_model) = bayesianLinearRegression 5 (1 / 0.3) basisFunctions used_data
+
+ -- Show some model statistics
+ evaluate model used_data
+
+ -- Show some graphical information about the model
+ plot [dataSetToSampleList used_data, sample $ predict model, sample $ predict variance_model]
@@ -2,6 +2,12 @@
-- home in a "utilities" module
module Algorithms.MachineLearning.Utilities where
+import Data.List
+import Data.Ord
+
+import System.Random
+
+
square :: Num a => a -> a
square x = x * x
@@ -18,4 +24,13 @@ onLeft :: (a -> c) -> (a, b) -> (c, b)
onLeft f (x, y) = (f x, y)
onRight :: (b -> c) -> (a, b) -> (a, c)
-onRight f (x, y) = (x, f y)
+onRight f (x, y) = (x, f y)
+
+shuffle :: StdGen -> [a] -> [a]
+shuffle gen xs = map snd $ sortBy (comparing fst) (zip (randoms gen :: [Double]) xs)
+
+chunk :: Int -> [a] -> [[a]]
+chunk _ [] = []
+chunk n xs = this : chunk n rest
+ where
+ (this, rest) = splitAt n xs
@@ -25,7 +25,7 @@ Library
Exposed-Modules: Algorithms.MachineLearning.Framework
Algorithms.MachineLearning.LinearRegression
- Build-Depends: hmatrix >= 0.4.0.0
+ Build-Depends: hmatrix >= 0.4.0.0, random >= 1.0.0.0
if flag(splitBase)
Build-Depends: base >= 3
else
@@ -39,7 +39,7 @@ Executable machine-learning-tests
-- I just need HTam for the GNUPlot module. Probably I should just write my
-- own GNUPlot interface module instead of depending on such a weird package :-)
- Build-Depends: hmatrix >= 0.4.0.0, HTam
+ Build-Depends: hmatrix >= 0.4.0.0, HTam, random >= 1.0.0.0
if flag(splitBase)
Build-Depends: base >= 3, process >= 1.0.0.0
else

0 comments on commit af9b579

Please sign in to comment.