In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import absolute_import, division, unicode_literals

import sys
import numpy as np
import logging
import sklearn
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
import os
import json
import pytorch_lightning as pl
import torch
from torch import utils

import matplotlib.pyplot as plt


sys.path.insert(0, 'nli/')
from setup import load_model, prep_sent, find_checkpoint
from data import NLIDataModule, DataSetPadding

In [3]:
class Args:
    def __init__(self, model_type = 'avg_word_emb', ckpt_path = None, version = 'version_0', path_to_vocab = 'store/vocab.pkl', num_workers = 3):
        self.model_type = model_type
        self.ckpt_path = model_type if ckpt_path is None else ckpt_path
        self.version = version
        self.path_to_vocab = path_to_vocab
        self.num_workers = num_workers

args = Args('avg_word_emb')

In [None]:
class NLIResults:
    def __init__(self, args):

        self.model, vocab = load_model(args.model_type, args.path_to_vocab, args.ckpt_path, args.version)
        self.datamodule = NLIDataModule(vocab=vocab, batch_size=64, num_workers=args.num_workers)
        self.trainer = pl.Trainer(
            logger = False,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        )

    def test(self):
        test_acc = self.trainer.test(self.model, datamodule=self.datamodule, verbose = False)[0]['test_acc']
        return test_acc
    
    def validate(self):
        val_acc = self.trainer.validate(self.model, datamodule=self.datamodule, verbose = False)[0]['val_acc']
        return val_acc
    
    def get_nli_accs(self):
        return {'val': self.validate()*100., 'test': self.test()*100., }
    
    def get_nli_preds(self):
        y_hat, y = self.trainer.predict(self.model, datamodule=self.datamodule)[0]
        y_pred = torch.nn.functional.softmax(y_hat, dim=1)
        return y_pred