Skip to content

Commit

Permalink
new loadData() API like tf style
Browse files Browse the repository at this point in the history
  • Loading branch information
huan committed Aug 28, 2019
1 parent dd0d5ee commit 8e628cd
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions Sources/MNIST/MNIST.swift
Expand Up @@ -19,11 +19,15 @@ public class MNIST {
}

// Split data into training and test
public func splitTrainTest() -> (
Tensor<Float>,
Tensor<Int32>,
Tensor<Float> ,
Tensor<Int32>
public func loadData() -> (
(
Tensor<Float>,
Tensor<Int32>
),
(
Tensor<Float>,
Tensor<Int32>
)
) {
let data = self.images!
let labels = self.labels!
Expand All @@ -37,7 +41,10 @@ public class MNIST {
let testX = data[split..<N]
let testY = labels[split..<N]

return (trainX, trainY, testX, testY)
return (
(trainX, trainY),
(testX, testY)
)
}

// report accuracy of a batch
Expand Down

0 comments on commit 8e628cd

Please sign in to comment.