In [None]:
local torch = require "torch"
local itorch = require "itorch"
require "cutorch"
require "nn"
require "cunn"
local cudnn = require "cudnn"
local data = require "data"
local helpers = require "helpers"

-- Load configuration
local model = 'out/2016-07-30-lr4/model_15000.t7'
local args = {model = model, conf = 'conf.json'}
local opts = helpers.opts(args)

-- Initialize and normalize training and validation data
local trainData = data.new(opts.train,
    opts.height, opts.width, opts.validationSubjects)
trainData:normalize(opts.mean, opts.std)
local validateData = data.new(opts.validate,
    opts.height, opts.width, opts.validationSubjects, true)
validateData:normalize(opts.mean, opts.std)

-- Load network from file
local net = torch.load(args.model)
local modelName = string.match(args.model, '(.*)%.t7$')
net:evaluate()

local function sample(dataset)
    -- Get a minibatch
    local batch, names = dataset:batch(opts.batchSize)
    local batchInputs = batch.inputs:cuda()
    local batchLabels = batch.labels:cuda()

    -- Forward pass and score
    local outputs = net:forward(batchInputs)
    local diceValue = helpers.dice(outputs, batchLabels)
    local _, predLabels = outputs:max(2):squeeze():double() - 1

    -- Output
    for j, name in pairs(names) do
        print(name, diceValue[j])
        itorch.image({batchInputs[j], predLabels[j], batchLabels - 1})
    end
    print()
    print("Mean Dice score for batch:", diceValue:mean())
end

In [None]:
sample(trainData)

In [None]:
sample(validateData)