In [14]:
import h5py
import caffe
import numpy as np

In [15]:
maxWords = 100
# dataPath = '/data/gengshan/pose_s2vt/hdf5/buffer_32_s2vt_100/small_train_batches/batch_0.h5'
dataPath = '/data/gengshan/pose_s2vt/hdf5/buffer_32_s2vt_100/train_2_batches/batch_0.h5'          # change h5 data path
LSTM_NET_FILE = './captioner.prototxt'                                                            # change net 
# MODEL_FILE = '/data/gengshan/pose_s2vt/snapshots/stored/s2vt_asl_pose_iter_11668.caffemodel'
MODEL_FILE = '/data/gengshan/pose_s2vt/snapshots/s2vt_asl_pose_iter_50000.caffemodel'             # change model path
caffe.set_mode_gpu()
caffe.set_device(0)
lstm_net = caffe.Net(LSTM_NET_FILE, MODEL_FILE, caffe.TEST)
vocabFile = '/data/gengshan/pose_s2vt/whole_vocabulary.txt'                                           # change vocab
UNK_IDENTIFIER = '<en_unk>'

In [16]:
def vocab_inds_to_sentence(vocab, inds):
    sentence = ' '.join([vocab[i] for i in inds])
    # Capitalize first character.
    sentence = sentence[0].upper() + sentence[1:]
    # Replace <EOS> with '.', or append '...'.
    if sentence.endswith(' ' + vocab[0]):
        sentence = sentence[:-(len(vocab[0]) + 1)] + '.'
    else:
        sentence += '...'
    return sentence

def init_vocab_from_file(vocabFilePath):
    # initialize the vocabulary with the UNK word
    vocabulary = {UNK_IDENTIFIER: 0}
    vocabulary_inverted = [UNK_IDENTIFIER]
    num_words_dataset = 0
    with open(vocabFilePath, 'r') as f:
        for line in f.readlines():
            split_line = line.split()
            word = split_line[0]
            # print word
            if word == UNK_IDENTIFIER:
                continue
            else:
                assert word not in vocabulary
            num_words_dataset += 1
            vocabulary[word] = len(vocabulary_inverted)
            vocabulary_inverted.append(word)
    num_words_vocab = len(vocabulary.keys())
    print ('Initialized vocabulary from file with %d unique words ' +
       '(from %d total words in dataset).') % \
      (num_words_vocab, num_words_dataset)
    assert len(vocabulary_inverted) == num_words_vocab
    return vocabulary_inverted

In [17]:
vocabList =  ['<EOS>'] + init_vocab_from_file(vocabFile)

wholeData = h5py.File(dataPath, 'r')
poseData = []
targetSent = []
existIdict = []
stageIdict = []
inputSent = []

streamSize,batchSize,_ = wholeData['frame_fc7'].shape
for streamIdx in range(0, streamSize/maxWords):
    for batchIdx in range(0, batchSize):
        poseData.append(wholeData['frame_fc7'][streamIdx * maxWords: (streamIdx + 1) * maxWords, batchIdx: batchIdx + 1, :])
        inputSent.append(wholeData['input_sentence'][streamIdx * maxWords: (streamIdx + 1) * maxWords, batchIdx: batchIdx + 1])
        existIdict.append(wholeData['cont_sentence'][streamIdx * maxWords: (streamIdx + 1) * maxWords, batchIdx: batchIdx + 1])
        stageIdict.append(wholeData['stage_indicator'][streamIdx * maxWords: (streamIdx + 1) * maxWords, batchIdx: batchIdx + 1])
        targetSent.append(wholeData['target_sentence'][streamIdx * maxWords: (streamIdx + 1) * maxWords, batchIdx: batchIdx + 1])

Initialized vocabulary from file with 13852 unique words (from 13851 total words in dataset).


In [18]:
ref = open('train_log/ref.txt', 'w')
res = open('train_log/res.txt', 'w')
for it in range(0, len(inputSent)):
    probs = lstm_net.forward(frames_fc7=poseData[it], cont_sentence=existIdict[it], input_sentence=inputSent[it],\
                             stage_indicator=stageIdict[it])['probs']
    
    predictIdx = np.squeeze(stageIdict[it] * np.argmax(probs, axis=2))

    predictIdx = predictIdx[np.where(stageIdict[it] != 0)[0][0]:]
    if 0 in predictIdx:
        predictIdx = predictIdx[:np.where(predictIdx == 0)[0][0]]
    
    res.write(vocab_inds_to_sentence(vocabList, [int(x) for x in predictIdx]) + '\n')
    ref.write(vocab_inds_to_sentence(vocabList, [int(x) for x in targetSent[it] if x != -1]) + '\n')
ref.close()
res.close()