Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Add variance estimate model to Bayesian linear regression and remove …

…DataSet phantom type parameter
  • Loading branch information...
commit af9b579b8460dcca978d248ec929db14ee8248b2 1 parent bae1ecf
Max Bolingbroke authored August 17, 2008
17  Algorithms/MachineLearning/Framework.hs
@@ -4,8 +4,13 @@
4 4
 -- the machine learning algorithms.
5 5
 module Algorithms.MachineLearning.Framework where
6 6
 
  7
+import Algorithms.MachineLearning.Utilities
  8
+
7 9
 import Numeric.LinearAlgebra
8 10
 
  11
+import Data.List
  12
+
  13
+import System.Random
9 14
 
10 15
 --
11 16
 -- Ubiquitous synonyms for documentation purposes
@@ -55,21 +60,27 @@ instance Vectorable (Vector Double) where
55 60
 -- Labelled data set
56 61
 --
57 62
 
58  
-data DataSet input = DataSet {
  63
+data DataSet = DataSet {
59 64
         ds_inputs :: Matrix Double, -- One row per sample, one column per input variable
60 65
         ds_targets :: Vector Target -- One row per sample, each value being a single target variable
61 66
     }
62 67
 
63  
-dataSetFromSampleList :: Vectorable a => [(a, Target)] -> DataSet a
  68
+dataSetFromSampleList :: Vectorable a => [(a, Target)] -> DataSet
64 69
 dataSetFromSampleList elts
65 70
   = DataSet {
66 71
     ds_inputs = fromRows $ map (toVector . fst) elts,
67 72
     ds_targets = fromList $ map snd elts
68 73
   }
69 74
 
70  
-dataSetToSampleList :: Vectorable a => DataSet a -> [(a, Target)]
  75
+dataSetToSampleList :: Vectorable a => DataSet -> [(a, Target)]
71 76
 dataSetToSampleList ds = zip (map fromVector $ toRows $ ds_inputs ds) (toList $ ds_targets ds)
72 77
 
  78
+binDS :: StdGen -> Int -> DataSet -> [DataSet]
  79
+binDS gen bins ds = map dataSetFromSampleList $ chunk (ceiling $ (fromIntegral $ length samples :: Double) / (fromIntegral bins)) shuffled_samples
  80
+  where
  81
+    samples = zip (toRows $ ds_inputs ds) (toList $ ds_targets ds)
  82
+    shuffled_samples = shuffle gen samples
  83
+
73 84
 --
74 85
 -- Models
75 86
 --
69  Algorithms/MachineLearning/LinearRegression.hs
@@ -24,6 +24,37 @@ instance Model LinearModel where
24 24
         phi_app_x = applyVector (lm_basis_fns model) input
25 25
 
26 26
 
  27
+data BayesianVarianceModel input = BayesianVarianceModel {
  28
+        bvm_basis_fns :: [input -> Target],
  29
+        bvm_weight_covariance :: Matrix Weight,
  30
+        bvm_beta :: Precision
  31
+    }
  32
+
  33
+instance Show (BayesianVarianceModel input) where
  34
+    show model = "Weight Covariance: " ++ show (bvm_weight_covariance model) ++ "\n" ++
  35
+                 "Beta: " ++ show (bvm_beta model)
  36
+
  37
+instance Model BayesianVarianceModel where
  38
+    predict model input = recip (bvm_beta model) + phi_app_x <.> (bvm_weight_covariance model <> phi_app_x)
  39
+      where
  40
+        phi_app_x = applyVector (bvm_basis_fns model) input
  41
+
  42
+
  43
+regressDesignMatrix :: (Vectorable input) => [input -> Target] -> Matrix Double -> Matrix Double
  44
+regressDesignMatrix basis_fns inputs
  45
+  = applyMatrix (map (. fromVector) basis_fns) inputs -- One row per sample, one column per basis function
  46
+
  47
+-- | Regularized pseudo-inverse of a matrix, with regularization coefficient lambda.
  48
+regularizedPinv :: RegularizationCoefficient -> Matrix Double -> Matrix Double
  49
+regularizedPinv lambda phi = regularizedPrePinv lambda 1 phi <> trans phi
  50
+
  51
+-- | Just the left portion of the formula for the pseudo-inverse, with coefficients alpha and beta, i.e.:
  52
+--
  53
+-- > (alpha * _I_ + beta * _phi_ ^ T * _phi_) ^ -1
  54
+regularizedPrePinv :: Precision -> Precision -> Matrix Double -> Matrix Double
  55
+regularizedPrePinv alpha beta phi = inv $ (alpha .* (ident (cols phi))) + (beta .* (trans phi <> phi))
  56
+
  57
+
27 58
 -- | Regress a basic linear model with no regularization at all onto the given data using the
28 59
 -- supplied basis functions.
29 60
 --
@@ -34,8 +65,11 @@ instance Model LinearModel where
34 65
 -- is also very quick to find, since there is a closed form solution.
35 66
 --
36 67
 -- Equation 3.15 in Bishop.
37  
-regressLinearModel :: (Vectorable input) => [input -> Target] -> DataSet input -> LinearModel input
38  
-regressLinearModel = regressLinearModelCore pinv
  68
+regressLinearModel :: (Vectorable input) => [input -> Target] -> DataSet -> LinearModel input
  69
+regressLinearModel basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
  70
+  where
  71
+    design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
  72
+    weights = pinv design_matrix <> ds_targets ds
39 73
 
40 74
 -- | Regress a basic linear model with a sum-of-squares regularization term.  This penalizes models with weight
41 75
 -- vectors of large magnitudes and hence ameliorates the over-fitting problem of 'regressLinearModel'.
@@ -46,30 +80,29 @@ regressLinearModel = regressLinearModelCore pinv
46 80
 -- the weight vector.  Like 'regressLinearModel', a closed form solution is used to find the model quickly.
47 81
 --
48 82
 -- Equation 3.28 in Bishop.
49  
-regressRegularizedLinearModel :: (Vectorable input) => RegularizationCoefficient -> [input -> Target] -> DataSet input -> LinearModel input
50  
-regressRegularizedLinearModel lambda = regressLinearModelCore regularizedPinv
  83
+regressRegularizedLinearModel :: (Vectorable input) => RegularizationCoefficient -> [input -> Target] -> DataSet -> LinearModel input
  84
+regressRegularizedLinearModel lambda basis_fns ds = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
51 85
   where
52  
-    regularizedPinv phi = let trans_phi = trans phi
53  
-                          in inv ((lambda .* (ident (cols phi))) + (trans_phi <> phi)) <> trans_phi
54  
-
55  
-regressLinearModelCore :: (Vectorable input) => (Matrix Double -> Matrix Double) -> [input -> Target] -> DataSet input -> LinearModel input
56  
-regressLinearModelCore find_pinv basis_fns ds
57  
-  = LinearModel { lm_basis_fns = basis_fns, lm_weights = weights }
58  
-  where
59  
-    designMatrix = applyMatrix (map (. fromVector) basis_fns) (ds_inputs ds) -- One row per sample, one column per basis function
60  
-    weights = find_pinv designMatrix <> (ds_targets ds)
61  
-
  86
+    design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
  87
+    weights = regularizedPinv lambda design_matrix <> ds_targets ds
62 88
 
63 89
 -- | Bayesian linear regression, using an isotropic Gaussian prior for the weights centred at the origin.  The precision
64 90
 -- of the weight prior is controlled by the parameter alpha, and our belief about the inherent noise in the data is
65 91
 -- controlled by the precision parameter beta.
66 92
 --
67 93
 -- Bayesion linear regression with this prior is entirely equivalent to calling 'regressRegularizedLinearModel' with
68  
--- lambda = alpha / beta.
  94
+-- lambda = alpha / beta.  However, the twist is that we can use our knowledge of the prior to also make an estimate
  95
+-- for the variance of the true value about any input point.
69 96
 --
70  
--- Equation 3.55 in Bishop.
  97
+-- Equations 3.53, 3.54 and 3.59 in Bishop.
71 98
 bayesianLinearRegression :: (Vectorable input) 
72 99
                          => Precision -- ^ Precision of Gaussian weight prior
73 100
                          -> Precision -- ^ Precision of noise on samples
74  
-                         -> [input -> Target] -> DataSet input -> LinearModel input
75  
-bayesianLinearRegression alpha beta = regressRegularizedLinearModel (alpha / beta)
  101
+                         -> [input -> Target] -> DataSet -> (LinearModel input, BayesianVarianceModel input)
  102
+bayesianLinearRegression alpha beta basis_fns ds
  103
+  = (LinearModel { lm_basis_fns = basis_fns, lm_weights = weights },
  104
+     BayesianVarianceModel { bvm_basis_fns = basis_fns, bvm_weight_covariance = weight_covariance, bvm_beta = beta })
  105
+  where
  106
+    design_matrix = regressDesignMatrix basis_fns (ds_inputs ds)
  107
+    weight_covariance = regularizedPrePinv alpha beta design_matrix
  108
+    weights = beta .* weight_covariance <> trans design_matrix <> (ds_targets ds)
6  Algorithms/MachineLearning/Tests/Data.hs
@@ -12,8 +12,8 @@ import Algorithms.MachineLearning.Framework
12 12
 -- @
13 13
 --
14 14
 -- Source: http://research.microsoft.com/~cmbishop/PRML/webdatasets/curvefitting.txt
15  
-sinDataSet :: DataSet Double
16  
-sinDataSet = dataSetFromSampleList [
  15
+sinDataSet :: DataSet
  16
+sinDataSet = dataSetFromSampleList ([
17 17
     (0.000000, 0.349486),
18 18
     (0.111111, 0.830839),
19 19
     (0.222222, 1.007332),
@@ -24,4 +24,4 @@ sinDataSet = dataSetFromSampleList [
24 24
     (0.777778, -0.445686),
25 25
     (0.888889, -0.563567),
26 26
     (1.000000, 0.261502)
27  
-  ]
  27
+  ] :: [(Double, Double)])
25  Algorithms/MachineLearning/Tests/Driver.hs
@@ -8,7 +8,12 @@ import Algorithms.MachineLearning.Tests.Data
8 8
 import Algorithms.MachineLearning.Utilities
9 9
 
10 10
 import GNUPlot
  11
+
  12
+import Data.List
  13
+import Data.Ord
  14
+
11 15
 import System.Cmd
  16
+import System.Random
12 17
 
13 18
 
14 19
 basisFunctions :: [Double -> Double]
@@ -20,20 +25,28 @@ sumOfSquaresError targetsAndPredictions = sum $ map (abs . uncurry (-)) targetsA
20 25
 sample :: (Double -> Double) -> [(Double, Double)]
21 26
 sample f = map (\(x :: Rational) -> let x' = rationalToDouble x in (x', f x')) [0,0.01..1.0]
22 27
 
23  
-evaluate :: (Model model, Show (model Double)) => model Double -> DataSet Double -> IO ()
  28
+evaluate :: (Model model, Show (model Double)) => model Double -> DataSet -> IO ()
24 29
 evaluate model true_data = do
25 30
     putStrLn $ "Target Mean = " ++ show (vectorMean (ds_targets true_data))
26 31
     putStrLn $ "Error = " ++ show (sumOfSquaresError comparable_data)
27 32
     putStrLn $ "Model:\n" ++ show model
28  
-    plot (dataSetToSampleList true_data) (sample fittedFunction)
29 33
   where
30 34
     fittedFunction = predict model
31 35
     comparable_data = map (fittedFunction `onLeft`) (dataSetToSampleList true_data)
32 36
 
33  
-plot :: [(Double, Target)] -> [(Double, Target)] -> IO ()
34  
-plot true_samples fitted_samples = do
35  
-    plotPaths [EPS "output.ps"] [true_samples, fitted_samples]
  37
+plot :: [[(Double, Target)]] -> IO ()
  38
+plot sampless = do
  39
+    plotPaths [EPS "output.ps"] (map (sortBy (comparing fst)) sampless)
36 40
     void $ rawSystem "open" ["output.ps"]
37 41
 
38 42
 main :: IO ()
39  
-main = evaluate (regressRegularizedLinearModel 1 basisFunctions sinDataSet) sinDataSet
  43
+main = do
  44
+    gen <- newStdGen
  45
+    let used_data = head $ binDS gen 2 sinDataSet
  46
+        (model, variance_model) = bayesianLinearRegression 5 (1 / 0.3) basisFunctions used_data
  47
+    
  48
+    -- Show some model statistics
  49
+    evaluate model used_data
  50
+    
  51
+    -- Show some graphical information about the model
  52
+    plot [dataSetToSampleList used_data, sample $ predict model, sample $ predict variance_model]
17  Algorithms/MachineLearning/Utilities.hs
@@ -2,6 +2,12 @@
2 2
 -- home in a "utilities" module
3 3
 module Algorithms.MachineLearning.Utilities where
4 4
 
  5
+import Data.List
  6
+import Data.Ord
  7
+
  8
+import System.Random
  9
+
  10
+
5 11
 square :: Num a => a -> a
6 12
 square x = x * x
7 13
 
@@ -18,4 +24,13 @@ onLeft :: (a -> c) -> (a, b) -> (c, b)
18 24
 onLeft f (x, y) = (f x, y)
19 25
 
20 26
 onRight :: (b -> c) -> (a, b) -> (a, c)
21  
-onRight f (x, y) = (x, f y)
  27
+onRight f (x, y) = (x, f y)
  28
+
  29
+shuffle :: StdGen -> [a] -> [a]
  30
+shuffle gen xs = map snd $ sortBy (comparing fst) (zip (randoms gen :: [Double]) xs)
  31
+
  32
+chunk :: Int -> [a] -> [[a]]
  33
+chunk _ [] = []
  34
+chunk n xs = this : chunk n rest
  35
+  where
  36
+    (this, rest) = splitAt n xs
4  machine-learning.cabal
@@ -25,7 +25,7 @@ Library
25 25
         Exposed-Modules:        Algorithms.MachineLearning.Framework
26 26
                                 Algorithms.MachineLearning.LinearRegression
27 27
 
28  
-        Build-Depends:          hmatrix >= 0.4.0.0
  28
+        Build-Depends:          hmatrix >= 0.4.0.0, random >= 1.0.0.0
29 29
         if flag(splitBase)
30 30
                 Build-Depends:  base >= 3
31 31
         else
@@ -39,7 +39,7 @@ Executable machine-learning-tests
39 39
         
40 40
         -- I just need HTam for the GNUPlot module. Probably I should just write my
41 41
         -- own GNUPlot interface module instead of depending on such a weird package :-)
42  
-        Build-Depends:          hmatrix >= 0.4.0.0, HTam
  42
+        Build-Depends:          hmatrix >= 0.4.0.0, HTam, random >= 1.0.0.0
43 43
         if flag(splitBase)
44 44
                 Build-Depends:  base >= 3, process >= 1.0.0.0
45 45
         else

0 notes on commit af9b579

Please sign in to comment.
Something went wrong with that request. Please try again.