In [None]:
import torch
import torch.nn
import torch.optim
import torch.utils.data
import torch.nn.functional as F
from splitcross import SplitCrossEntropyLoss

import numpy as np
import networkx as nx
import math
import json
import time

import data
import os
from utils import batchify
from argparse import Namespace
from model import AWDRNNModel
from train import train, evaluate
import datetime

In [None]:
suffix = '2226_2020-04-18_07-35-19_999938929'

In [None]:
log = json.load(open('train_logs_multi_runs/log_stats_model_100' + suffix + '.json', 'r'))

In [None]:
args = Namespace(**log)

In [None]:
corpus = data.Corpus(args.data)
cuda = 'cuda'

train_data = batchify(corpus.train, args.batch_size, args, cuda)
train_eval_data = batchify(corpus.train, args.eval_batch_size, args, cuda)
val_data = batchify(corpus.valid, args.eval_batch_size, args, cuda)
test_data = batchify(corpus.test, args.eval_batch_size, args, cuda)

ntokens = len(corpus.dictionary)

In [None]:
custom_model = AWDRNNModel(args.model, 
                               ntokens, 
                               args.emsize, 
                               args.nhid, 
                               args.nlayers, 
                               args.dropout, 
                               args.dropouth, 
                               args.dropouti, 
                               args.dropoute, 
                               args.wdrop, 
                               args.tied,
                               args.recepie,
                               verbose=False)

In [None]:
custom_model.load_state_dict(torch.load('models_weights/dump_weights_model_' + suffix + '.pt'))

In [None]:
custom_model.to(cuda);

In [None]:
criterion = SplitCrossEntropyLoss(args.emsize, splits=[], verbose=False)

In [None]:
train_loss = evaluate(custom_model, criterion, train_eval_data, args.eval_batch_size, args)
val_loss = evaluate(custom_model, criterion, val_data, args.eval_batch_size, args)
test_loss = evaluate(custom_model, criterion, test_data, args.eval_batch_size, args)

In [None]:
print('-' * 89)
print('train loss {:5.4f} | '
    'train ppl {:8.2f} | train bpw {:8.3f} |\n| valid loss {:5.4f} | '
    'valid ppl {:8.2f} | valid bpw {:8.3f} |\n| test loss {:5.4f} | '
    'test ppl {:8.2f} | test bpw {:8.3f} |'.format(
        train_loss, math.exp(train_loss), train_loss / math.log(2),
        val_loss, math.exp(val_loss), val_loss / math.log(2),
    test_loss, math.exp(test_loss), test_loss / math.log(2)))
print('-' * 89)

In [None]:
print('logged train loss', log['train_losses'][-1])
print('logged valid loss', log['val_losses'][-1])
print('logged test loss', log['test_losses'][-1])