In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from __future__ import absolute_import, division, unicode_literals

import sys
import numpy as np
import logging
import sklearn
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
import os
import json
import pytorch_lightning as pl
import torch

sys.path.insert(0, 'nli/')
from setup import load_model, prep_sent, find_checkpoint
from data import NLIDataModule


In [4]:
class TransferResults:
    def __init__(self, args, tasks_with_acc_given = None):
        
        # Read the results
        _, version_path = find_checkpoint(args.ckpt_path, args.version)
        # with open(os.path.join(version_path, 'results.txt'), 'r') as f:
        with open(os.path.join('results.txt'), 'r') as f:
            results = json.load(f)
        self.results = results

        # Assert that the tasks with acc are the same as the ones given
        self.tasks_with_acc = self.get_tasks_with_acc(tasks_with_acc_given)

    def get_tasks_with_acc(self, tasks_with_acc_given):
        task_with_acc = {task for task in self.results if 'acc' in self.results[task]}
        if tasks_with_acc_given is not None:
            assert task_with_acc == tasks_with_acc_given, f'{task_with_acc} != {tasks_with_acc_given}'
        return task_with_acc

    def get_transfer_accs(self):
        dev_accs = {}
        num_dev_samples = {}
        for task, task_data in self.results.items():
            if task not in self.tasks_with_acc:
                continue
            dev_accs[task] = task_data['devacc']
            num_dev_samples[task] = task_data['ndev']

        # Calculate macro accuracy
        macro_acc = sum(dev_accs.values()) / len(dev_accs)

        # Calculate micro accuracy
        total_dev_samples = sum(num_dev_samples.values())
        micro_acc = sum(dev_accs[task] * num_dev_samples[task] / total_dev_samples for task in dev_accs)

        return {'micro': micro_acc, 'macro': macro_acc}
    


In [5]:
class Args:
    def __init__(self, model_type = 'avg_word_emb', ckpt_path = None, version = 'version_0', path_to_vocab = 'store/vocab.pkl', num_workers = 3):
        self.model_type = model_type
        self.ckpt_path = model_type if ckpt_path is None else ckpt_path
        self.version = version
        self.path_to_vocab = path_to_vocab
        self.num_workers = num_workers

In [9]:
class NLIResults:
    def __init__(self, args):

        self.model, vocab = load_model(args.model_type, args.path_to_vocab, args.ckpt_path, args.version)
        self.datamodule = NLIDataModule(vocab=vocab, batch_size=64, num_workers=args.num_workers)
        self.trainer = pl.Trainer(
            logger = False,
            accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        )

    def test(self):
        test_acc = self.trainer.test(self.model, datamodule=self.datamodule, verbose = False)[0]['test_acc']
        return test_acc
    
    def validate(self):
        val_acc = self.trainer.validate(self.model, datamodule=self.datamodule, verbose = False)[0]['val_acc']
        return val_acc
    
    def get_nli_accs(self):
        return {'test': self.test()*100., 'val': self.validate()*100.}
    



GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Testing: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [15]:
model_types = ['avg_word_emb', 'uni_lstm', 'bi_lstm', 'max_pool_lstm']
results = {model_type : {} for model_type in model_types}

In [18]:
for model_type, accs in results.items():
    accs['nli'] = 'test'

print(results)

{'avg_word_emb': {'nli': 'test'}, 'uni_lstm': {'nli': 'test'}, 'bi_lstm': {'nli': 'test'}, 'max_pool_lstm': {'nli': 'test'}}


In [14]:
args = Args(model_type = 'avg_word_emb')

nli_results = NLIResults(args)
transfer_avg_word_emb = TransferResults(args, {'MR', 'CR'})

transfer_accs = transfer_avg_word_emb.get_transfer_accs()
nli_accs = nli_results.get_nli_accs()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


KeyboardInterrupt: 

In [23]:
accs = {}

In [32]:
accs2 = {**accs, **nli_accs}
print(accs2)

accs2 = {**accs2, **transfer_accs}
accs2

{'test': 0.6547231078147888, 'val': 0.6559642553329468}


{'test': 0.6547231078147888,
 'val': 0.6559642553329468,
 'micro': 51.42129389762415,
 'macro': 52.435}

{'micro': 51.42129389762415, 'macro': 52.435}

In [12]:
# concatenate two dictionaries python nli_accs and tranfer_accs
accs = {**nli_accs, **transfer_accs}

# round all acs to 2 decimal places
accs = {k: round(v, 1) for k, v in accs.items()}

In [13]:
accs

{'test': 0.7, 'val': 0.7, 'micro': 51.4, 'macro': 52.4}

In [19]:
_, version_path = find_checkpoint(args.ckpt_path, args.version)

In [20]:
version_path

'logs/avg_word_emb/version_0'