Browse files

moved hcasadi stuff into it's own separate package

  • Loading branch information...
1 parent d9f6284 commit dc25fd22c24f79cf024ad02e7392b5fc027062a8 @ghorn committed Dec 22, 2011
View
10 Casadi.hs
@@ -1,10 +0,0 @@
--- Casadi.hs
-
-{-# OPTIONS_GHC -Wall #-}
-
-module Casadi
- (
- module Casadi.Api
- ) where
-
-import Casadi.Api
View
61 Casadi/Api.hs
@@ -1,61 +0,0 @@
--- Api.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}
-
-{-
-this module is all the user should ever need to see. It is re-exported as Casadi.
-In Casadi.* many functions have longer names to make maintaining the bindings easier, but these longer names
-are either used to instance the Matrix typeclass or are given shorter names in this module.
--}
-
-module Casadi.Api
- (
- SX(..)
- , SXFunction(..)
- , SXMatrix(..)
- , DMatrix(..)
- , Matrix(..)
- , Boundable(..)
- , sxSymbolic
- , sxMatrixSymbolic
- , sxFunction
- , sxFunctionEvaluate
- , sxInt
- , sxDouble
- , sxBound
- , gradient
- , hessian
- , jacobian
- ) where
-
-import Casadi.SX
-import Casadi.DMatrix
-import Casadi.SXMatrix
-import Casadi.SXFunction
-import Casadi.Matrix
-import System.IO.Unsafe(unsafePerformIO)
-
-sxSymbolic :: String -> SX
-{-# NOINLINE sxSymbolic #-}
-sxSymbolic name = unsafePerformIO $ do
- sym <- sxCreateSymbolic name
- return sym
-
-sxMatrixSymbolic :: String -> (Int,Int) -> SXMatrix
-{-# NOINLINE sxMatrixSymbolic #-}
-sxMatrixSymbolic prefix dim' = unsafePerformIO $ do
- mat <- sxMatrixCreateSymbolic prefix dim'
- return mat
-
-sxFunction :: Matrix a b c => [SXMatrix] -> [SXMatrix] -> [a] -> [a]
-sxFunction ins outs = sxFunctionEvaluate $ sxFunctionCreate ins outs
-
-sxInt :: Int -> SX
-sxInt = sxFromInt
-
-sxDouble :: Double -> SX
-sxDouble = sxFromDouble
-
---sxIntegral :: Integral a => a -> SX
---sxIntegral = sxFromIntegral
View
24 Casadi/CasadiInterfaceUtils.hs
@@ -1,24 +0,0 @@
--- CasadiInterfaceUtils.hs
-
-{-# OPTIONS_GHC -Wall #-}
-
-module Casadi.CasadiInterfaceUtils
- (
- withForeignPtrs2
- , withForeignPtrs3
- ) where
-
-import Foreign.Ptr
-import Foreign.ForeignPtr
-
--- convenience functions for cpp wrappers
-withForeignPtrs2 :: (Ptr a -> Ptr b -> IO c) -> ForeignPtr a -> ForeignPtr b -> IO c
-withForeignPtrs2 f0 p0 p1 = withForeignPtr p1 $ \p1' -> (f1 p1')
- where
- f1 p1' = withForeignPtr p0 (\p0' -> f0 p0' p1')
-
-withForeignPtrs3 :: (Ptr a -> Ptr b -> Ptr c -> IO d) -> ForeignPtr a -> ForeignPtr b -> ForeignPtr c -> IO d
-withForeignPtrs3 f0 p0 p1 p2 = withForeignPtr p2 $ \p2' -> (f2 p2')
- where
- f2 p2' = withForeignPtr p1 (\p1' -> f1 p1' p2' )
- f1 p1' p2' = withForeignPtr p0 (\p0' -> f0 p0' p1' p2')
View
328 Casadi/DMatrix.hs
@@ -1,328 +0,0 @@
--- DMatrix.hs
-
-{-# OPTIONS_GHC -Wall #-}
---{-# OPTIONS_GHC -Wall -fno-cse -fno-full-laziness #-}
-{-# LANGUAGE ForeignFunctionInterface, MultiParamTypeClasses #-}
-
-module Casadi.DMatrix
- (
- DMatrix(..)
- , DMatrixRaw(..)
- , dMatrixNewZeros
- ) where
-
-import Casadi.CasadiInterfaceUtils
-import Casadi.Matrix
-import Casadi.SXFunctionRaw
-
-import Foreign.C
-import Foreign.Marshal
-import Foreign.ForeignPtr
-import Foreign.Ptr
-import Control.Exception(mask_)
-import System.IO.Unsafe(unsafePerformIO)
-import Control.DeepSeq
-import Data.List(intersperse)
-import Data.Tuple(swap)
-
-
--- the DMatrix data type
-data DMatrixRaw = DMatrixRaw
-newtype DMatrix = DMatrix (ForeignPtr DMatrixRaw)
-
-instance NFData DMatrix where
- rnf x = x `seq` ()
-
--- foreign imports
-foreign import ccall unsafe "&dMatrixDelete" c_dMatrixDelete
- :: FunPtr (Ptr DMatrixRaw -> IO ())
-foreign import ccall unsafe "dMatrixZeros" c_dMatrixZeros
- :: CInt -> CInt -> IO (Ptr DMatrixRaw)
-foreign import ccall unsafe "dMatrixAt" c_dMatrixAt
- :: (Ptr DMatrixRaw) -> CInt -> CInt -> IO CDouble
-foreign import ccall unsafe "dMatrixSetToList" c_dMatrixSetToList
- :: CInt -> Ptr CDouble -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixSetFromList" c_dMatrixSetFromList
- :: CInt -> Ptr CDouble -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixSetFromLists" c_dMatrixSetFromLists
- :: CInt -> CInt -> Ptr CDouble -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixSize1" c_dMatrixSize1
- :: (Ptr DMatrixRaw) -> IO CInt
-foreign import ccall unsafe "dMatrixSize2" c_dMatrixSize2
- :: (Ptr DMatrixRaw) -> IO CInt
-foreign import ccall unsafe "dMatrixPlus" c_dMatrixPlus
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixMinus" c_dMatrixMinus
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixNegate" c_dMatrixNegate
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMM" c_dMM
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixTranspose" c_dMatrixTranspose
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixIsEqual" c_dMatrixIsEqual
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO CInt
-foreign import ccall unsafe "dMatrixScale" c_dMatrixScale
- :: CDouble -> (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixInv" c_dMatrixInv
- :: (Ptr DMatrixRaw) -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "dMatrixVertcat" c_dMatrixVertcat
- :: Ptr (Ptr DMatrixRaw) -> CInt -> (Ptr DMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxFunctionEvaluateDMatrix" c_sxFunctionEvaluateDMatrix
- :: CInt -> Ptr (Ptr DMatrixRaw) -> CInt -> Ptr (Ptr DMatrixRaw) -> Ptr SXFunctionRaw -> IO ()
-
-
------------------ create -------------------------
-dMatrixNewZeros :: (Int, Int) -> IO DMatrix
-dMatrixNewZeros (n,m) = mask_ $ do
- let n' = safeToCInt n
- m' = safeToCInt m
-
- safeToCInt :: Int -> CInt
- safeToCInt x
- | and [toInteger x <= maxCInt, toInteger x >= minCInt] = fromIntegral x
- | otherwise = error "Error - dMatrixNewZeros dimensions too big"
- where
- maxCInt = fromIntegral (maxBound :: CInt)
- minCInt = fromIntegral (minBound :: CInt)
-
- mat <- c_dMatrixZeros n' m' >>= newForeignPtr c_dMatrixDelete
- return $ DMatrix mat
-
-
-dMatrixZeros :: (Int, Int) -> DMatrix
-{-# NOINLINE dMatrixZeros #-}
-dMatrixZeros dimensions = unsafePerformIO $ do
- mOut <- dMatrixNewZeros dimensions
- return mOut
-
-
-dMatrixToList :: DMatrix -> [Double]
-{-# NOINLINE dMatrixToList #-}
-dMatrixToList (DMatrix dMatRaw) = unsafePerformIO $ do
- let (n,m) = size (DMatrix dMatRaw)
- if (m /= 1)
- then error "dMatrixToList can only be used on an n by 1 matrix"
- else do dListPtr <- mallocArray n
- withForeignPtr dMatRaw $ c_dMatrixSetToList (fromIntegral n) dListPtr
- listOut <- peekArray n dListPtr
-
- return $ map realToFrac listOut
-
-
-dMatrixFromList :: [Double] -> DMatrix
-{-# NOINLINE dMatrixFromList #-}
-dMatrixFromList dList = unsafePerformIO $ do
- dListPtr <- newArray (map realToFrac dList)
- DMatrix m0 <- dMatrixNewZeros (length dList, 1)
- withForeignPtr m0 $ c_dMatrixSetFromList (fromIntegral $ length dList) dListPtr
- return $ DMatrix m0
-
-
-dMatrixFromLists :: [[Double]] -> DMatrix
-{-# NOINLINE dMatrixFromLists #-}
-dMatrixFromLists dLists = unsafePerformIO $ do
- let rows' = length dLists
- cols' = length (head dLists)
- dListPtr <- newArray $ map realToFrac (concat dLists)
- DMatrix m0 <- dMatrixNewZeros (rows', cols')
- withForeignPtr m0 $ c_dMatrixSetFromLists (fromIntegral rows') (fromIntegral cols') dListPtr
- return $ DMatrix m0
-
-
---------------- getters/setters ---------------------
-dMatrixAt :: DMatrix -> (Int,Int) -> IO Double
-dMatrixAt (DMatrix matIn) (n,m) = do
- dOut <- withForeignPtr matIn (\matIn' -> c_dMatrixAt matIn' (fromIntegral n) (fromIntegral m))
- return $ realToFrac dOut
-
-
----------------- dimensions --------------------
-dMatrixSize :: DMatrix -> (Int,Int)
-{-# NOINLINE dMatrixSize #-}
-dMatrixSize (DMatrix matIn) = unsafePerformIO $ do
- n <- withForeignPtr matIn c_dMatrixSize1
- m <- withForeignPtr matIn c_dMatrixSize2
- return (fromIntegral n, fromIntegral m)
-
-
-dMatrixToLists :: DMatrix -> [[Double]]
-{-# NOINLINE dMatrixToLists #-}
-dMatrixToLists mat = unsafePerformIO $ do
- let f row = mapM (\col -> dMatrixAt mat (row, col)) [0..m-1]
- (n,m) = dMatrixSize mat
- mapM f [0..n-1]
-
-
-------------------------- math ---------------------------------
-dMatrixPlus :: DMatrix -> DMatrix -> DMatrix
-{-# NOINLINE dMatrixPlus #-}
-dMatrixPlus (DMatrix m0) (DMatrix m1) = unsafePerformIO $ do
- let size'
- | sizeM0 == sizeM1 = sizeM0
- | otherwise = error "dMatrixPlus can't add matrices of different dimensions"
- where
- sizeM0 = dMatrixSize (DMatrix m0)
- sizeM1 = dMatrixSize (DMatrix m1)
- DMatrix mOut <- dMatrixNewZeros size'
- withForeignPtrs3 c_dMatrixPlus m0 m1 mOut
- return $ DMatrix mOut
-
-
-dMatrixMinus :: DMatrix -> DMatrix -> DMatrix
-{-# NOINLINE dMatrixMinus #-}
-dMatrixMinus (DMatrix m0) (DMatrix m1) = unsafePerformIO $ do
- let size'
- | sizeM0 == sizeM1 = sizeM0
- | otherwise = error $ "dMatrixMinus can't subtract " ++ (show (dMatrixSize (DMatrix m0))) ++ " matrix\n" ++ (show (DMatrix m0)) ++ "\nby " ++ (show (dMatrixSize (DMatrix m1))) ++ " matrix\n" ++ (show (DMatrix m1))
- where
- sizeM0 = dMatrixSize (DMatrix m0)
- sizeM1 = dMatrixSize (DMatrix m1)
- DMatrix mOut <- dMatrixNewZeros size'
- withForeignPtrs3 c_dMatrixMinus m0 m1 mOut
- return $ DMatrix mOut
-
-
-dMatrixNegate :: DMatrix -> DMatrix
-{-# NOINLINE dMatrixNegate #-}
-dMatrixNegate (DMatrix m0) = unsafePerformIO $ do
- DMatrix mOut <- dMatrixNewZeros (dMatrixSize (DMatrix m0))
- withForeignPtrs2 c_dMatrixNegate m0 mOut
- return $ DMatrix mOut
-
-
-dMM :: DMatrix -> DMatrix -> DMatrix
-{-# NOINLINE dMM #-}
-dMM (DMatrix m0) (DMatrix m1) = unsafePerformIO $ do
- let size'
- | colsM0 == rowsM1 = (rowsM0, colsM1)
- | otherwise = error "dMM sees incompatible dimensions"
- where
- (rowsM0, colsM0) = dMatrixSize (DMatrix m0)
- (rowsM1, colsM1) = dMatrixSize (DMatrix m1)
-
- DMatrix mOut <- dMatrixNewZeros size'
- withForeignPtrs3 c_dMM m0 m1 mOut
- return $ DMatrix mOut
-
-
-dMatrixTranspose :: DMatrix -> DMatrix
-{-# NOINLINE dMatrixTranspose #-}
-dMatrixTranspose (DMatrix mIn) = unsafePerformIO $ do
- DMatrix mOut <- dMatrixNewZeros $ swap (dMatrixSize (DMatrix mIn))
- withForeignPtrs2 c_dMatrixTranspose mIn mOut
- return $ DMatrix mOut
-
-
-dMatrixIsEqual :: DMatrix -> DMatrix -> Bool
-{-# NOINLINE dMatrixIsEqual #-}
-dMatrixIsEqual (DMatrix m0) (DMatrix m1) = unsafePerformIO $ do
- isEq <- withForeignPtrs2 c_dMatrixIsEqual m0 m1
- if (isEq == 1)
- then
- return True
- else
- return False
-
-
-dMatrixScale :: Double -> DMatrix -> DMatrix
-{-# NOINLINE dMatrixScale #-}
-dMatrixScale scalar (DMatrix mIn) = unsafePerformIO $ do
- DMatrix mOut <- dMatrixNewZeros (1,1)
- withForeignPtrs2 (c_dMatrixScale $ realToFrac scalar) mIn mOut
- return $ DMatrix mOut
-
-
-dMatrixInv :: DMatrix -> DMatrix
-{-# NOINLINE dMatrixInv #-}
-dMatrixInv (DMatrix mIn) = unsafePerformIO $ do
- DMatrix mOut <- dMatrixNewZeros (1,1)
- withForeignPtrs2 c_dMatrixInv mIn mOut
- return $ DMatrix mOut
-
-
-dMatrixVertcat :: [DMatrix] -> DMatrix
-{-# NOINLINE dMatrixVertcat #-}
-dMatrixVertcat inputs = unsafePerformIO $ do
- -- turn input DMatrix lists into [Ptr DMatrixRaw]
- let unsafeInputPtrs :: [Ptr DMatrixRaw]
- unsafeInputPtrs = map (\(DMatrix mat) -> unsafeForeignPtrToPtr mat) inputs
- nIn = fromIntegral $ length inputs
-
- -- turn [Ptr SXMatrixRaw] into Ptr (Ptr DMatrixRaw)
- inputPtrArray <- newArray unsafeInputPtrs
-
- DMatrix mOutRaw <- dMatrixNewZeros (sum $ map rows inputs, cols (head inputs))
-
- withForeignPtr mOutRaw $ c_dMatrixVertcat inputPtrArray nIn
-
- -- touch all [ForeignPtr DMatrixRaw] for unsafeForeignPtrToPtr safety
- mapM_ (\(DMatrix d) -> touchForeignPtr d) inputs
-
- return (DMatrix mOutRaw)
-
-
------------------ typeclass stuff ------------------
-instance Show DMatrix where
- show d = f (dMatrixSize d)
- where
- f (1,1) = show $ toLists d
- f (_,1) = show $ toList d
- f (1,_) = show $ toLists d
- f (_,_) = '[': (concat $ intersperse "\n " $ map show (toLists d)) ++ "]"
-
-instance Eq DMatrix where
- (==) = dMatrixIsEqual
- (/=) d0 d1 = not $ d0 == d1
-
-
-instance Num DMatrix where
- (+) = dMatrixPlus
- (-) = dMatrixMinus
- (*) m0 m1
- | dMatrixSize m0 == (1,1) = dMatrixScale s0 m1
- | dMatrixSize m1 == (1,1) = dMatrixScale s1 m0
- | otherwise = dMM m0 m1
- where
- [s0] = dMatrixToList m0
- [s1] = dMatrixToList m1
-
- abs = error "abs not defined for instance Num DMatrix"
- signum = error "signum not defined for instance Num DMatrix"
- fromInteger i = dMatrixFromList [fromIntegral i]
-
- negate = dMatrixNegate
-
-
-instance Fractional DMatrix where
- (/) m0 m1 = m0 * (recip m1)
- recip mat = dMatrixInv mat
- fromRational x = dMatrixFromList [fromRational x :: Double]
-
-
-instance Matrix DMatrix Double DMatrixRaw where
- trans = dMatrixTranspose
- size = dMatrixSize
- rows = fst . dMatrixSize
- cols = snd . dMatrixSize
- toList = dMatrixToList
- toLists = dMatrixToLists
- fromList = dMatrixFromList
- fromLists = dMatrixFromLists
- vertcat = dMatrixVertcat
- inv = dMatrixInv
- scale = dMatrixScale
- zeros = dMatrixZeros
-
- c_sxFunctionEvaluate _ = c_sxFunctionEvaluateDMatrix
- getForeignPtr (DMatrix r) = r
- newZeros = dMatrixNewZeros
-
-instance Boundable DMatrix where
- bound xs (lbs, ubs) = fromList $ zipWith boundDouble (toList xs) (zip (toList lbs) (toList ubs))
- where
- boundDouble x (lb, ub)
- | ub < lb = error $ "in boundDouble, ub (" ++ show ub ++ ") < lb (" ++ show lb ++ ")"
- | x < lb = lb
- | x > ub = ub
- | otherwise = x
View
39 Casadi/Matrix.hs
@@ -1,39 +0,0 @@
--- Matrix.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies #-}
-
-module Casadi.Matrix( Matrix(..)
- , Boundable(..)
- ) where
-
-import Casadi.SXFunctionRaw
-
-import Control.DeepSeq
-import Foreign.Ptr
-import Foreign.ForeignPtr
-import Foreign.C
-
-class (NFData a, NFData b, Num a, Fractional a, Num b, Fractional b, Floating b, Boundable a) => Matrix a b c | a -> b c where
- trans :: a -> a
- size :: a -> (Int, Int)
- rows :: a -> Int
- cols :: a -> Int
- toList :: a -> [b]
- toLists :: a -> [[b]]
- fromList :: [b] -> a
- fromLists :: [[b]] -> a
- vertcat :: [a] -> a
- inv :: a -> a
- toSingleton :: a -> b
- scale :: b -> a -> a
- zeros :: (Int,Int) -> a
-
- newZeros :: (Int, Int) -> IO a
- getForeignPtr :: a -> ForeignPtr c
- c_sxFunctionEvaluate :: a -> CInt -> Ptr (Ptr c) -> CInt -> Ptr (Ptr c) -> Ptr SXFunctionRaw -> IO ()
-
- toSingleton = head . toList
-
-class Boundable a where
- bound :: a -> (a, a) -> a
View
329 Casadi/SX.hs
@@ -1,329 +0,0 @@
--- SX.hs
-
-{-# OPTIONS_GHC -Wall #-}
---{-# OPTIONS_GHC -Wall -fno-cse -fno-full-laziness #-}
-{-# LANGUAGE ForeignFunctionInterface #-}
-
-module Casadi.SX
- (
- SX(..)
- , SXRaw(..)
- , sxNewDouble
- , sxNewInt
- , sxNewIntegral
- , sxFromInt
- , sxFromIntegral
- , sxFromDouble
- , sxCreateSymbolic
- , sxBound
- ) where
-
-import Casadi.CasadiInterfaceUtils
-import Casadi.Matrix
-
-import Foreign.C
-import Foreign.Ptr
-import Foreign.ForeignPtr
-import Control.Exception(mask_)
-import System.IO.Unsafe(unsafePerformIO)
-import Control.DeepSeq
- --import Data.Ratio(numerator, denominator)
-
--- the SX data type
-data SXRaw = SXRaw
-newtype SX = SX (ForeignPtr SXRaw)
-
-instance NFData SX where
- rnf x = x `seq` ()
-
--- foreign imports
-foreign import ccall unsafe "sxInterface.hpp sxCreateSymbolic" c_sxCreateSymbolic :: Ptr CChar -> IO (Ptr SXRaw)
-foreign import ccall unsafe "sxInterface.hpp sxNewDouble" c_sxNewDouble :: CDouble -> IO (Ptr SXRaw)
-foreign import ccall unsafe "sxInterface.hpp sxNewInt" c_sxNewInt :: CInt -> IO (Ptr SXRaw)
-foreign import ccall unsafe "sxInterface.hpp &sxDelete" c_sxDelete :: FunPtr (Ptr SXRaw -> IO ())
-foreign import ccall unsafe "sxInterface.hpp sxShow" c_sxShow :: Ptr CChar -> CInt -> (Ptr SXRaw) -> IO ()
-
-foreign import ccall unsafe "sxInterface.hpp sxEqual" c_sxEqual :: Ptr SXRaw -> Ptr SXRaw -> IO CInt
-foreign import ccall unsafe "sxInterface.hpp sxPlus" c_sxPlus :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxMinus" c_sxMinus :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxTimes" c_sxTimes :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxDivide" c_sxDivide :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxNegate" c_sxNegate :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxAbs" c_sxAbs :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxSignum" c_sxSignum :: Ptr SXRaw -> IO CInt
-
-foreign import ccall unsafe "sxInterface.hpp sxPi" c_sxPi :: Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxExp" c_sxExp :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxSqrt" c_sxSqrt :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxLog" c_sxLog :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxPow" c_sxPow :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxSin" c_sxSin :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxCos" c_sxCos :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxTan" c_sxTan :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxArcsin" c_sxArcsin :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxArccos" c_sxArccos :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-foreign import ccall unsafe "sxInterface.hpp sxArctan" c_sxArctan :: Ptr SXRaw -> Ptr SXRaw -> IO ()
-
-foreign import ccall unsafe "sxInterface.hpp sxBound" c_sxBound :: Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> Ptr SXRaw -> IO ()
-
--- creation
-sxCreateSymbolic :: String -> IO SX
-sxCreateSymbolic name = mask_ $ do
- cName <- newCString name
- sym <- c_sxCreateSymbolic cName >>= newForeignPtr c_sxDelete
- return $ SX sym
-
-sxNewDouble :: Double -> IO SX
-sxNewDouble val = mask_ $ do
- f <- c_sxNewDouble (realToFrac val) >>= newForeignPtr c_sxDelete
- return $ SX f
-
-sxNewInt :: Int -> IO SX
-sxNewInt val = mask_ $ do
- f <- c_sxNewInt (fromIntegral val) >>= newForeignPtr c_sxDelete
- return $ SX f
-
-
-sxNewIntegral :: Integral a => a -> IO SX
-sxNewIntegral val
- | withinCIntBounds val = sxNewInt (fromIntegral val)
- | otherwise = error "input out of range of CInt in sxNewIntegral"
- where
- withinCIntBounds x = and [fromIntegral x <= maxCInt, fromIntegral x >= minCInt]
- maxCInt = toInteger (maxBound :: CInt)
- minCInt = toInteger (minBound :: CInt)
-
-sxShow :: SX -> String
-{-# NOINLINE sxShow #-}
-sxShow (SX s) = unsafePerformIO $ do
- (stringRef, stringLength) <- newCStringLen $ replicate 512 ' '
- withForeignPtr s $ c_sxShow stringRef (fromIntegral stringLength)
- peekCString stringRef
-
-sxEqual :: SX -> SX -> Bool
-{-# NOINLINE sxEqual #-}
-sxEqual (SX sx0) (SX sx1) = unsafePerformIO $ do
- equalInt <- withForeignPtrs2 c_sxEqual sx0 sx1
- let equalBool
- | equalInt == 0 = False
- | otherwise = True
- return equalBool
-
-sxPlus :: SX -> SX -> SX
-{-# NOINLINE sxPlus #-}
-sxPlus (SX sx0) (SX sx1) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs3 c_sxPlus sx0 sx1 sxOut
- return (SX sxOut)
-
-sxMinus :: SX -> SX -> SX
-{-# NOINLINE sxMinus #-}
-sxMinus (SX sx0) (SX sx1) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs3 c_sxMinus sx0 sx1 sxOut
- return (SX sxOut)
-
-sxTimes :: SX -> SX -> SX
-{-# NOINLINE sxTimes #-}
-sxTimes (SX sx0) (SX sx1) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs3 c_sxTimes sx0 sx1 sxOut
- return (SX sxOut)
-
-sxDivide :: SX -> SX -> SX
-{-# NOINLINE sxDivide #-}
-sxDivide (SX sx0) (SX sx1) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs3 c_sxDivide sx0 sx1 sxOut
- return (SX sxOut)
-
-sxNegate :: SX -> SX
-{-# NOINLINE sxNegate #-}
-sxNegate (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxNegate sx sxOut
- return (SX sxOut)
-
-sxAbs :: SX -> SX
-{-# NOINLINE sxAbs #-}
-sxAbs (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxAbs sx sxOut
- return (SX sxOut)
-
-sxSignum :: SX -> SX
-{-# NOINLINE sxSignum #-}
-sxSignum (SX sx) = unsafePerformIO $ do
- sign <- withForeignPtr sx c_sxSignum
- if (sign == 1)
- then sxNewInt 1
- else sxNewInt (-1)
-
-
-
-sxPi :: SX
-{-# NOINLINE sxPi #-}
-sxPi = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtr sxOut $ c_sxPi
- return (SX sxOut)
-
-sxExp :: SX -> SX
-{-# NOINLINE sxExp #-}
-sxExp (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxExp sx sxOut
- return (SX sxOut)
-
-sxSqrt :: SX -> SX
-{-# NOINLINE sxSqrt #-}
-sxSqrt (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxSqrt sx sxOut
- return (SX sxOut)
-
-sxLog :: SX -> SX
-{-# NOINLINE sxLog #-}
-sxLog (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxLog sx sxOut
- return (SX sxOut)
-
-sxPow :: SX -> SX -> SX
-{-# NOINLINE sxPow #-}
-sxPow (SX sx0) (SX sx1) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs3 c_sxPow sx0 sx1 sxOut
- return (SX sxOut)
-
-sxSin :: SX -> SX
-{-# NOINLINE sxSin #-}
-sxSin (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxSin sx sxOut
- return (SX sxOut)
-
-sxCos :: SX -> SX
-{-# NOINLINE sxCos #-}
-sxCos (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxCos sx sxOut
- return (SX sxOut)
-
-sxTan :: SX -> SX
-{-# NOINLINE sxTan #-}
-sxTan (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxTan sx sxOut
- return (SX sxOut)
-
-sxArcsin :: SX -> SX
-{-# NOINLINE sxArcsin #-}
-sxArcsin (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxArcsin sx sxOut
- return (SX sxOut)
-
-sxArccos :: SX -> SX
-{-# NOINLINE sxArccos #-}
-sxArccos (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxArccos sx sxOut
- return (SX sxOut)
-
-sxArctan :: SX -> SX
-{-# NOINLINE sxArctan #-}
-sxArctan (SX sx) = unsafePerformIO $ do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 c_sxArctan sx sxOut
- return (SX sxOut)
-
-
-
-sxFromInt :: Int -> SX
-{-# NOINLINE sxFromInt #-}
-sxFromInt n = unsafePerformIO $ do
- sxOut <- sxNewInt n
- return sxOut
-
-sxFromIntegral :: Integral a => a -> SX
-{-# NOINLINE sxFromIntegral #-}
-sxFromIntegral n = unsafePerformIO $ do
- sxOut <- sxNewIntegral n
- return sxOut
-
-sxFromDouble :: Double -> SX
-{-# NOINLINE sxFromDouble #-}
-sxFromDouble val = unsafePerformIO $ do
- s <- sxNewDouble val
- return s
-
-
-sxBound :: SX -> (SX, SX) -> SX
-{-# NOINLINE sxBound #-}
-sxBound (SX sxIn) (SX sxLb, SX sxUb) = unsafePerformIO $ do
- (SX sxOut) <- sxNewDouble 0
-
- let sxLb' = unsafeForeignPtrToPtr sxLb
- sxUb' = unsafeForeignPtrToPtr sxUb
- sxIn' = unsafeForeignPtrToPtr sxIn
- sxOut' = unsafeForeignPtrToPtr sxOut
-
- c_sxBound sxLb' sxUb' sxIn' sxOut'
-
- touchForeignPtr sxLb
- touchForeignPtr sxUb
- touchForeignPtr sxIn
- touchForeignPtr sxOut
-
- return (SX sxOut)
-
-
--- typeclass stuff
-instance Eq SX where
- (==) = sxEqual
- (/=) sx0 sx1 = not $ sx0 == sx1
-
-instance Show SX where
- show sx = sxShow sx
-
-instance Num SX where
- (+) = sxPlus
- (*) = sxTimes
- (-) = sxMinus
- negate = sxNegate
- abs = sxAbs
- signum = sxSignum
- fromInteger = sxFromIntegral
-
-instance Fractional SX where
- (/) = sxDivide
- recip sx = (sxFromInt 1)/sx
--- fromRational x = (sxFromIntegral num)/(sxFromIntegral den)
--- where
--- num = numerator x
--- den = denominator x
- fromRational x = sxFromDouble (fromRational x)
-
-instance Floating SX where
- pi = sxPi
- exp = sxExp
- sqrt = sxSqrt
- log = sxLog
- (**) = sxPow
--- logBase :: a -> a -> a
- sin = sxSin
- tan = sxTan
- cos = sxCos
- asin = sxArcsin
- atan = sxArctan
- acos = sxArccos
- sinh = error "hyperbolic functions not yet implemented for SX"
- tanh = error "hyperbolic functions not yet implemented for SX"
- cosh = error "hyperbolic functions not yet implemented for SX"
- asinh = error "hyperbolic functions not yet implemented for SX"
- atanh = error "hyperbolic functions not yet implemented for SX"
- acosh = error "hyperbolic functions not yet implemented for SX"
-
-instance Boundable SX where
- bound = sxBound
View
347 Casadi/SXFunction.hs
@@ -1,347 +0,0 @@
--- SXFunction.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE ForeignFunctionInterface, ScopedTypeVariables #-}
-
-module Casadi.SXFunction
- (
- SXFunction(..)
- , sxFunctionCreate
- , sxFunctionEvaluate
- , sxFunctionEvaluateLists
- , sxFunctionGetInputsSX
- , sxFunctionGetOutputsSX
- , sxFunctionGradientAt
- , sxFunctionGradients
- , sxFunctionJacobianAt
- , sxFunctionHessianAt
- , sxFunctionCompile
- ) where
-
-import Casadi.SXFunctionRaw(SXFunctionRaw(..))
-import Casadi.SXMatrix
-import Casadi.DMatrix
-import Casadi.Matrix
-import Casadi.CasadiInterfaceUtils
-
-import Foreign.C
-import Foreign.ForeignPtr
-import Foreign.Ptr
-import Foreign.Marshal(newArray)
-import Control.Exception(mask_)
-import System.IO.Unsafe(unsafePerformIO)
-import System.IO
-import Text.Printf
-import Control.DeepSeq
-import System.Process
-import System.Exit(ExitCode(..))
-import Debug.Trace
-import Data.Time.Clock
-
--- the SXFunction data type
-data SXFunction = SXFunction { sxFunRaw :: ForeignPtr SXFunctionRaw
- , sxFunNumInputs :: Int
- , sxFunNumOutputs :: Int
- , sxFunInputDims :: [(Int,Int)]
- , sxFunOutputDims :: [(Int,Int)]
- }
-
-instance NFData SXFunction where
- rnf x = x `seq` ()
-
--- foreign imports
-foreign import ccall unsafe "sxFunctionCreate" c_sxFunctionCreate
- :: Ptr (Ptr SXMatrixRaw) -> CInt -> Ptr (Ptr SXMatrixRaw) -> CInt -> IO (Ptr SXFunctionRaw)
-foreign import ccall unsafe "&sxFunctionDelete" c_sxFunctionDelete
- :: FunPtr (Ptr SXFunctionRaw -> IO ())
-foreign import ccall unsafe "sxFunctionGetNumInputs" c_sxFunctionGetNumInputs
- :: Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGetNumOutputs" c_sxFunctionGetNumOutputs
- :: Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGetInputsSX" c_sxFunctionGetInputsSX
- :: Ptr SXFunctionRaw -> CInt -> Ptr SXMatrixRaw -> IO ()
-foreign import ccall unsafe "sxFunctionGetOutputsSX" c_sxFunctionGetOutputsSX
- :: Ptr SXFunctionRaw -> CInt -> Ptr SXMatrixRaw -> IO ()
-foreign import ccall unsafe "sxFunctionGetInputSize1" c_sxFunctionGetInputSize1
- :: CInt -> Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGetInputSize2" c_sxFunctionGetInputSize2
- :: CInt -> Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGetOutputSize1" c_sxFunctionGetOutputSize1
- :: CInt -> Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGetOutputSize2" c_sxFunctionGetOutputSize2
- :: CInt -> Ptr SXFunctionRaw -> IO CInt
-foreign import ccall unsafe "sxFunctionGradient" c_sxFunctionGradient
- :: Ptr SXFunctionRaw -> CInt -> Ptr SXMatrixRaw -> IO ()
-foreign import ccall unsafe "sxFunctionJacobian" c_sxFunctionJacobian
- :: Ptr SXFunctionRaw -> CInt -> CInt -> Ptr SXMatrixRaw -> IO ()
-foreign import ccall unsafe "sxFunctionHessian" c_sxFunctionHessian
- :: Ptr SXFunctionRaw -> CInt -> CInt -> Ptr SXMatrixRaw -> IO ()
-
-foreign import ccall unsafe "generateCCode" c_generateCCode
- :: Ptr CChar -> Ptr SXFunctionRaw -> IO CDouble
-foreign import ccall unsafe "createExternalFunction" c_createExternalFunction
- :: Ptr CChar -> IO (Ptr SXFunctionRaw)
-
-
-sxFunctionCreate :: [SXMatrix] -> [SXMatrix] -> SXFunction
-{-# NOINLINE sxFunctionCreate #-}
-sxFunctionCreate inputs outputs = unsafePerformIO $ mask_ $ do
- -- turn input/output SXMatrix lists into [Ptr SXMatrixRaw]
- let unsafeInputPtrs :: [Ptr SXMatrixRaw]
- unsafeInputPtrs = map (\(SXMatrix mat) -> unsafeForeignPtrToPtr mat) inputs
-
- unsafeOutputPtrs :: [Ptr SXMatrixRaw]
- unsafeOutputPtrs = map (\(SXMatrix mat) -> unsafeForeignPtrToPtr mat) outputs
-
- nIn = fromIntegral $ length inputs
- nOut = fromIntegral $ length outputs
-
- -- turn [Ptr SXMatrixRaw] into Ptr (Ptr SXMatrixRaw)
- inputPtrArray <- newArray unsafeInputPtrs
- outputPtrArray <- newArray unsafeOutputPtrs
-
- -- create SXFunction
- funRaw <- c_sxFunctionCreate inputPtrArray nIn outputPtrArray nOut >>= newForeignPtr c_sxFunctionDelete
-
- -- touch all [ForeignPtr SXMatrixRaw] for unsafeForeignPtrToPtr safety
- mapM_ (\(SXMatrix d) -> touchForeignPtr d) inputs
- mapM_ (\(SXMatrix d) -> touchForeignPtr d) outputs
-
- -- prepare output data structure
- let funOut = SXFunction { sxFunRaw = funRaw
- , sxFunNumInputs = length inputs
- , sxFunNumOutputs = length outputs
- , sxFunInputDims = map size inputs
- , sxFunOutputDims = map size outputs}
-
- -- make sure dimensions are right
- checkSXFunctionDimensions funOut
-
- return funOut
-
-
-------------------- getters -----------------------
-checkSXFunctionDimensions :: SXFunction -> IO ()
-checkSXFunctionDimensions fun = do
-
- let sxFunctionNumInputs :: SXFunction -> IO Int
- sxFunctionNumInputs fun' = do
- num <- withForeignPtr (sxFunRaw fun') c_sxFunctionGetNumInputs
- return $ fromIntegral num
-
- sxFunctionNumOutputs :: SXFunction -> IO Int
- sxFunctionNumOutputs fun' = do
- num <- withForeignPtr (sxFunRaw fun') c_sxFunctionGetNumOutputs
- return $ fromIntegral num
-
- sxFunctionGetInputDim :: SXFunction -> Int -> IO (Int, Int)
- sxFunctionGetInputDim fun' idx = do
- size1 <- withForeignPtr (sxFunRaw fun') $ c_sxFunctionGetInputSize1 (fromIntegral idx)
- size2 <- withForeignPtr (sxFunRaw fun') $ c_sxFunctionGetInputSize2 (fromIntegral idx)
- return (fromIntegral size1, fromIntegral size2)
-
- sxFunctionGetOutputDim :: SXFunction -> Int -> IO (Int, Int)
- sxFunctionGetOutputDim fun' idx = do
- size1 <- withForeignPtr (sxFunRaw fun') $ c_sxFunctionGetOutputSize1 (fromIntegral idx)
- size2 <- withForeignPtr (sxFunRaw fun') $ c_sxFunctionGetOutputSize2 (fromIntegral idx)
- return (fromIntegral size1, fromIntegral size2)
-
- numInputs <- sxFunctionNumInputs fun
- numOutputs <- sxFunctionNumOutputs fun
- inputDim <- mapM (sxFunctionGetInputDim fun) [0..numInputs - 1]
- outputDim <- mapM (sxFunctionGetOutputDim fun) [0..numOutputs - 1]
-
- let ret
- | numInputs /= sxFunNumInputs fun = error "checkSXFunctionDimensions got bad numInputs"
- | numOutputs /= sxFunNumOutputs fun = error "checkSXFunctionDimensions got bad numOutputs"
- | inputDim /= sxFunInputDims fun = error "checkSXFunctionDimensions got bad inputDim"
- | outputDim /= sxFunOutputDims fun = error "checkSXFunctionDimensions got bad outputDim"
- | otherwise = ()
- return $ ret `seq` ret
-
-
-sxFunctionGetInputsSX :: SXFunction -> Int -> SXMatrix
-{-# NOINLINE sxFunctionGetInputsSX #-}
-sxFunctionGetInputsSX fun idx = trace "why are you using sxFunctionGetInputsSX?" $ unsafePerformIO $ do
- if idx >= sxFunNumInputs fun
- then error $ printf "Error in sxFunctionGetInputsSX - requested input index: %d >= numInputs (SXFunction fun): %d" idx (sxFunNumInputs fun)
- else do return ()
-
- SXMatrix mat <- sxMatrixNewZeros (1::Int,1::Int)
- withForeignPtrs2 (\fun' mat' -> c_sxFunctionGetInputsSX fun' (fromIntegral idx) mat') (sxFunRaw fun) mat
- return (SXMatrix mat)
-
-
-sxFunctionGetOutputsSX :: SXFunction -> Int -> SXMatrix
-{-# NOINLINE sxFunctionGetOutputsSX #-}
-sxFunctionGetOutputsSX fun idx = trace "why are you using sxFunctionGetOutputsSX?" $ unsafePerformIO $ do
- if idx >= sxFunNumOutputs fun
- then error $ printf "Error in sxFunctionGetOutputsSX - requested output index: %d >= numOutputs (SXFunction fun): %d" idx (sxFunNumOutputs fun)
- else do return ()
-
- SXMatrix mat <- sxMatrixNewZeros (1::Int,1::Int)
- withForeignPtrs2 (\fun' mat' -> c_sxFunctionGetOutputsSX fun' (fromIntegral idx) mat') (sxFunRaw fun) mat
- return (SXMatrix mat)
-
-
------------------------ AD -----------------------
-sxFunctionGradientAt :: SXFunction -> Int -> SXMatrix
-{-# NOINLINE sxFunctionGradientAt #-}
-sxFunctionGradientAt fun idxInput = unsafePerformIO $ do
- -- don't take gradient with respect to non-existant input
- if idxInput >= sxFunNumInputs fun
- then error $ printf "Error in sxFunctionGradientAt - requested gradient index: %d >= numInputs fun: %d" idxInput (sxFunNumInputs fun)
- else do return ()
-
- -- don't take gradient of vector valued function
- if (1,1) /= head (sxFunOutputDims fun)
- then error $ printf "Error in sxFunctionGradientAt - requested gradient of non-scalar"
- else do return ()
-
- SXMatrix mat <- sxMatrixNewZeros (1::Int,1::Int)
- withForeignPtrs2 (\fun' mat' -> c_sxFunctionGradient fun' (fromIntegral idxInput) mat') (sxFunRaw fun) mat
- return $ (SXMatrix mat)
-
-
-sxFunctionGradients :: SXFunction -> [SXMatrix]
-sxFunctionGradients fun = map (sxFunctionGradientAt fun) $ take (sxFunNumInputs fun) [0..]
-
-
-sxFunctionJacobianAt :: SXFunction -> (Int, Int) -> SXMatrix
-{-# NOINLINE sxFunctionJacobianAt #-}
-sxFunctionJacobianAt fun (idx0, idx1) = unsafePerformIO $ do
- -- don't take jacobian with respect to non-existant output
- if idx0 >= sxFunNumOutputs fun
- then error $ printf "Error in sxFunctionJacobianAt - requested jacobian index: (%d,%d) is outside numOutputs fun: %d" idx0 idx1 (sxFunNumOutputs fun)
- else do return ()
-
- -- don't take jacobian with respect to non-existant input
- if idx1 >= sxFunNumOutputs fun
- then error $ printf "Error in sxFunctionJacobianAt - requested jacobian index: (%d,%d) is outside numInputs fun: %d" idx0 idx1 (sxFunNumInputs fun)
- else do return ()
-
- SXMatrix mat <- sxMatrixNewZeros (1::Int,1::Int)
- withForeignPtrs2 (\fun' mat' -> c_sxFunctionJacobian fun' (fromIntegral idx0) (fromIntegral idx1) mat') (sxFunRaw fun) mat
- return $ (SXMatrix mat)
-
-
-sxFunctionHessianAt :: SXFunction -> (Int, Int) -> SXMatrix
-{-# NOINLINE sxFunctionHessianAt #-}
-sxFunctionHessianAt fun (idx0, idx1) = unsafePerformIO $ do
- -- don't take hessian with respect to non-existant input
- if any (\x -> x >= sxFunNumInputs fun) [idx0, idx1]
- then error $ printf "Error in sxFunctionHessianAt - requested hessian index: (%d,%d) >= numInputs fun: %d" idx0 idx1 (sxFunNumInputs fun)
- else do return ()
-
- -- don't take hessian of vector valued function
- if (1,1) /= head (sxFunOutputDims fun)
- then error $ printf "Error in sxFunctionHessianAt - requested hessian of non-scalar"
- else do return ()
-
- SXMatrix mat <- sxMatrixNewZeros (1::Int,1::Int)
- withForeignPtrs2 (\fun' mat' -> c_sxFunctionHessian fun' (fromIntegral idx0) (fromIntegral idx1) mat') (sxFunRaw fun) mat
- return $ (SXMatrix mat)
-
-
---------------------- evaluate -----------------------------
-sxFunctionEvaluate :: forall a b c. Matrix a b c => SXFunction -> [a] -> [a]
-{-# NOINLINE sxFunctionEvaluate #-}
-sxFunctionEvaluate fun inputs = unsafePerformIO $ do
- do if (map size inputs) == sxFunInputDims fun
- then do return ()
- else error "sxFunctionEvaluate got bad input dimensions"
-
- let unsafeInputPtrs :: [Ptr c]
- unsafeInputPtrs = map (\m -> unsafeForeignPtrToPtr (getForeignPtr m)) inputs
-
- outputs <- mapM newZeros (sxFunOutputDims fun)
- let unsafeOutputPtrs :: [Ptr c]
- unsafeOutputPtrs = map (\m -> unsafeForeignPtrToPtr (getForeignPtr m)) outputs
-
- inputPtrArray <- newArray unsafeInputPtrs
- outputPtrArray <- newArray unsafeOutputPtrs
-
- let nIn = fromIntegral $ length inputs
- nOut = fromIntegral$ length outputs
-
- let eval :: CInt -> Ptr (Ptr c) -> CInt -> Ptr (Ptr c) -> Ptr SXFunctionRaw -> IO ()
- eval = c_sxFunctionEvaluate (head inputs)
- withForeignPtr (sxFunRaw fun) (eval nIn inputPtrArray nOut outputPtrArray)
-
- mapM_ (\d -> touchForeignPtr (getForeignPtr d)) inputs
- mapM_ (\d -> touchForeignPtr (getForeignPtr d)) outputs
-
- return outputs
-
-
-sxFunctionEvaluateLists :: SXFunction -> [[[Double]]] -> [[[Double]]]
-{-# NOINLINE sxFunctionEvaluateLists #-}
-sxFunctionEvaluateLists fun inputs = unsafePerformIO $ do
- let outNew = map toLists $ sxFunctionEvaluate fun $ (map fromLists inputs :: [DMatrix])
- return outNew
-
-
-getMd5 :: String -> IO String
-getMd5 filename = do
- (_, hStdout, _, p) <- runInteractiveCommand $ "md5sum " ++ filename
- exitCode <- waitForProcess p
- md5Out <- hGetContents hStdout
- if exitCode == ExitSuccess
- then do return $ head (lines md5Out)
- else do error $ "getMd5 couldn't read \"" ++ filename ++ "\""
-
----------------------- code gen ---------------------
-sxFunctionCompile :: SXFunction -> String -> ([DMatrix] -> [DMatrix])
-sxFunctionCompile fun name = unsafePerformIO $ do
-
- let srcname = name ++ ".c"
- objname = name ++ ".so"
- hashname = name ++ ".so.md5"
-
- cSrcname <- newCString srcname
- cObjname <- newCString objname
-
- let funPtr = unsafeForeignPtrToPtr (sxFunRaw fun)
-
- -- generate code
- genTime <- c_generateCCode cSrcname funPtr
- touchForeignPtr (sxFunRaw fun)
- putStrLn $ "Generated " ++ srcname ++ " in " ++ show (realToFrac genTime::Double) ++ " seconds"
-
- -- check md5
- let getOldMd5 = do catch (readFile ("./" ++ hashname)) $ \_ -> do return $ hashname ++ " does not exist"
-
- oldMd5 <- getOldMd5
- newMd5 <- getMd5 srcname
- putStrLn $ "oldMd5: \"" ++ oldMd5 ++ "\""
- putStrLn $ "newMd5: \"" ++ newMd5 ++ "\""
-
- if oldMd5 /= newMd5
- -- compile new object
- then do let compileString = "gcc -O1 -fPIC -shared " ++ srcname ++ " -o " ++ objname
- putStrLn compileString
- p <- runCommand compileString
- exitCode <- timeComputation "compiled in " $ waitForProcess p
- if exitCode /= ExitSuccess
- then do error "compilation failure"
- else do writeFile ("./" ++ hashname) newMd5
- return ()
- -- use old object
- else do putStrLn $ "md5 of " ++ srcname ++ " matches " ++ hashname ++ " - reusing " ++ objname
-
- extFun <- c_createExternalFunction cObjname >>= newForeignPtr c_sxFunctionDelete
-
- return $ sxFunctionEvaluate $ SXFunction { sxFunRaw = extFun
- , sxFunNumInputs = sxFunNumInputs fun
- , sxFunNumOutputs = sxFunNumOutputs fun
- , sxFunInputDims = sxFunInputDims fun
- , sxFunOutputDims = sxFunOutputDims fun}
-
-timeComputation :: String -> IO t -> IO t
-timeComputation msg a = do
- start <- getCurrentTime
- v <- a
- end <- getCurrentTime
- let diffTime = (realToFrac $ diffUTCTime end start)::Double
- putStrLn $ msg ++ show diffTime ++ " seconds"
- return v
View
5 Casadi/SXFunctionRaw.hs
@@ -1,5 +0,0 @@
--- SXFunctionRaw.hs
-
-module Casadi.SXFunctionRaw( SXFunctionRaw(..) ) where
-
-data SXFunctionRaw = SXFunctionRaw
View
365 Casadi/SXMatrix.hs
@@ -1,365 +0,0 @@
--- SXMatrix.hs
-
-{-# OPTIONS_GHC -Wall #-}
---{-# OPTIONS_GHC -Wall -fno-cse -fno-full-laziness #-}
-{-# LANGUAGE ForeignFunctionInterface, MultiParamTypeClasses #-}
-
-module Casadi.SXMatrix
- (
- SXMatrix(..)
- , SXMatrixRaw(..)
- , sxMatrixCreateSymbolic
- , sxMatrixNewZeros
- , gradient
- , hessian
- , jacobian
- ) where
-
-import Casadi.SX
-import Casadi.SXFunctionRaw
-import Casadi.CasadiInterfaceUtils
-import Casadi.Matrix
-
-import Foreign.C
-import Foreign.ForeignPtr
-import Foreign.Ptr
-import Control.Exception(mask_)
-import System.IO.Unsafe(unsafePerformIO)
-import Control.DeepSeq
-
--- the SXMatrix data type
-data SXMatrixRaw = SXMatrixRaw
-newtype SXMatrix = SXMatrix (ForeignPtr SXMatrixRaw)
-
-instance NFData SXMatrix where
- rnf x = x `seq` ()
-
--- foreign imports
-foreign import ccall unsafe "sxMatrixCreateSymbolic" c_sxMatrixCreateSymbolic
- :: Ptr CChar -> CInt -> CInt -> IO (Ptr SXMatrixRaw)
-foreign import ccall unsafe "sxMatrixDuplicate" c_sxMatrixDuplicate
- :: (Ptr SXMatrixRaw) -> IO (Ptr SXMatrixRaw)
-foreign import ccall unsafe "&sxMatrixDelete" c_sxMatrixDelete
- :: FunPtr (Ptr SXMatrixRaw -> IO ())
-foreign import ccall unsafe "sxMatrixZeros" c_sxMatrixZeros
- :: CInt -> CInt -> IO (Ptr SXMatrixRaw)
-foreign import ccall unsafe "sxMatrixShow" c_sxMatrixShow
- :: Ptr CChar -> CInt -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixAt" c_sxMatrixAt
- :: (Ptr SXMatrixRaw) -> CInt -> CInt -> (Ptr SXRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixSet" c_sxMatrixSet
- :: (Ptr SXRaw) -> CInt -> CInt -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixSize1" c_sxMatrixSize1
- :: (Ptr SXMatrixRaw) -> IO CInt
-foreign import ccall unsafe "sxMatrixSize2" c_sxMatrixSize2
- :: (Ptr SXMatrixRaw) -> IO CInt
-foreign import ccall unsafe "sxMatrixPlus" c_sxMatrixPlus
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixMinus" c_sxMatrixMinus
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixNegate" c_sxMatrixNegate
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMM" c_sxMM
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixTranspose" c_sxMatrixTranspose
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixIsEqual" c_sxMatrixIsEqual
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO CInt
-foreign import ccall unsafe "sxMatrixScale" c_sxMatrixScale
- :: (Ptr SXRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "sxMatrixInv" c_sxMatrixInv
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-
-foreign import ccall unsafe "myGradient" c_myGradient
- :: (Ptr SXRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "myHessian" c_myHessian
- :: (Ptr SXRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-foreign import ccall unsafe "myJacobian" c_myJacobian
- :: (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> (Ptr SXMatrixRaw) -> IO ()
-
-foreign import ccall unsafe "sxFunctionEvaluateSXMatrix" c_sxFunctionEvaluateSXMatrix
- :: CInt -> Ptr (Ptr SXMatrixRaw) -> CInt -> Ptr (Ptr SXMatrixRaw) -> Ptr SXFunctionRaw -> IO ()
-
------------------ create -------------------------
-sxMatrixCreateSymbolic :: String -> (Int, Int) -> IO SXMatrix
-sxMatrixCreateSymbolic prefix (n,m) = mask_ $ do
- cPrefix <- newCString prefix
- mat <- c_sxMatrixCreateSymbolic cPrefix (fromIntegral n) (fromIntegral m) >>= newForeignPtr c_sxMatrixDelete
- return $ SXMatrix mat
-
-
-sxMatrixDuplicate :: SXMatrix -> IO SXMatrix
-sxMatrixDuplicate (SXMatrix old) = mask_ $ do
- new <- withForeignPtr old c_sxMatrixDuplicate >>= newForeignPtr c_sxMatrixDelete
- return $ SXMatrix new
-
-
-sxMatrixNewZeros :: (Int, Int) -> IO SXMatrix
-sxMatrixNewZeros (n,m) = mask_ $ do
- let n' = safeToCInt n
- m' = safeToCInt m
-
- safeToCInt :: Int -> CInt
- safeToCInt x
- | and [toInteger x <= maxCInt, toInteger x >= minCInt] = fromIntegral x
- | otherwise = error "Error - sxMatrixNewZeros dimensions too big"
- where
- maxCInt = fromIntegral (maxBound :: CInt)
- minCInt = fromIntegral (minBound :: CInt)
-
- mat <- c_sxMatrixZeros n' m' >>= newForeignPtr c_sxMatrixDelete
- return $ SXMatrix mat
-
-
-sxMatrixZeros :: (Int,Int) -> SXMatrix
-{-# NOINLINE sxMatrixZeros #-}
-sxMatrixZeros dim = unsafePerformIO $ do
- mat <- sxMatrixNewZeros dim
- return mat
-
-
-sxMatrixFromList :: [SX] -> SXMatrix
-{-# NOINLINE sxMatrixFromList #-}
-sxMatrixFromList sxList = unsafePerformIO $ do
- m0 <- sxMatrixNewZeros (length sxList, 1)
- let indexedSXList = zip sxList $ take (length sxList) [0..]
- return $ foldl (\acc (sx,idx) -> sxMatrixSet acc (idx,0) sx) m0 indexedSXList
-
-
----------------- show -------------------
-sxMatrixShow :: SXMatrix -> String
-{-# NOINLINE sxMatrixShow #-}
-sxMatrixShow (SXMatrix s) = unsafePerformIO $ do
- (stringRef, stringLength) <- newCStringLen $ replicate 4096 ' '
- withForeignPtr s $ c_sxMatrixShow stringRef (fromIntegral stringLength)
- peekCString stringRef
-
-
---------------- getters/setters ---------------------
-sxMatrixAt :: SXMatrix -> (Int,Int) -> IO SX
-sxMatrixAt (SXMatrix matIn) (n,m) = do
- SX sxOut <- sxNewInt 0
- withForeignPtrs2 (\matIn' sxOut' -> c_sxMatrixAt matIn' (fromIntegral n) (fromIntegral m) sxOut') matIn sxOut
- return (SX sxOut)
-
-sxMatrixSet :: SXMatrix -> (Int,Int) -> SX -> SXMatrix
-{-# NOINLINE sxMatrixSet #-}
-sxMatrixSet (SXMatrix matIn) (n,m) (SX val) = unsafePerformIO $ do
- SXMatrix matOut <- sxMatrixDuplicate (SXMatrix matIn)
- let n' = fromIntegral n
- m' = fromIntegral m
- withForeignPtrs2 (\val' matOut' -> c_sxMatrixSet val' n' m' matOut') val matOut
- return (SXMatrix matOut)
-
-
----------------- dimensions --------------------
-sxMatrixSize :: SXMatrix -> (Int,Int)
-{-# NOINLINE sxMatrixSize #-}
-sxMatrixSize (SXMatrix matIn) = unsafePerformIO $ do
- n <- withForeignPtr matIn c_sxMatrixSize1
- m <- withForeignPtr matIn c_sxMatrixSize2
- return (fromIntegral n, fromIntegral m)
-
-
-sxMatrixToLists :: SXMatrix -> [[SX]]
-{-# NOINLINE sxMatrixToLists #-}
-sxMatrixToLists mat = unsafePerformIO $ do
- let f row = mapM (\col -> sxMatrixAt mat (row, col)) [0..m-1]
- (n,m) = sxMatrixSize mat
- mapM f [0..n-1]
-
-
--- turns n by 1 matrix into a list of SX, returns error if matrix is not n by 1
-sxMatrixToList :: SXMatrix -> [SX]
-{-# NOINLINE sxMatrixToList #-}
-sxMatrixToList mat = unsafePerformIO $ do
- let (n,m) = sxMatrixSize mat
- if m == 1
- then mapM (\row -> sxMatrixAt mat (row, 0)) [0..n-1]
- else error "sxMatrixToList can only be used on an n by 1 matrix"
-
-
-------------------------- math ---------------------------------
-sxMatrixPlus :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixPlus #-}
-sxMatrixPlus (SXMatrix m0) (SXMatrix m1) = unsafePerformIO $ do
- let size'
- | sizeM0 == sizeM1 = sizeM0
- | otherwise = error "sxMatrixPlus can't add matrices of different dimensions"
- where
- sizeM0 = sxMatrixSize (SXMatrix m0)
- sizeM1 = sxMatrixSize (SXMatrix m1)
- SXMatrix mOut <- sxMatrixNewZeros size'
- withForeignPtrs3 c_sxMatrixPlus m0 m1 mOut
- return $ SXMatrix mOut
-
-
-sxMatrixMinus :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixMinus #-}
-sxMatrixMinus (SXMatrix m0) (SXMatrix m1) = unsafePerformIO $ do
- let size'
- | sizeM0 == sizeM1 = sizeM0
- | otherwise = error "sxMatrixMinus can't add matrices of different dimensions"
- where
- sizeM0 = sxMatrixSize (SXMatrix m0)
- sizeM1 = sxMatrixSize (SXMatrix m1)
- SXMatrix mOut <- sxMatrixNewZeros size'
- withForeignPtrs3 c_sxMatrixMinus m0 m1 mOut
- return $ SXMatrix mOut
-
-
-sxMatrixNegate :: SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixNegate #-}
-sxMatrixNegate (SXMatrix m0) = unsafePerformIO $ do
- SXMatrix mOut <- sxMatrixNewZeros (sxMatrixSize (SXMatrix m0))
- withForeignPtrs2 c_sxMatrixNegate m0 mOut
- return $ SXMatrix mOut
-
-
-sxMM :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE sxMM #-}
-sxMM (SXMatrix m0) (SXMatrix m1) = unsafePerformIO $ do
- let size'
- | colsM0 == rowsM1 = (rowsM0, colsM1)
- | otherwise = error $ "sxMM can't multiply " ++ (show (size (SXMatrix m0))) ++ " matrix by " ++ (show (size (SXMatrix m1))) ++ " matrix"
- where
- (rowsM0, colsM0) = sxMatrixSize (SXMatrix m0)
- (rowsM1, colsM1) = sxMatrixSize (SXMatrix m1)
-
- SXMatrix mOut <- sxMatrixNewZeros size'
- withForeignPtrs3 c_sxMM m0 m1 mOut
- return $ SXMatrix mOut
-
-
-sxMatrixTranspose :: SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixTranspose #-}
-sxMatrixTranspose (SXMatrix mIn) = unsafePerformIO $ do
- SXMatrix mOut <- sxMatrixNewZeros $ sxMatrixSize (SXMatrix mIn)
- withForeignPtrs2 c_sxMatrixTranspose mIn mOut
- return $ SXMatrix mOut
-
-
-sxMatrixIsEqual :: SXMatrix -> SXMatrix -> Bool
-{-# NOINLINE sxMatrixIsEqual #-}
-sxMatrixIsEqual (SXMatrix m0) (SXMatrix m1) = unsafePerformIO $ do
- isEq <- withForeignPtrs2 c_sxMatrixIsEqual m0 m1
- if (isEq == 1)
- then
- return True
- else
- return False
-
-
-sxMatrixScale :: SX -> SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixScale #-}
-sxMatrixScale (SX scalar) (SXMatrix mIn) = unsafePerformIO $ do
- SXMatrix mOut <- sxMatrixNewZeros (1,1)
- withForeignPtrs3 c_sxMatrixScale scalar mIn mOut
- return $ SXMatrix mOut
-
-
-sxMatrixInv :: SXMatrix -> SXMatrix
-{-# NOINLINE sxMatrixInv #-}
-sxMatrixInv (SXMatrix mIn) = unsafePerformIO $ do
- SXMatrix mOut <- sxMatrixNewZeros (1,1)
- withForeignPtrs2 c_sxMatrixInv mIn mOut
- return $ SXMatrix mOut
-
-
-sxMatrixFromIntegral :: Integral a => a -> SXMatrix
-{-# NOINLINE sxMatrixFromIntegral #-}
-sxMatrixFromIntegral i = unsafePerformIO $ do
- s <- sxNewIntegral i
- return $ sxMatrixFromList [s]
-
-
------------------ ad -----------------------
-gradient :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE gradient #-}
-gradient expr (SXMatrix argsRaw) = unsafePerformIO $ do
- if (1,1) /= size expr
- then do error $ "error: can't take gradient of non-scalar, dimensions: " ++ show (size expr)
- else do let [(SX expRaw)] = toList expr
- SXMatrix mOut <- sxMatrixNewZeros (1,1)
- withForeignPtrs3 c_myGradient expRaw argsRaw mOut
- return $ (SXMatrix mOut)
-
-hessian :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE hessian #-}
-hessian expr (SXMatrix argsRaw) = unsafePerformIO $ do
- if (1,1) /= size expr
- then do error $ "error: can't take hessian of non-scalar, dimensions: " ++ show (size expr)
- else do let [(SX expRaw)] = toList expr
- SXMatrix mOut <- sxMatrixNewZeros (1,1)
- withForeignPtrs3 c_myHessian expRaw argsRaw mOut
- return $ (SXMatrix mOut)
-
-jacobian :: SXMatrix -> SXMatrix -> SXMatrix
-{-# NOINLINE jacobian #-}
-jacobian (SXMatrix expRaw) (SXMatrix argsRaw) = unsafePerformIO $ do
- SXMatrix mOut <- sxMatrixNewZeros (1,1)
- withForeignPtrs3 c_myJacobian expRaw argsRaw mOut
- return $ (SXMatrix mOut)
-
-
------------------ typeclass stuff ------------------
-instance Show SXMatrix where
- show m = f (sxMatrixSize m)
- where
- f (1,1) = "[[" ++ (sxMatrixShow m) ++ "]]"
- f (_,1) = sxMatrixShow m
- f (1,_) = init $ sxMatrixShow m
- f (_,_) = init $ sxMatrixShow m
-
-
-instance Eq SXMatrix where
- (==) = sxMatrixIsEqual
- (/=) sx0 sx1 = not $ sx0 == sx1
-
-
-instance Num SXMatrix where
- (+) = sxMatrixPlus
- (-) = sxMatrixMinus
- (*) m0 m1
- | sxMatrixSize m0 == (1,1) = sxMatrixScale s0 m1
- | sxMatrixSize m1 == (1,1) = sxMatrixScale s1 m0
- | otherwise = sxMM m0 m1
- where
- [s0] = sxMatrixToList m0
- [s1] = sxMatrixToList m1
-
- abs = error "abs not defined for instance Num SXMatrix"
- signum = error "signum not defined for instance Num SXMatrix"
-
- fromInteger = sxMatrixFromIntegral
-
- negate = sxMatrixNegate
-
-
-instance Fractional SXMatrix where
- (/) m0 m1 = m0 * (recip m1)
- recip mat = sxMatrixInv mat
- fromRational x = sxMatrixFromList [fromRational x :: SX]
-
-
-instance Matrix SXMatrix SX SXMatrixRaw where
- trans = sxMatrixTranspose
- size = sxMatrixSize
- rows = fst . sxMatrixSize
- cols = snd . sxMatrixSize
- toList = sxMatrixToList
- toLists = sxMatrixToLists
- fromList = sxMatrixFromList
- fromLists = error "sxMatrixFromLists not yet implemented"
- vertcat mats = fromList $ concat $ map toList mats
- inv = sxMatrixInv
- scale = sxMatrixScale
- zeros = sxMatrixZeros
-
- c_sxFunctionEvaluate _ = c_sxFunctionEvaluateSXMatrix
- getForeignPtr (SXMatrix r) = r
- newZeros = sxMatrixNewZeros
-
-instance Boundable SXMatrix where
- bound xs (lbs, ubs) = fromList $ zipWith bound (toList xs) (zip (toList lbs) (toList ubs))
-
View
58 Casadi/Utils.hs
@@ -1,58 +0,0 @@
--- Utils.hs
-
-{-# OPTIONS_GHC -Wall #-}
-
-module Casadi.Utils
- (
- getDerivs
- , timeComputation
- ) where
-
-import Casadi.SX
-import Casadi.SXMatrix
-import Casadi.Matrix
-import Casadi.SXFunction
-import qualified Numeric.LinearAlgebra as LA
-import Data.Time.Clock
-
-getDerivs :: ([SX] -> SX) -> Int
- -> IO (LA.Vector Double -> Double, LA.Vector Double -> LA.Vector Double, LA.Vector Double -> LA.Matrix Double)
-getDerivs f n = do
-
- x <- sxMatrixCreateSymbolic "x" (n,1)
-
- let xSX = toList x
- sxFun = sxFunctionCreate [x] [fromList [f xSX]]
-
- gradSX = sxFunctionGradientAt sxFun 0
- hessSX = sxFunctionHessianAt sxFun (0,0)
-
- sxGrad = sxFunctionCreate [x] [gradSX]
- sxHess = sxFunctionCreate [x] [hessSX]
-
- return (evalF sxFun, evalG sxGrad, evalH sxHess)
-
-evalF :: SXFunction -> LA.Vector Double -> Double
-evalF sxFun xVec = output
- where
- [[[output]]] = sxFunctionEvaluateLists sxFun [[LA.toList xVec]]
-
-evalG :: SXFunction -> LA.Vector Double -> LA.Vector Double
-evalG sxGrad xVec = LA.fromList (concat output)
- where
- [output] = sxFunctionEvaluateLists sxGrad [[LA.toList xVec]]
-
-evalH :: SXFunction -> LA.Vector Double -> LA.Matrix Double
-evalH sxHess xVec = LA.fromLists output
- where
- [output] = sxFunctionEvaluateLists sxHess [[LA.toList xVec]]
-
-
-timeComputation :: String -> IO t -> IO t
-timeComputation msg a = do
- start <- getCurrentTime
- v <- a
- end <- getCurrentTime
- let diffTime = (realToFrac $ diffUTCTime end start)::Double
- putStrLn $ msg ++ show diffTime ++ " seconds"
- return v
View
39 NLP/Ipopt.hs
@@ -1,39 +0,0 @@
--- Ipopt.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE ForeignFunctionInterface #-}
-
-module NLP.Ipopt
- (
- Ipopt(..)
- , IpoptExactHessian(..)
- ) where
-
-import Casadi.SX
-import Casadi.SXMatrix
-import Casadi.DMatrix
-import NLP.NLP
-
-import Foreign.C
-import Foreign.Ptr
-
-data Ipopt = Ipopt
-data IpoptExactHessian = IpoptExactHessian
-
-instance NLPRaw Ipopt where
- c_createSolver = c_ipoptSolverCreate
- c_deleteSolver = c_ipoptSolverDelete
- c_solve = c_ipoptSolverSolve
-instance NLPRaw IpoptExactHessian where
- c_createSolver = c_ipoptSolverCreateExactHessian
- c_deleteSolver = c_ipoptSolverDeleteExactHessian
- c_solve = c_ipoptSolverSolveExactHessian
-
-
--- foreign imports
-foreign import ccall unsafe "ipoptSolverCreate" c_ipoptSolverCreate :: Ptr SXMatrixRaw -> Ptr SXRaw -> Ptr SXMatrixRaw -> IO (Ptr Ipopt)
-foreign import ccall unsafe "ipoptSolverCreateExactHessian" c_ipoptSolverCreateExactHessian :: Ptr SXMatrixRaw -> Ptr SXRaw -> Ptr SXMatrixRaw -> IO (Ptr IpoptExactHessian)
-foreign import ccall unsafe "&ipoptSolverDelete" c_ipoptSolverDelete :: FunPtr (Ptr Ipopt -> IO ())
-foreign import ccall unsafe "&ipoptSolverDelete" c_ipoptSolverDeleteExactHessian :: FunPtr (Ptr IpoptExactHessian -> IO ())
-foreign import ccall unsafe "ipoptSolverSolve" c_ipoptSolverSolve :: Ptr Ipopt -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> IO CDouble
-foreign import ccall unsafe "ipoptSolverSolve" c_ipoptSolverSolveExactHessian :: Ptr IpoptExactHessian -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> IO CDouble
View
72 NLP/NLP.hs
@@ -1,72 +0,0 @@
--- NLP.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE ForeignFunctionInterface, MultiParamTypeClasses #-}
-
-module NLP.NLP
- (
- NLPSolver(..)
- , NLPRaw(..)
- , createSolver
- , solveNlp
- ) where
-
-import Casadi
-import Casadi.SX
-import Casadi.SXMatrix
-import Casadi.DMatrix
-import Casadi.CasadiInterfaceUtils
-
-import Foreign.C
-import Foreign.Ptr
-import Foreign.ForeignPtr
-import Control.Exception(mask_)
-import Text.Printf
-
-
-data NLPSolver a = NLPSolver { nlpPtr :: (ForeignPtr a)
- , nInputs :: Int
- , nConstraints :: Int
- }
-
-class NLPRaw a where
- c_createSolver :: Ptr SXMatrixRaw -> Ptr SXRaw -> Ptr SXMatrixRaw -> IO (Ptr a)
- c_deleteSolver :: FunPtr (Ptr a -> IO ())
- c_solve :: Ptr a -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> IO CDouble
-
-createSolver :: NLPRaw a => a -> SXMatrix -> SX -> SXMatrix -> IO (NLPSolver a)
-createSolver _ (SXMatrix inputs) (SX objFun) (SXMatrix constraints) = mask_ $ do
- ss <- withForeignPtrs3 c_createSolver inputs objFun constraints >>= newForeignPtr c_deleteSolver
- return $ NLPSolver { nlpPtr = ss, nInputs = rows (SXMatrix inputs), nConstraints = rows (SXMatrix constraints) }
-
-
-solveNlp :: NLPRaw a => NLPSolver a -> DMatrix -> (DMatrix,DMatrix) -> (DMatrix,DMatrix) -> IO (DMatrix, Double)
-solveNlp nlp
- x0'@(DMatrix x0Raw)
- (xLb'@(DMatrix xLbRaw), xUb'@(DMatrix xUbRaw))
- (gLb'@(DMatrix gLbRaw), gUb'@(DMatrix gUbRaw)) =
- if any (\x -> (nInputs nlp) /= rows x) [x0',xLb',xUb'] || any (\x -> (nConstraints nlp) /= rows x) [gLb',gUb']
- then error $ printf ("\nError - Bad dimensions in ipoptSolve\n" ++
- " Solve call saw nx = %d, nxLb = %d, nxUb = %d, ngLb = %d, ngUb = %d\n" ++
- " Solver has %d inputs and %d nonlcons")
- (rows x0') (rows xLb') (rows xUb') (rows gLb') (rows gUb') (nInputs nlp) (nConstraints nlp)
- else do
- DMatrix solRaw <- dMatrixNewZeros $ size (DMatrix x0Raw)
-
- let x0 = unsafeForeignPtrToPtr x0Raw
- xLb = unsafeForeignPtrToPtr xLbRaw
- xUb = unsafeForeignPtrToPtr xUbRaw
- gLb = unsafeForeignPtrToPtr gLbRaw
- gUb = unsafeForeignPtrToPtr gUbRaw
- sol = unsafeForeignPtrToPtr solRaw
-
- optVal <- withForeignPtr (nlpPtr nlp) (\s -> c_solve s x0 xLb xUb gLb gUb sol)
-
- touchForeignPtr x0Raw
- touchForeignPtr xLbRaw
- touchForeignPtr xUbRaw
- touchForeignPtr gLbRaw
- touchForeignPtr gUbRaw
- touchForeignPtr solRaw
-
- return (DMatrix solRaw, realToFrac optVal)
View
29 NLP/Snopt.hs
@@ -1,29 +0,0 @@
--- Snopt.hs
-
-{-# OPTIONS_GHC -Wall #-}
-{-# LANGUAGE ForeignFunctionInterface #-}
-
-module NLP.Snopt
- (
- Snopt(..)
- ) where
-
-import Casadi.SX
-import Casadi.SXMatrix
-import Casadi.DMatrix
-import NLP.NLP
-
-import Foreign.C
-import Foreign.Ptr
-
-data Snopt = Snopt
-
-instance NLPRaw Snopt where
- c_createSolver = c_snoptSolverCreate
- c_deleteSolver = c_snoptSolverDelete
- c_solve = c_snoptSolverSolve
-
--- foreign imports
-foreign import ccall unsafe "snoptSolverCreate" c_snoptSolverCreate :: Ptr SXMatrixRaw -> Ptr SXRaw -> Ptr SXMatrixRaw -> IO (Ptr Snopt)
-foreign import ccall unsafe "&snoptSolverDelete" c_snoptSolverDelete :: FunPtr (Ptr Snopt -> IO ())
-foreign import ccall unsafe "snoptSolverSolve" c_snoptSolverSolve :: Ptr Snopt -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> Ptr DMatrixRaw -> IO CDouble
View
38 README
@@ -2,44 +2,10 @@ Hopt-o-mex is currently an incomplete Haskell port of my "opt-o-mex" c++ optimal
So far I've implemented euler/rk4 integrators, dynamics linearization functions, differential dynamic programming, a visualizer, and some examples.
-Automatic differentiation is currently done using two packages: Edward Kmett's pure Haskell "ad" package, and a C++ project called "CasADi" which I wrote haskell bindings for.
+Automatic differentiation is currently done using "hcasadi," my haskell bindings to the CasADi C++ AD library.
-
-################# INSTALLATION #################
-
-THIS WILL NOT BUILD OUT OF THE BOX. My cabal build foo is weak and some things are hard-coded to my local installation paths. Also you have to figure out how to build the CasADi project. If you are interested in Hopt-o-mex ping me and I'll help you out.
-
-
-#### HASKELL BINDINGS FOR CASADI:
-To build the CasADi bindings you need to install CasADi yourself, make a CASADI environment variable, then
-
- cd hopt-o-mex/hcasadi_cppsrc
- make
-
-You will need to go into casadi/CMakeLists.txt and set "WITH_JIT" to ON in order to use the SXFunction code-generation capabilities
-
-#### IPOPT
-Install ipopt on your own - their website has good instructions. Then you may have to append /usr/local/lib/coin and /usr/local/lib/coin/ThirdParty to your LD_LIBRARY_PATH env var.
-
-When you cabal configure it will warn you about missing ipopt/coinmumps/coinmetis libraries - you can ignore this.
-
-You might just want to comment out the whole ipoptTest executable in hopt-o-mex.cabal if you are having problems.
-
-
-### SNOPT
-add the SNOPT environment variable
-
-
-#### HOPT-O-MEX
-Currently you must edit hopt-o-mex.cabal and replace "/home/greg/hopt-o-mex" with wherever you put it. Fixing this is high on my todo list.
-
- cd hopt-o-mex
- cabal configure
- cabal build
-
-You will probably be prompted to install required dependencies.
+See hcasadi's README for installation difficulties
#### RUNNING
-After the build is complete:
For a quick demo of ddp run 'hopt-o-mex/dist/build/springDdp/springDdp'
For a quick and awesome demo of ddp/visualizer run 'hopt-o-mex/dist/build/cartpoleDdp/cartpoleDdp'
View
2 hcasadi_cppsrc/.gitignore
@@ -1,2 +0,0 @@
-libcasadi_interface.a
-*.o
View
41 hcasadi_cppsrc/Makefile
@@ -1,41 +0,0 @@
-# Hopt-o-mex casadi interface makefile
-# Greg Horn 2011
-
-Q = @
-CXX = g++
-
-LIB = libcasadi_interface.a
-
-SRC = \
- codegen.cpp \
- adInterface.cpp \
- dMatrixInterface.cpp \
- sxInterface.cpp \
- sxMatrixInterface.cpp \
- sxFunctionInterface.cpp \
- ipoptSolverInterface.cpp
-
-ifdef SNOPT
- SRC += \
- snoptSolverInterface.cpp \
- SnoptSolver.cpp
-endif
-
-OBJ = $(SRC:%.cpp=%.o)
-INCLUDES = -I$(CASADI) -I$(SNOPT) -I$(SNOPT)/include
-LDFLAGS =
-FLAGS = -O2 #-Wall -Wextra -Wshadow #-Werror #-g
-
-.PHONY: clean
-
-$(LIB): $(OBJ)
- @echo AR $@
- $(Q)ar rcs $(LIB) $(OBJ) $(LDFLAGS)
-
-%.o : %.cpp
- @echo CXX $@
- $(Q)$(CXX) $(FLAGS) $(INCLUDES) -c $< -o $@
-
-clean:
- rm -f $(LIB)
- rm -f $(OBJ)
View
425 hcasadi_cppsrc/SnoptSolver.cpp
@@ -1,425 +0,0 @@
-// SnoptSolver.cpp
-// Greg Horn
-
-#include <stdio.h>
-#include <string.h>
-#include <iostream>
-
-#include <cstdlib>
-
-#include "SnoptSolver.hpp"
-
-static SnoptSolver * si;
-
-using namespace std;
-using namespace CasADi;
-
-SnoptSolver::~SnoptSolver()
-{
- delete []iGfun;
- delete []jGvar;
-
- delete []iAfun;
- delete []jAvar;
- delete []A;
-
- delete []x;
- delete []xlow;
- delete []xupp;
- delete []xmul;
- delete []xstate;
-
- delete []F;
- delete []Flow;
- delete []Fupp;
- delete []Fmul;
- delete []Fstate;
- delete []Foffset;
-}
-
-void
-SnoptSolver::setGuess(const DMatrix & _xGuess){
- for (int k=0; k<n; k++)
- x[k] = _xGuess.indexed(k,0).at(0);
-}
-
-void
-SnoptSolver::setXBounds(const DMatrix & _xlb, const DMatrix & _xub){
- for (int k=0; k<n; k++){
- xlow[k] = _xlb.indexed(k,0).at(0);
- xupp[k] = _xub.indexed(k,0).at(0);
- }
-}
-
-void
-SnoptSolver::setFBounds(const DMatrix & _Flb, const DMatrix & _Fub){
- // set bound on objective function
- Flow[0] = -SNOPT_INFINITY;
- Fupp[0] = SNOPT_INFINITY;
-
- // set bound on nonlinear constraints
- for (int k=0; k<neF-1; k++){
- Flow[k+1] = _Flb.indexed(k,0).at(0);
- Fupp[k+1] = _Fub.indexed(k,0).at(0);
- }
-
- // correct for constant offset in F
- for (int k=0; k<neF; k++){
- Flow[k] -= Foffset[k];
- Fupp[k] -= Foffset[k];
- }
-
- // cout << endl;
- // for (int k=0; k<neF; k++)
- // cout << "Flow[" << k << "]: " << Flow[k] << ", Fupp[" << k << "]: " << Fupp[k] << endl;
-}
-
-double
-SnoptSolver::getSolution(DMatrix & _xOpt){
- for (int k=0; k<n; k++)
- _xOpt.indexed_assignment(k, 0, x[k]);
- return F[objRow - FIRST_FORTRAN_INDEX] + objAdd;
-}
-
-
-SnoptSolver::SnoptSolver(const SXMatrix & designVariables, const SX & objFun, const SXMatrix & constraints)
-{
- // make sure mempy is ok for copying doublereal to double
- if (sizeof(double) != sizeof(doublereal)){
- cout << "\n\n----------------------------------------------------\n";
- cout << "(cpp), sizeof(doublereal) != sizeof(double)\n";
- cout << "----------------------------------------------------\n\n";
- throw 1;
- }
-
- si = this;
- SXMatrix ftotal = vertcat( SXMatrix(objFun), constraints );
-
- SXFunction Ftotal(designVariables, ftotal);
- Ftotal.init();
-
- /************ design variables ************/
- n = Ftotal.input().size();
- x = new doublereal[n];
- xlow = new doublereal[n];
- xupp = new doublereal[n];
- xmul = new doublereal[n];
- xstate = new integer[n];
- for (int k=0; k<n; k++){
- x[k] = 0;
- xlow[k] = -SNOPT_INFINITY;
- xupp[k] = SNOPT_INFINITY;
- xmul[k] = 0;
- xstate[k] = 0;
- }
-
- /*********** objective/constraint functions ***********/
- neF = Ftotal.output().size();
- objRow = FIRST_FORTRAN_INDEX;
- F = new doublereal[neF];
- Flow = new doublereal[neF];
- Fupp = new doublereal[neF];
- Fmul = new doublereal[neF];
- Fstate = new integer[neF];
- Foffset = new doublereal[neF];
- for (int k=0; k<neF; k++){
- F[k] = 0;
- Flow[k] = -SNOPT_INFINITY;
- Fupp[k] = 0;
- Fmul[k] = 0;
- Fstate[k] = 0;
- Foffset[k] = 0;
- }
- Fupp[ objRow - FIRST_FORTRAN_INDEX ] = SNOPT_INFINITY;
-
- /****************** jacobian *********************/
- SXMatrix fnonlinear = ftotal;
-
- SXFunction gradF(Ftotal.jacobian());
-
- vector<int> rowind,col;
- gradF.output().sparsity().getSparsityCRS(rowind,col);
-
- // split jacobian into constant and nonconstant elements (linear and nonlinear parts of F)
- vector<doublereal> A_;
- vector<integer> iAfun_;
- vector<integer> jAvar_;
-
- vector<SX> G_;
- vector<integer> iGfun_;
- vector<integer> jGvar_;
-
- for(int r=0; r<rowind.size()-1; ++r)
- for(int el=rowind[r]; el<rowind[r+1]; ++el){
- SXMatrix jacobElem = gradF.outputSX().getElement(r, col[el]);
- jacobElem = evaluateConstants(jacobElem);
- if (jacobElem.at(0).isConstant()){
- A_.push_back( jacobElem.at(0).getValue() );
- iAfun_.push_back( r + FIRST_FORTRAN_INDEX );
- jAvar_.push_back( col[el] + FIRST_FORTRAN_INDEX );
-
- // subtract out linear part
- SXMatrix linearpart = jacobElem.at(0).getValue()*designVariables[col[el]];
- fnonlinear[r] -= linearpart.at(0);
- simplify(fnonlinear.at(r));
- } else {
- G_.push_back( gradF.outputSX().getElement(r, col[el]) );
- iGfun_.push_back( r + FIRST_FORTRAN_INDEX );
- jGvar_.push_back( col[el] + FIRST_FORTRAN_INDEX );
- }
- }
-
- // remove constants from fnonlinear
- fnonlinear = evaluateConstants(fnonlinear);
- for (int k=0; k<neF; k++){
- if (fnonlinear.at(k).isConstant()){
- Foffset[k] = fnonlinear.at(k).getValue();
- fnonlinear[k] -= fnonlinear.at(k);
- simplify(fnonlinear.at(k));
- }
- }
- objAdd = Foffset[objRow - FIRST_FORTRAN_INDEX];
-
- // nonlinear function
- Fnonlinear = SXFunction( designVariables, fnonlinear );
- Fnonlinear.init();
-
- // linear part
- neA = A_.size();
- lenA = neA;
- if (lenA == 0) lenA = 1;
-
- A = new doublereal[lenA];
- iAfun = new integer[lenA];
- jAvar = new integer[lenA];
-
- if (neA > 0){
- copy( A_.begin(), A_.end(), A);
- copy( iAfun_.begin(), iAfun_.end(), iAfun);
- copy( jAvar_.begin(), jAvar_.end(), jAvar);
- }
-
- // nonlinear part
- neG = G_.size();
- lenG = neG;
- if (lenG == 0) lenG = 1;
-
- iGfun = new integer[lenG];
- jGvar = new integer[lenG];
-
- if (neG > 0){
- copy( iGfun_.begin(), iGfun_.end(), iGfun);
- copy( jGvar_.begin(), jGvar_.end(), jGvar);
- }
-
- Gfcn = SXFunction( designVariables, G_ );
- Gfcn.init();
-
- // for (int k=0; k<neA; k++){
- // cout << "A[" << iAfun[k] << "," << jAvar[k] << "]: " << A[k] << endl;
- // }
- // cout << "\n";
-
- // for (int k=0; k<neG; k++){
- // cout << "G[" << iGfun[k] << "," << jGvar[k] << "]: " << Gfcn.outputSX().at(k) << endl;
- // }
- // cout << "\n";
-
- // cout << "Fnonlinear:\n";
- // for (int k=0; k<neF; k++)
- // cout << "F[" << k << "]: " << Fnonlinear.outputSX().at(k) << endl;
-}
-
-
-void
-SnoptSolver::solve()
-{
-// #define LENRW 20000
-// #define LENIW 10000
-#define LENCW 500
-
-#define LENRW 600000
-#define LENIW 150000
- //#define LENCW 5000
-
- integer minrw, miniw, mincw;
- integer lenrw = LENRW, leniw = LENIW, lencw = LENCW;
- doublereal rw[LENRW];
- integer iw[LENIW];
- char cw[8*LENCW];
-
- integer Cold = 0, Basis = 1, Warm = 2;
-
- integer INFO;
-
- integer nxname = 1, nFname = 1, npname;
- char xnames[1*8], Fnames[1*8];
- char Prob[200];
-
- integer iSpecs = 4, spec_len;
- integer iSumm = 6;
- integer iPrint = 9, prnt_len;
-
- char printname[200];
- char specname[200];
-
- integer nS, nInf;
- doublereal sInf;
- integer DerOpt, Major, strOpt_len;
- char strOpt[200];
-
- /* open output files using snfilewrappers.[ch] */
- sprintf(specname , "%s", "sntoya.spc"); spec_len = strlen(specname);
- sprintf(printname, "%s", "sntoya.out"); prnt_len = strlen(printname);
-
- /* Open the print file, fortran style */
- snopenappend_
- ( &iPrint, printname, &INFO, prnt_len );
-
- /* ================================================================== */
- /* First, sninit_ MUST be called to initialize optional parameters */
- /* to their default values. */
- /* ================================================================== */
-
- sninit_
- ( &iPrint, &iSumm, cw, &lencw, iw, &leniw, rw, &lenrw, 8*500 );
-
- strcpy(Prob,"snopta");
- npname = strlen(Prob);
- INFO = 0;
-
- /* Read in specs file (optional) */
- /* snfilewrapper_ will open the specs file, fortran style, */
- /* then call snspec_ to read in specs. */
-
- // snfilewrapper_
- // ( specname, &iSpecs, &INFO, cw, &lencw,
- // iw, &leniw, rw, &lenrw, spec_len, 8*lencw);
-
- // if( INFO != 101 )
- // {
- // printf("Warning: trouble reading specs file %s \n", specname);
- // }
-
-
- // sprintf(strOpt,"%s","Solution yes");
- // strOpt_len = strlen(strOpt);
- // snset_
- // ( strOpt, &iPrint, &iSumm, &INFO,
- // cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- /* ------------------------------------------------------------------ */
- /* Tell SnoptA that userfg computes derivatives. */
- /* ------------------------------------------------------------------ */
-
- DerOpt = 1;
- sprintf(strOpt,"%s","Derivative option");
- strOpt_len = strlen(strOpt);
- snseti_
- ( strOpt, &DerOpt, &iPrint, &iSumm, &INFO,
- cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- Major = 250;
- //Major = 2500;
- strcpy( strOpt,"Major Iterations limit");
- strOpt_len = strlen(strOpt);
- snseti_
- ( strOpt, &Major, &iPrint, &iSumm, &INFO,
- cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
-
- integer Minor = 1000;
- strcpy( strOpt,"Minor Iterations limit");
- strOpt_len = strlen(strOpt);
- snseti_
- ( strOpt, &Minor, &iPrint, &iSumm, &INFO,
- cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- //integer Niter = 100000;
- integer Niter = 10000;
- strcpy( strOpt,"Iterations limit");
- strOpt_len = strlen(strOpt);
- snseti_
- ( strOpt, &Niter, &iPrint, &iSumm, &INFO,
- cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- doublereal major_opt_tol = 1e-2;
- strcpy(strOpt,"Major optimality tolerance");
- strOpt_len = strlen(strOpt);
- snsetr_
- ( strOpt, &major_opt_tol, &iPrint, &iSumm, &INFO,
- cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- // integer verifyLevel = 3;
- // strcpy( strOpt,"Verify level");
- // strOpt_len = strlen(strOpt);
- // snseti_
- // ( strOpt, &verifyLevel, &iPrint, &iSumm, &INFO,
- // cw, &lencw, iw, &leniw, rw, &lenrw, strOpt_len, 8*500 );
-
- /* ------------------------------------------------------------------ */
- /* Solve the problem */
- /* ------------------------------------------------------------------ */
- snopta_
- ( &Cold, &neF, &n, &nxname, &nFname,
- &objAdd, &objRow, Prob, (U_fp)userfcn,
- iAfun, jAvar, &lenA, &neA, A,
- iGfun, jGvar, &lenG, &neG,
- xlow, xupp, xnames, Flow, Fupp, Fnames,
- x, xstate, xmul, F, Fstate, Fmul,
- &INFO, &mincw, &miniw, &minrw,
- &nS, &nInf, &sInf,
- cw, &lencw, iw, &leniw, rw, &lenrw,
- cw, &lencw, iw, &leniw, rw, &lenrw,
- npname, 8*nxname, 8*nFname,
- 8*500, 8*500);
-
- // extern int snopta_
- // ( integer *start, integer *nef, integer *n, integer *nxname, integer *nfname,
- // doublereal *objadd, integer *objrow, char *prob, U_fp usrfun,
- // integer *iafun, integer *javar, integer *lena, integer *nea, doublereal *a,
- // integer *igfun, integer *jgvar, integer *leng, integer *neg,
- // doublereal *xlow, doublereal *xupp, char *xnames, doublereal *flow, doublereal *fupp, char *fnames,
- // doublereal *x, integer *xstate, doublereal *xmul, doublereal *f, integer *fstate, doublereal *fmul,
- // integer *inform, integer *mincw, integer *miniw, integer *minrw,
- // integer *ns, integer *ninf, doublereal *sinf,
- // char *cu, integer *lencu, integer *iu, integer *leniu, doublereal *ru, integer *lenru,
- // char *cw, integer *lencw, integer *iw, integer *leniw, doublereal *rw, integer *lenrw,
- // ftnlen prob_len, ftnlen xnames_len, ftnlen fnames_len,
- // ftnlen cu_len, ftnlen cw_len );
-
- snclose_( &iPrint );
-// snclose_( &iSpecs );
-}
-
-
-int SnoptSolver::userfcn
-( integer *Status, integer *n, doublereal x[],
- integer *needF, integer *neF, doublereal F[],
- integer *needG, integer *neG, doublereal G[],
- char *cu, integer *lencu,
- integer iu[], integer *leniu,
- doublereal ru[], integer *lenru )
-{
- if( *needF > 0 ) {
- si->Fnonlinear.setInput(x);
- si->Fnonlinear.evaluate();
- si->Fnonlinear.getOutput(F);
-
- // cout << endl;
- // for (int k=0; k<*neF; k++)
- // cout << "F[" << k << "]: " << F[k] << endl;
- }
-
- if( *needG > 0 ){
- si->Gfcn.setInput(x);
- si->Gfcn.evaluate();
- si->Gfcn.getOutput(G);
-
- // cout << endl;
- // for (int k=0; k<*neG; k++)
- // cout << "G[" << k << "]: " << G[k] << endl;
- }
-
- return 0;
-}
View
90 hcasadi_cppsrc/SnoptSolver.hpp
@@ -1,90 +0,0 @@
-// SnoptSolver.hpp
-// Greg Horn
-
-#pragma once
-
-#include <casadi/sx/sx_tools.hpp>
-#include <casadi/fx/fx_tools.hpp>
-#include <casadi/stl_vector_tools.hpp>
-#include <casadi/fx/sx_function.hpp>
-#include <casadi/sx/sx.hpp>
-
-using namespace std;
-using namespace CasADi;
-
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-#include <cexamples/snopt.h>
-#include <cexamples/snfilewrapper.h>
-
-#ifdef __cplusplus
-}
-#endif
-
-#define FIRST_FORTRAN_INDEX 1
-#define SNOPT_INFINITY 1e25
-
-class SnoptSolver
-{
-public:
- ~SnoptSolver(void);
- SnoptSolver(const SXMatrix & designVariables, const SX & objFun, const SXMatrix & constraints);
-
- void setGuess(const DMatrix & _xGuess);
- void setXBounds(const DMatrix & _xlb, const DMatrix & _xub);
- void setFBounds(const DMatrix & _Flb, const DMatrix & _Fub);
- double getSolution(DMatrix & _xOpt);
-
- void solve(void);
-
-private:
- // function for nonlinear part of ftotal
- SXFunction Fnonlinear;
-