Skip to content

Commit

Permalink
Merge branch 'better-entity-def' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
snoyberg committed Dec 27, 2011
2 parents a5bf39c + c3dd94f commit e259bd3
Show file tree
Hide file tree
Showing 15 changed files with 867 additions and 540 deletions.
165 changes: 95 additions & 70 deletions persistent-postgresql/Database/Persist/Postgresql.hs
Expand Up @@ -15,6 +15,7 @@ import Database.Persist hiding (Update)
import Database.Persist.Base hiding (Add, Update)
import Database.Persist.GenericSql hiding (Key(..))
import Database.Persist.GenericSql.Internal
import Database.Persist.EntityDef

import qualified Database.HDBC as H
import qualified Database.HDBC.PostgreSQL as H
Expand Down Expand Up @@ -81,12 +82,12 @@ prepare' conn sql = do
, withStmt = withStmt' stmt
}

insertSql' :: RawName -> [RawName] -> Either Text (Text, Text)
insertSql' :: DBName -> [DBName] -> Either Text (Text, Text)
insertSql' t cols = Left $ pack $ concat
[ "INSERT INTO "
, escape t
, T.unpack $ escape t
, "("
, intercalate "," $ map escape cols
, intercalate "," $ map (T.unpack . escape) cols
, ") VALUES("
, intercalate "," (map (const "?") cols)
, ") RETURNING id"
Expand Down Expand Up @@ -141,22 +142,25 @@ pFromSql (H.SqlLocalTime d) = PersistUTCTime $ localTimeToUTC utc d
pFromSql x = PersistText $ pack $ H.fromSql x -- FIXME

migrate' :: PersistEntity val
=> (Text -> IO Statement)
=> [EntityDef]
-> (Text -> IO Statement)
-> val
-> IO (Either [Text] [(Bool, Text)])
migrate' getter val = do
let name = rawTableName $ entityDef val
old <- getColumns getter name
migrate' allDefs getter val = do
let name = entityDB $ entityDef val
old <- getColumns getter $ entityDef val
case partitionEithers old of
([], old'') -> do
let old' = partitionEithers old''
let new = mkColumns val
let new = second (map udToPair) $ mkColumns allDefs val
if null old
then do
let addTable = AddTable $ concat
[ "CREATE TABLE "
, escape name
, "(id SERIAL PRIMARY KEY UNIQUE"
, T.unpack $ escape name
, "("
, T.unpack $ escape $ entityID $ entityDef val
, " SERIAL PRIMARY KEY UNIQUE"
, concatMap (\x -> ',' : showColumn x) $ fst new
, ")"
]
Expand All @@ -172,25 +176,30 @@ migrate' getter val = do

data AlterColumn = Type SqlType | IsNull | NotNull | Add Column | Drop
| Default String | NoDefault | Update String
| AddReference RawName | DropReference RawName
type AlterColumn' = (RawName, AlterColumn)
| AddReference DBName | DropReference DBName
type AlterColumn' = (DBName, AlterColumn)

data AlterTable = AddUniqueConstraint RawName [RawName]
| DropConstraint RawName
data AlterTable = AddUniqueConstraint DBName [DBName]
| DropConstraint DBName

data AlterDB = AddTable String
| AlterColumn RawName AlterColumn'
| AlterTable RawName AlterTable
| AlterColumn DBName AlterColumn'
| AlterTable DBName AlterTable

-- | Returns all of the columns in the given table currently in the database.
getColumns :: (Text -> IO Statement)
-> RawName -> IO [Either Text (Either Column UniqueDef')]
getColumns getter name = do
stmt <- getter "SELECT column_name,is_nullable,udt_name,column_default FROM information_schema.columns WHERE table_name=? AND column_name <> 'id'"
cs <- withStmt stmt [PersistText $ pack $ unRawName name] helper
-> EntityDef
-> IO [Either Text (Either Column (DBName, [DBName]))]
getColumns getter def = do
stmt <- getter "SELECT column_name,is_nullable,udt_name,column_default FROM information_schema.columns WHERE table_name=? AND column_name <> ?"
let vals =
[ PersistText $ unDBName $ entityDB def
, PersistText $ unDBName $ entityID def
]
cs <- withStmt stmt vals helper
stmt' <- getter
"SELECT constraint_name, column_name FROM information_schema.constraint_column_usage WHERE table_name=? AND column_name <> 'id' ORDER BY constraint_name, column_name"
us <- withStmt stmt' [PersistText $ pack $ unRawName name] helperU
"SELECT constraint_name, column_name FROM information_schema.constraint_column_usage WHERE table_name=? AND column_name <> ? ORDER BY constraint_name, column_name"
us <- withStmt stmt' vals helperU
return $ cs ++ us
where
getAll pop front = do
Expand All @@ -202,22 +211,23 @@ getColumns getter name = do
Just _ -> getAll pop front -- FIXME error message?
helperU pop = do
rows <- getAll pop id
return $ map (Right . Right . (RawName . fst . head &&& map (RawName . snd)))
$ groupBy ((==) `on` fst) rows
return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
$ groupBy ((==) `on` fst)
$ map (T.pack *** T.pack) rows
helper pop = do
x <- pop
case x of
Nothing -> return []
Just x' -> do
col <- getColumn getter name x'
col <- getColumn getter (entityDB def) x'
let col' = case col of
Left e -> Left e
Right c -> Right $ Left c
cols <- helper pop
return $ col' : cols

getAlters :: ([Column], [UniqueDef'])
-> ([Column], [UniqueDef'])
getAlters :: ([Column], [(DBName, [DBName])])
-> ([Column], [(DBName, [DBName])])
-> ([AlterColumn'], [AlterTable])
getAlters (c1, u1) (c2, u2) =
(getAltersC c1 c2, getAltersU u1 u2)
Expand All @@ -226,6 +236,10 @@ getAlters (c1, u1) (c2, u2) =
getAltersC (new:news) old =
let (alters, old') = findAlters new old
in alters ++ getAltersC news old'

getAltersU :: [(DBName, [DBName])]
-> [(DBName, [DBName])]
-> [AlterTable]
getAltersU [] old = map (DropConstraint . fst) old
getAltersU ((name, cols):news) old =
case lookup name old of
Expand All @@ -239,7 +253,7 @@ getAlters (c1, u1) (c2, u2) =
: getAltersU news old'

getColumn :: (Text -> IO Statement)
-> RawName -> [PersistValue]
-> DBName -> [PersistValue]
-> IO (Either Text Column)
getColumn getter tname
[PersistByteString x, PersistByteString y,
Expand All @@ -250,10 +264,10 @@ getColumn getter tname
case getType $ bsToChars z of
Left s -> return $ Left s
Right t -> do
let cname = RawName $ bsToChars x
let cname = DBName $ T.pack $ bsToChars x
ref <- getRef cname
return $ Right $ Column cname (bsToChars y == "YES")
t d'' ref
t (fmap T.pack d'') ref
where
getRef cname = do
let sql = pack $ concat
Expand All @@ -266,11 +280,11 @@ getColumn getter tname
let ref = refName tname cname
stmt <- getter sql
withStmt stmt
[ PersistText $ pack $ unRawName tname
, PersistText $ pack $ unRawName ref
[ PersistText $ unDBName tname
, PersistText $ unDBName ref
] $ \pop -> do
Just [PersistInt64 i] <- pop
return $ if i == 0 then Nothing else Just (RawName "", ref)
return $ if i == 0 then Nothing else Just (DBName "", ref)
d' = case d of
PersistNull -> Right Nothing
PersistByteString a -> Right $ Just $ bsToChars a
Expand Down Expand Up @@ -306,7 +320,7 @@ findAlters col@(Column name isNull type_ def ref) cols =
(False, True) ->
let up = case def of
Nothing -> id
Just s -> (:) (name, Update s)
Just s -> (:) (name, Update $ T.unpack s)
in up [(name, NotNull)]
_ -> []
modType = if type_ == type_' then [] else [(name, Type type_)]
Expand All @@ -315,23 +329,23 @@ findAlters col@(Column name isNull type_ def ref) cols =
then []
else case def of
Nothing -> [(name, NoDefault)]
Just s -> [(name, Default s)]
Just s -> [(name, Default $ T.unpack s)]
in (modRef ++ modDef ++ modNull ++ modType,
filter (\c -> cName c /= name) cols)

showColumn :: Column -> String
showColumn (Column n nu t def ref) = concat
[ escape n
[ T.unpack $ escape n
, " "
, showSqlType t
, " "
, if nu then "NULL" else "NOT NULL"
, case def of
Nothing -> ""
Just s -> " DEFAULT " ++ s
Just s -> " DEFAULT " ++ T.unpack s
, case ref of
Nothing -> ""
Just (s, _) -> " REFERENCES " ++ escape s
Just (s, _) -> " REFERENCES " ++ T.unpack (escape s)
]

showSqlType :: SqlType -> String
Expand All @@ -354,106 +368,110 @@ showAlterDb (AlterColumn t (c, ac)) =
isUnsafe _ = False
showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)

showAlterTable :: RawName -> AlterTable -> String
showAlterTable :: DBName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ADD CONSTRAINT "
, escape cname
, T.unpack $ escape cname
, " UNIQUE("
, intercalate "," $ map escape cols
, intercalate "," $ map (T.unpack . escape) cols
, ")"
]
showAlterTable table (DropConstraint cname) = concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " DROP CONSTRAINT "
, escape cname
, T.unpack $ escape cname
]

showAlter :: RawName -> AlterColumn' -> String
showAlter :: DBName -> AlterColumn' -> String
showAlter table (n, Type t) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ALTER COLUMN "
, escape n
, T.unpack $ escape n
, " TYPE "
, showSqlType t
]
showAlter table (n, IsNull) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ALTER COLUMN "
, escape n
, T.unpack $ escape n
, " DROP NOT NULL"
]
showAlter table (n, NotNull) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ALTER COLUMN "
, escape n
, T.unpack $ escape n
, " SET NOT NULL"
]
showAlter table (_, Add col) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ADD COLUMN "
, showColumn col
]
showAlter table (n, Drop) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " DROP COLUMN "
, escape n
, T.unpack $ escape n
]
showAlter table (n, Default s) =
concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ALTER COLUMN "
, escape n
, T.unpack $ escape n
, " SET DEFAULT "
, s
]
showAlter table (n, NoDefault) = concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ALTER COLUMN "
, escape n
, T.unpack $ escape n
, " DROP DEFAULT"
]
showAlter table (n, Update s) = concat
[ "UPDATE "
, escape table
, T.unpack $ escape table
, " SET "
, escape n
, T.unpack $ escape n
, "="
, s
, " WHERE "
, escape n
, T.unpack $ escape n
, " IS NULL"
]
showAlter table (n, AddReference t2) = concat
[ "ALTER TABLE "
, escape table
, T.unpack $ escape table
, " ADD CONSTRAINT "
, escape $ refName table n
, T.unpack $ escape $ refName table n
, " FOREIGN KEY("
, escape n
, T.unpack $ escape n
, ") REFERENCES "
, escape t2
, T.unpack $ escape t2
]
showAlter table (_, DropReference cname) = concat
[ "ALTER TABLE "
, T.unpack (escape table)
, " DROP CONSTRAINT "
, T.unpack $ escape cname
]
showAlter table (_, DropReference cname) =
"ALTER TABLE " ++ escape table ++ " DROP CONSTRAINT " ++ escape cname

escape :: RawName -> String
escape (RawName s) =
'"' : go s ++ "\""
escape :: DBName -> Text
escape (DBName s) =
T.pack $ '"' : go (T.unpack s) ++ "\""
where
go "" = ""
go ('"':xs) = "\"\"" ++ go xs
Expand Down Expand Up @@ -498,3 +516,10 @@ safeRead name t = case reads s of
[] -> MLeft $ concat ["Invalid value for ", name, ": ", s]
where
s = T.unpack t

refName :: DBName -> DBName -> DBName
refName (DBName table) (DBName column) =
DBName $ T.concat [table, "_", column, "_fkey"]

udToPair :: UniqueDef -> (DBName, [DBName])
udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)

0 comments on commit e259bd3

Please sign in to comment.