Skip to content

Commit

Permalink
Merge pull request #22 from tensorflow/embedding-lookup-fix
Browse files Browse the repository at this point in the history
Embedding lookup fix
  • Loading branch information
blackgnezdo committed Nov 9, 2016
2 parents 4ec78a8 + d9115c7 commit 9e005e3
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 18 deletions.
48 changes: 30 additions & 18 deletions tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ module TensorFlow.EmbeddingOps where

import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import Data.List (genericLength)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Ops () -- Num instance for Tensor
import TensorFlow.Ops (scalar, shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
import qualified TensorFlow.GenOps.Core as CoreOps
Expand Down Expand Up @@ -56,21 +55,34 @@ embeddingLookup :: forall a b v .
-> Tensor Value b
-- ^ A `Tensor` with type `int32` or `int64`
-- containing the ids to be looked up in `params`.
-- The ids are required to be flat on entry and have
-- fewer than 2^31 entries.
-- The ids are required to have fewer than 2^31
-- entries.
-> Build (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup params ids =
CoreOps.dynamicStitch pindices <$> partitionedResult
where np = genericLength params
pAssignments = CoreOps.cast (ids `CoreOps.mod` np)
newIds = ids `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size ids) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult = zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do
-- Do np separate lookups, finding embeddings for plist[p] in params[p]
partitionedResult <- zipWithM
(\p g -> colocateWith p $ render $ CoreOps.gather p g)
params gatherIds
let unshapedResult = CoreOps.dynamicStitch pindices partitionedResult
-- Shape restoration is not as optimal as it would be with client
-- side shape tracking.
paramShape <- colocateWith p0 (render (shape p0))
let finalShape = CoreOps.concat 0 [shape ids, tailShape]
tailShape = CoreOps.slice paramShape (singleton 1) (singleton (-1))
render $ CoreOps.reshape unshapedResult finalShape
where
-- Avoids genericLength here which would be evaluated by TF.
np = fromIntegral (length params)
flatIds = CoreOps.reshape ids (singleton (-1))
pAssignments = CoreOps.cast (flatIds `CoreOps.mod` np)
newIds = flatIds `CoreOps.div` np
originalIndices = CoreOps.range 0 (CoreOps.size flatIds) 1
-- Partition list of ids based on assignments into np separate lists
gatherIds = CoreOps.dynamicPartition np newIds pAssignments
-- Similarly, partition the original indices.
pindices = CoreOps.dynamicPartition np originalIndices pAssignments
singleton i = vector [i :: Int32]

embeddingLookup [] _ = error "embeddingLookup requires params to be non empty"
53 changes: 53 additions & 0 deletions tensorflow-ops/tests/EmbeddingOpsTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Google.Test (googleTest)
import TensorFlow.EmbeddingOps (embeddingLookup)
import Test.Framework.Providers.QuickCheck2 (testProperty)
import Test.HUnit ((@=?))
import Test.Framework.Providers.HUnit (testCase)
import Test.QuickCheck (Arbitrary(..), Property, choose, vectorOf)
import Test.QuickCheck.Monadic (monadicIO, run)

Expand All @@ -34,6 +35,56 @@ import qualified TensorFlow.Session as TF
import qualified TensorFlow.Tensor as TF
import qualified TensorFlow.Types as TF


buildAndRun = TF.runSession . TF.buildAnd TF.run

-- | Tries to perform a simple embedding lookup, with two partitions.
testEmbeddingLookupHasRightShapeWithPartition = testCase "testEmbeddingLookupHasRightShapeWithPartition" $ do
let shape = TF.Shape [1, 3] -- Consider a 3-dim embedding of two items.
let embedding1 = [ 1, 1, 1 ] :: [Int32]
let embedding2 = [ 0, 0, 0 ] :: [Int32]

let embedding = [ TF.constant shape embedding1
, TF.constant shape embedding2
]

let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup embedding ids

(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)

-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]

-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]


-- | Tries to perform a simple embedding lookup, with only a single partition.
testEmbeddingLookupHasRightShape = testCase "testEmbeddingLookupHasRightShape" $ do
let shape = TF.Shape [2, 3] -- Consider a 3-dim embedding of two items.
let embeddingInit = [ 1, 1, 1
, 0, 0, 0 ] :: [Int32]

let embedding = TF.constant shape embeddingInit
let idValues = [0, 1] :: [Int32]
let ids = TF.constant (TF.Shape [1, 2]) idValues
let op = embeddingLookup [embedding] ids

(values, shape) <- buildAndRun $ do
vs <- op
return (vs, TF.shape vs)

-- This is the shape that is returned in the equiv. Python.
shape @=? V.fromList [ 1, 2, 3 ]

-- "[0, 1]" should pull out the resulting vector.
values @=? V.fromList [ 1, 1, 1, 0, 0, 0 ]


-- Verifies that direct gather is the same as dynamic split into
-- partitions, followed by embedding lookup.
testEmbeddingLookupUndoesSplit :: forall a. (TF.TensorType a, Show a, Eq a)
Expand Down Expand Up @@ -85,4 +136,6 @@ main :: IO ()
main = googleTest
[ testProperty "EmbeddingLookupUndoesSplit"
(testEmbeddingLookupUndoesSplit :: LookupExample Double -> Property)
, testEmbeddingLookupHasRightShape
, testEmbeddingLookupHasRightShapeWithPartition
]

0 comments on commit 9e005e3

Please sign in to comment.