In [1]:
import torch
import pickle

In [None]:
class LMProb():
    def __init__(self, model_path, dict_path):
        self.model = torch.load(open(model_path, 'rb'), map_location={'cuda:0': 'cpu'})
        self.model = self.model.cpu()
        self.model.eval()

        self.dictionary = pickle.load(open(dict_path, 'rb'))
        print (len(self.dictionary))

    def get_prob(self, words, verbose=False):
        with torch.no_grad():
            pad_words = ['<s>'] + words + ['</s>']
            idxs = [self.dictionary.getid(w) for w in pad_words]
            inp = torch.tensor([int(idxs[0])]).long().unsqueeze(0)

            if verbose:
                print('words =', pad_words)
                print('idxs =', idxs)

            hidden = self.model.init_hidden(1)
            log_probs = []
            for i in range(1, len(pad_words)):
                output, hidden = self.model(inp, hidden)
                word_weights = output.squeeze().data.double().exp()
                prob = word_weights[idxs[i]] / word_weights.sum()
                log_probs.append(torch.log(prob))
                inp.data.fill_(int(idxs[i]))

            if verbose:
                for i in range(len(log_probs)):
                    print('  {} => {:d},\tlogP(w|s)={:.4f}'.format(pad_words[i+1], idxs[i+1], log_probs[i]))
                print('\n  => sum_prob = {:.4f}'.format(sum(log_probs)))

        return sum(log_probs) / len(log_probs)

In [None]:
lm_model_paths   = ['out/anno.pt', 'out/code.pt']
read_file_paths  = ['data/java/train.token.code', 'data/java/train.token.nl']
dicts            = ['language_models/java/dict_code.pkl', 'language_models/java/dict_nl.pkl']
score_paths = ['scores/train.token.anno.score', 'scores/train.token.code.score']

In [None]:
def get_score(line, num):
    sent = line.strip().split(' ')
    lm_score = lm_model.get_prob(sent)
    return num, lm_score

# for anno / code
for i in range(2):
    lm_model = LMProb(lm_model_paths[i], dicts[i])
    
    with open(read_file_paths[i], 'rt') as fp:
        lines = fp.readlines()
    
    with ProcessPoolExecutor(max_workers=8) as executor:
        results = executor.map(get_score, lines, list(range(len(lines))))
        
    scores = {num: lm_score for (num, lm_score) in results}
    
    with open(score_paths[i], 'wt') as fp:
        for i in range(len(lines)):
            fp.write(str(scores[i]) + '\n')