In [88]:
from copy import deepcopy
from argparse import Namespace

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.legacy import data

from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage
from ignite.contrib.handlers.tqdm_logger import ProgressBar

In [224]:
def define_argparse():
    p = {
        'model_fn': './model.pth',
        'batch_size': 256,
        'topk': 1,
        'gpu_id': -1,
        'drop_rnn' : False,
        'drop_cnn' : False,
        'topk' : 1,
    }
    config = Namespace(**p)

    return config


def main(text, config):
    saved_data = torch.load(
        config.model_fn,
        map_location='cpu' if config.gpu_id < 0 else 'cuda:{}'.format(config.gpu_id)
    )

    rnn_dict = saved_data['rnn']
    cnn_dict = saved_data['cnn']
    train_config = saved_data['config']
    vocab = saved_data['vocab']
    classes = saved_data['classes']

    text_field = data.Field(batch_first=True)
    label_field = data.Field(sequential=False,
                             unk_token=None)

    text_field.vocab = vocab
    label_field.vocab = classes

    lines = []

    for t in text:
        lines.append(t.strip().split(' ')[:train_config.max_length])

    with torch.no_grad():
        ensemble = []
        if rnn_dict != None and not config.drop_rnn:
            model = RNNclassifier(input_size=len(vocab),
                                  emb_dim=train_config.emb_dim,
                                  hidden_size=train_config.hidden_size,
                                  n_layers=train_config.n_layers,
                                  n_classes=len(classes),
                                  dropout=train_config.dropout)
            model.load_state_dict(rnn_dict)
            ensemble.append(model)

        if cnn_dict != None and not config.drop_cnn:
            model = CNNclassifier(input_size=len(vocab),
                                  emb_dim=train_config.emb_dim,
                                  window_sizes=train_config.window_sizes,
                                  n_filters=train_config.n_filters,
                                  use_batchnorm=train_config.use_batchnorm,
                                  dropout=train_config.dropout,
                                  n_classes=len(classes))
            model.load_state_dict(cnn_dict)
            ensemble.append(model)

        y_hats = []
        for model in ensemble:
            model.eval()

            y_hat = []
            for i in range(0, len(lines), config.batch_size):
                x = text_field.numericalize(
                    text_field.pad(lines[i:i + config.batch_size]),
                    device = 'cpu' if config.gpu_id == -1 else 'cuda:{}'.format(config.gpu_id)
                )

                y_hat.append(model(x).cpu())
                # y_hat = (bs, class)
            y_hat = torch.cat(y_hat, dim=0)

            y_hats.append(y_hat)
        y_hats = torch.stack(y_hats, dim=0).exp()
        # y_hats = (n_models, bs, class)
        y_hats = torch.mean(y_hats, dim=0)
        # y_hats = (bs, class)

        probs, indices = torch.topk(y_hats, config.topk, dim=-1)
        
        for i in range(len(text)):
            print('{}\t{}\n'.format(
                    ' '.join(classes.itos[indices[i][j]] for j in range(config.topk)),
                    ' '.join(lines[i])
                ))

In [229]:
text = ['Good news for you! 50% discount', "What's up john? Long time no see."]

config = define_argparse()
main(text, config)

spam	Good news for you! 50% discount

ham	What's up john? Long time no see.

