In [1]:
import pandas as pd

## Dataset exploration

In [12]:
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.common import Params
import os
import logging
from scicite.dataset_readers.citation_data_reader_scicite import SciciteDatasetReader
from allennlp.data.instance import Instance

from typing import Dict, Iterable, Tuple

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
os.environ['SEED'] = '21016'
os.environ['PYTORCH_SEED'] = str(int(os.environ['SEED']) // 3)
os.environ['NUMPY_SEED'] = str(int(os.environ['PYTORCH_SEED']) // 3)
os.environ["elmo"] = "true"

In [13]:
parameter_filename = "./experiment_configs/scicite-experiment-0.05-0.05.json"
overrides = ""
params = Params.from_file(parameter_filename, overrides)

dataset_reader = DatasetReader.from_params(params.pop('dataset_reader'))


04/10/2024 22:23:04 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.dataset_readers.dataset_reader.DatasetReader'> from params {'multilabel': False, 'type': 'scicite_datasetreader', 'use_sparse_lexicon_features': False, 'with_elmo': 'true'} and extras {}
04/10/2024 22:23:04 - INFO - allennlp.common.params -   dataset_reader.type = scicite_datasetreader
04/10/2024 22:23:04 - INFO - allennlp.common.params -   dataset_reader.lazy = False
04/10/2024 22:23:04 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.tokenizers.tokenizer.Tokenizer'> from params {} and extras {}
04/10/2024 22:23:04 - INFO - allennlp.common.params -   dataset_reader.tokenizer.type = word
04/10/2024 22:23:04 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.tokenizers.word_tokenizer.WordTokenizer'> from params {} and extras {}
04/10/2024 22:23:04 - INFO - allennlp.common.params -   dataset_reader.tokenizer.start_tokens 

In [14]:
# params.as_dict()

In [15]:
validation_dataset_reader_params = params.pop("validation_dataset_reader", None)
validation_and_test_dataset_reader: DatasetReader = dataset_reader
if validation_dataset_reader_params is not None:
    logger.info("Using a separate dataset reader to load validation and test data.")
    validation_and_test_dataset_reader = DatasetReader.from_params(validation_dataset_reader_params)

train_data_path = params.pop('train_data_path')
logger.info("Reading training data from %s", train_data_path)
train_data = dataset_reader.read(train_data_path)

datasets: Dict[str, Iterable[Instance]] = {"train": train_data}

04/10/2024 22:23:04 - INFO - allennlp.common.params -   validation_dataset_reader = None
04/10/2024 22:23:04 - INFO - allennlp.common.params -   train_data_path = scicite_data/train.jsonl
04/10/2024 22:23:04 - INFO - __main__ -   Reading training data from scicite_data/train.jsonl
8243it [00:02, 3274.03it/s]


In [16]:
train_data[0].fields

{'citation_text': <allennlp.data.fields.text_field.TextField at 0x7f13cc548ba8>,
 'labels': <allennlp.data.fields.label_field.LabelField at 0x7f13cc548978>,
 'year_diff': <allennlp.data.fields.array_field.ArrayField at 0x7f13cc548ac8>,
 'citing_paper_id': <allennlp.data.fields.metadata_field.MetadataField at 0x7f13cc548cf8>,
 'cited_paper_id': <allennlp.data.fields.metadata_field.MetadataField at 0x7f13cc548d30>,
 'citation_excerpt_index': <allennlp.data.fields.metadata_field.MetadataField at 0x7f13cc548e80>,
 'citation_id': <allennlp.data.fields.metadata_field.MetadataField at 0x7f13cc548da0>}

### Train_multitask_two_tasks 

In [17]:
"""
The `train_multitask` subcommand that can be used to train the model in the multitask fashion
It requires a configuration file and a directory in
which to write the results.
.. code-block:: bash
   $ allennlp train --help
   usage: allennlp train [-h] -s SERIALIZATION_DIR [-r] [-o OVERRIDES]
                         [--file-friendly-logging]
                         [--include-package INCLUDE_PACKAGE]
                         param_path
   Train the specified model on the specified dataset.
   positional arguments:
   param_path            path to parameter file describing the model to be
                           trained
   optional arguments:
   -h, --help            show this help message and exit
   -s SERIALIZATION_DIR, --serialization-dir SERIALIZATION_DIR
                           directory in which to save the model and its logs
   -r, --recover         recover training from the state in serialization_dir
   -o OVERRIDES, --overrides OVERRIDES
                           a JSON structure used to override the experiment
                           configuration
   --include-package INCLUDE_PACKAGE
                           additional packages to include
   --file-friendly-logging
                           outputs tqdm status on separate lines and slows tqdm
                           refresh rate
"""
import random
from typing import Dict, Iterable, Tuple
import argparse
import logging
import os
import re

import torch

from allennlp.commands.evaluate import evaluate
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params
from allennlp.common.util import prepare_environment, prepare_global_logging, \
                                 get_frozen_and_tunable_parameter_names, dump_metrics
from allennlp.data import Vocabulary
from allennlp.data.instance import Instance
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.models.archival import archive_model, CONFIG_NAME
from allennlp.models.model import Model, _DEFAULT_WEIGHTS
from allennlp.training.trainer import Trainer

# from scicite.training.multitask_trainer_two_tasks import MultiTaskTrainer2
from scicite.training.vocabulary_multitask import VocabularyMultitask

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


class TrainMultiTask2(Subcommand):
    """ Class for training the model with two scaffold tasks """
    def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser:
        # pylint: disable=protected-access
        description = '''Train the specified model on the specified dataset.'''
        subparser = parser.add_parser(name, description=description, help='Train a model')

        subparser.add_argument('param_path',
                               type=str,
                               help='path to parameter file describing the model to be trained')

        subparser.add_argument('-s', '--serialization-dir',
                               required=True,
                               type=str,
                               help='directory in which to save the model and its logs')

        subparser.add_argument('-r', '--recover',
                               action='store_true',
                               default=False,
                               help='recover training from the state in serialization_dir')

        subparser.add_argument('-o', '--overrides',
                               type=str,
                               default="",
                               help='a JSON structure used to override the experiment configuration')

        subparser.add_argument('--file-friendly-logging',
                               action='store_true',
                               default=False,
                               help='outputs tqdm status on separate lines and slows tqdm refresh rate')

        subparser.set_defaults(func=train_model_from_args)

        return subparser

def train_model_from_args(args: argparse.Namespace):
    """
    Just converts from an ``argparse.Namespace`` object to string paths.
    """
    train_model_from_file(args.param_path,
                          args.serialization_dir,
                          args.overrides,
                          args.file_friendly_logging,
                          args.recover)


def train_model_from_file(parameter_filename: str,
                          serialization_dir: str,
                          overrides: str = "",
                          file_friendly_logging: bool = False,
                          recover: bool = False) -> Model:
    """
    A wrapper around :func:`train_model` which loads the params from a file.

    Parameters
    ----------
    param_path : ``str``
        A json parameter file specifying an AllenNLP experiment.
    serialization_dir : ``str``
        The directory in which to save results and logs. We just pass this along to
        :func:`train_model`.
    overrides : ``str``
        A JSON string that we will use to override values in the input parameter file.
    file_friendly_logging : ``bool``, optional (default=False)
        If ``True``, we make our output more friendly to saved model files.  We just pass this
        along to :func:`train_model`.
    recover : ``bool`, optional (default=False)
        If ``True``, we will try to recover a training run from an existing serialization
        directory.  This is only intended for use when something actually crashed during the middle
        of a run.  For continuing training a model on new data, see the ``fine-tune`` command.
    """
    # Load the experiment config from a file and pass it to ``train_model``.
    params = Params.from_file(parameter_filename, overrides)
    return train_model(params, serialization_dir, file_friendly_logging, recover)


def datasets_from_params(params: Params) -> Tuple[Dict[str, Iterable[Instance]], Dict[str, Iterable[Instance]], Dict[str, Iterable[Instance]]]:
    """
    Load all the datasets specified by the config.
    This includes the main dataset and the two scaffold auxiliary datasets
    """
    dataset_reader = DatasetReader.from_params(params.pop('dataset_reader'))
    validation_dataset_reader_params = params.pop("validation_dataset_reader", None)

    validation_and_test_dataset_reader: DatasetReader = dataset_reader
    if validation_dataset_reader_params is not None:
        logger.info("Using a separate dataset reader to load validation and test data.")
        validation_and_test_dataset_reader = DatasetReader.from_params(validation_dataset_reader_params)

    train_data_path = params.pop('train_data_path')
    logger.info("Reading training data from %s", train_data_path)
    train_data = dataset_reader.read(train_data_path)

    datasets: Dict[str, Iterable[Instance]] = {"train": train_data}

    # 2. Auxillary training data.
    dataset_reader_aux = DatasetReader.from_params(params.pop('dataset_reader_aux'))
    train_data_path_aux = params.pop('train_data_path_aux')
    logger.info("Reading auxiliary training data from %s", train_data_path_aux)
    train_data_aux = dataset_reader_aux.read(train_data_path_aux)

    dataset_reader_aux2 = DatasetReader.from_params(params.pop('dataset_reader_aux2'))
    train_data_path_aux2 = params.pop('train_data_path_aux2')
    logger.info("Reading second auxiliary training data for from %s", train_data_path_aux2)
    train_data_aux2 = dataset_reader_aux2.read(train_data_path_aux2)

    # If only using a fraction of the auxiliary data.
    aux_sample_fraction = params.pop("aux_sample_fraction", 1.0)
    if aux_sample_fraction < 1.0:
        sample_size = int(aux_sample_fraction * len(train_data_aux))
        train_data_aux = random.sample(train_data_aux, sample_size)
        train_data_aux2 = random.sample(train_data_aux2, sample_size)

    # Balance the datasets by inflating the size of the smaller dataset to the size of the larger dataset.
    train_size = len(train_data)
    aux_train_size = len(train_data_aux)
    aux2_train_size = len(train_data_aux2)

    # Make second auxillary dataset the same size of the first auxiliary dataset
    if aux2_train_size > aux_train_size:
        train_data_aux2 = random.sample(train_data_aux2, aux_train_size)
    else:
        train_data_aux = random.sample(train_data_aux, aux2_train_size)

    # inflate training size to be as large as auxiliary training data
    if train_size > aux_train_size:
        difference = train_size - aux_train_size
        aux_sample = [random.choice(train_data_aux) for _ in range(difference)]
        train_data_aux = train_data_aux + aux_sample
        logger.info("Inflating auxiliary train data from {} to {} samples".format(
            aux_train_size, len(train_data_aux)))
    else:
        difference = aux_train_size - train_size
        train_sample = [random.choice(train_data) for _ in range(difference)]
        train_data = train_data + train_sample
        logger.info("Inflating train data from {} to {} samples".format(
            train_size, len(train_data)))

    datasets["train"] = train_data
    datasets_aux = {"train_aux": train_data_aux}
    datasets_aux2 = {"train_aux": train_data_aux2}

    validation_data_path = params.pop('validation_data_path', None)
    if validation_data_path is not None:
        logger.info("Reading validation data from %s", validation_data_path)
        validation_data = validation_and_test_dataset_reader.read(validation_data_path)
        datasets["validation"] = validation_data

    # Auxiliary validation data.
    validation_data_path_aux = params.pop('validation_data_path_aux', None)
    if validation_data_path_aux is not None:
        logger.info(f"Reading auxilliary validation data from {validation_data_path_aux}")
        validation_data_aux = dataset_reader_aux.read(validation_data_path_aux)
        datasets_aux["validation_aux"] = validation_data_aux
    else:
        validation_data_aux = None
    validation_data_path_aux2 = params.pop('validation_data_path_aux2', None)
    if validation_data_path_aux2 is not None:
        logger.info(f"Reading auxilliary validation data from {validation_data_path_aux2}")
        validation_data_aux2 = dataset_reader_aux2.read(validation_data_path_aux2)
        datasets_aux2["validation_aux"] = validation_data_aux2
    else:
        validation_data_aux2 = None

    test_data_path = params.pop("test_data_path", None)
    if test_data_path is not None:
        logger.info("Reading test data from %s", test_data_path)
        test_data = validation_and_test_dataset_reader.read(test_data_path)
        datasets["test"] = test_data

    # Auxillary test data
    test_data_path_aux = params.pop("test_data_path_aux", None)
    if test_data_path_aux is not None:
        logger.info(f"Reading auxiliary test data from {test_data_path_aux}")
        test_data_aux = dataset_reader_aux.read(test_data_path_aux)
        datasets_aux["test_aux"] = test_data_aux
    else:
        test_data_aux = None

    test_data_path_aux2 = params.pop("test_data_path_aux2", None)
    if test_data_path_aux2 is not None:
        logger.info(f"Reading auxillary test data from {test_data_path_aux2}")
        test_data_aux2 = dataset_reader_aux2.read(test_data_path_aux2)
        datasets_aux2["test_aux"] = test_data_aux2
    else:
        test_data_aux2 = None

    return datasets, datasets_aux, datasets_aux2

def create_serialization_dir(params: Params, serialization_dir: str, recover: bool) -> None:
    """
    This function creates the serialization directory if it doesn't exist.  If it already exists
    and is non-empty, then it verifies that we're recovering from a training with an identical configuration.

    Parameters
    ----------
    params: ``Params``
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir: ``str``
        The directory in which to save results and logs.
    recover: ``bool``
        If ``True``, we will try to recover from an existing serialization directory, and crash if
        the directory doesn't exist, or doesn't match the configuration we're given.
    """
    if os.path.exists(serialization_dir) and os.listdir(serialization_dir):
        if not recover:
            raise ConfigurationError(f"Serialization directory ({serialization_dir}) already exists and is "
                                     f"not empty. Specify --recover to recover training from existing output.")

        logger.info(f"Recovering from prior training at {serialization_dir}.")

        recovered_config_file = os.path.join(serialization_dir, CONFIG_NAME)
        if not os.path.exists(recovered_config_file):
            raise ConfigurationError("The serialization directory already exists but doesn't "
                                     "contain a config.json. You probably gave the wrong directory.")
        else:
            loaded_params = Params.from_file(recovered_config_file)

            # Check whether any of the training configuration differs from the configuration we are
            # resuming.  If so, warn the user that training may fail.
            fail = False
            flat_params = params.as_flat_dict()
            flat_loaded = loaded_params.as_flat_dict()
            for key in flat_params.keys() - flat_loaded.keys():
                logger.error(f"Key '{key}' found in training configuration but not in the serialization "
                             f"directory we're recovering from.")
                fail = True
            for key in flat_loaded.keys() - flat_params.keys():
                logger.error(f"Key '{key}' found in the serialization directory we're recovering from "
                             f"but not in the training config.")
                fail = True
            for key in flat_params.keys():
                if flat_params.get(key, None) != flat_loaded.get(key, None):
                    logger.error(f"Value for '{key}' in training configuration does not match that the value in "
                                 f"the serialization directory we're recovering from: "
                                 f"{flat_params[key]} != {flat_loaded[key]}")
                    fail = True
            if fail:
                raise ConfigurationError("Training configuration does not match the configuration we're "
                                         "recovering from.")
    else:
        if recover:
            raise ConfigurationError(f"--recover specified but serialization_dir ({serialization_dir}) "
                                     "does not exist.  There is nothing to recover from.")
        os.makedirs(serialization_dir, exist_ok=True)


def train_model(params: Params,
                serialization_dir: str,
                file_friendly_logging: bool = False,
                recover: bool = False) -> Model:
    """
    Trains the model specified in the given :class:`Params` object, using the data and training
    parameters also specified in that object, and saves the results in ``serialization_dir``.

    Parameters
    ----------
    params : ``Params``
        A parameter object specifying an AllenNLP Experiment.
    serialization_dir : ``str``
        The directory in which to save results and logs.
    file_friendly_logging : ``bool``, optional (default=False)
        If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
        down tqdm's output to only once every 10 seconds.
    recover : ``bool``, optional (default=False)
        If ``True``, we will try to recover a training run from an existing serialization
        directory.  This is only intended for use when something actually crashed during the middle
        of a run.  For continuing training a model on new data, see the ``fine-tune`` command.

    Returns
    -------
    best_model: ``Model``
        The model with the best epoch weights.
    """
    prepare_environment(params)

    create_serialization_dir(params, serialization_dir, recover)
    prepare_global_logging(serialization_dir, file_friendly_logging)

    check_for_gpu(params.get('trainer').get('cuda_device', -1))

    params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

    all_datasets, all_datasets_aux, all_datasets_aux2 = datasets_from_params(params)
    datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
    datasets_for_vocab_creation_aux = set(params.pop("auxiliary_datasets_for_vocab_creation", all_datasets_aux))
    datasets_for_vocab_creation_aux2 = set(params.pop("auxiliary_datasets_for_vocab_creation_2", all_datasets_aux2))


    mixing_ratio = params.pop_float("mixing_ratio")
    mixing_ratio2 = params.pop_float("mixing_ratio2")

    cutoff_epoch = params.pop("cutoff_epoch", -1)

    for dataset in datasets_for_vocab_creation:
        if dataset not in all_datasets:
            raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

    logger.info("From dataset instances, %s will be considered for vocabulary creation.",
                ", ".join(datasets_for_vocab_creation))
    vocab_instances_aux = [
        instance for key, dataset in all_datasets_aux.items()
        for instance in dataset
        if key in datasets_for_vocab_creation_aux
    ]
    vocab_instances_aux.extend([
        instance for key, dataset in all_datasets_aux2.items()
        for instance in dataset
        if key in datasets_for_vocab_creation_aux2
    ])
    vocab = VocabularyMultitask.from_params(
            params.pop("vocabulary", {}),
            (instance for key, dataset in all_datasets.items()
             for instance in dataset
             if key in datasets_for_vocab_creation),
            instances_aux=vocab_instances_aux
    )
    model = Model.from_params(vocab=vocab, params=params.pop('model'))

    # Initializing the model can have side effect of expanding the vocabulary
    vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))
    
    iterator = DataIterator.from_params(params.pop("iterator"))
    iterator.index_with(vocab)

    iterator_aux = DataIterator.from_params(params.pop("iterator_aux"))
    iterator_aux.index_with(vocab)

    iterator_aux2 = DataIterator.from_params(params.pop("iterator_aux2"))
    iterator_aux2.index_with(vocab)

    validation_iterator_params = params.pop("validation_iterator", None)
    if validation_iterator_params:
        validation_iterator = DataIterator.from_params(validation_iterator_params)
        validation_iterator.index_with(vocab)
    else:
        validation_iterator = None

    # TODO: if validation in multi-task need to add validation iterator as above

    train_data = all_datasets.get('train')
    validation_data = all_datasets.get('validation')
    test_data = all_datasets.get('test')

    train_data_aux = all_datasets_aux.get('train_aux')
    validation_data_aux = all_datasets_aux.get('validation_aux')
    test_data_aux = all_datasets_aux.get('test_aux')

    train_data_aux2 = all_datasets_aux2.get('train_aux')
    validation_data_aux2 = all_datasets_aux2.get('validation_aux')
    test_data_aux2 = all_datasets_aux2.get('test_aux')

    trainer_params = params.pop("trainer")
    no_grad_regexes = trainer_params.pop("no_grad", ())
    for name, parameter in model.named_parameters():
        if any(re.search(regex, name) for regex in no_grad_regexes):
            parameter.requires_grad_(False)

    frozen_parameter_names, tunable_parameter_names = \
                   get_frozen_and_tunable_parameter_names(model)
    logger.info("Following parameters are Frozen  (without gradient):")
    for name in frozen_parameter_names:
        logger.info(name)
    logger.info("Following parameters are Tunable (with gradient):")
    for name in tunable_parameter_names:
        logger.info(name)

    trainer = MultiTaskTrainer2.from_params(model=model,
                                            serialization_dir=serialization_dir,
                                            iterator=iterator,
                                            iterator_aux=iterator_aux,
                                            iterator_aux2=iterator_aux2,
                                            train_data=train_data,
                                            train_data_aux=train_data_aux,
                                            train_data_aux2=train_data_aux2,
                                            mixing_ratio=mixing_ratio,
                                            mixing_ratio2=mixing_ratio2,
                                            cutoff_epoch=cutoff_epoch,
                                            validation_data_aux=validation_data_aux,
                                            validation_data_aux2=validation_data_aux2,
                                            validation_data=validation_data,
                                            params=trainer_params,
                                            validation_iterator=validation_iterator)
    print(trainer._cuda_devices[0])
    evaluate_on_test = params.pop_bool("evaluate_on_test", False)
    evaluate_aux_on_test = params.pop_bool("evaluate_aux_on_test", False)
    params.assert_empty('base train command')

    try:
        metrics = trainer.train()
    except KeyboardInterrupt:
        # if we have completed an epoch, try to create a model archive.
        if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
            logging.info("Training interrupted by the user. Attempting to create "
                         "a model archive using the current best epoch weights.")
            archive_model(serialization_dir, files_to_archive=params.files_to_archive)
        raise

    # Now tar up results
    archive_model(serialization_dir, files_to_archive=params.files_to_archive)

    logger.info("Loading the best epoch weights.")
    best_model_state_path = os.path.join(serialization_dir, 'best.th')
    best_model_state = torch.load(best_model_state_path)
    best_model = model
    best_model.load_state_dict(best_model_state)

    if test_data and evaluate_on_test:
        logger.info("The model will be evaluated using the best epoch weights.")
        test_metrics = evaluate(
                best_model, test_data, validation_iterator or iterator,
                cuda_device=trainer._cuda_devices[0] # pylint: disable=protected-access
        )
        for key, value in test_metrics.items():
            metrics["test_" + key] = value

    elif test_data:
        logger.info("To evaluate on the test set after training, pass the "
                    "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")

    if test_data_aux and evaluate_aux_on_test:
        # for instance in test_data_aux:
        #     instance.index_fields(vocab)
        # for instance in test_data_aux2:
        #     instance.index_fields(vocab)
        test_metrics_aux = evaluate(best_model, test_data_aux, iterator_aux,
                                    cuda_device=trainer._cuda_devices[0])  # pylint: disable=protected-access
        test_metrics_aux2 = evaluate(best_model, test_data_aux2, iterator_aux2,
                                     cuda_device=trainer._cuda_devices[0])  # pylint: disable=protected-access

        for key, value in test_metrics_aux.items():
            metrics["test_aux_" + key] = value
        for key, value in test_metrics_aux2.items():
            metrics["test_aux2_" + key] = value

    elif test_data_aux:
        logger.info("To evaluate on the auxiliary test set after training, pass the "
                    "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")

    dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True)

    return best_model


### MultiTaskTrainer2

In [18]:
"""
This module is an extended trainer based on the allennlp's default trainer to handle multitask training
    for two auxiliary tasks

A :class:`~allennlp.training.trainer.Trainer` is responsible for training a
:class:`~allennlp.models.model.Model`.

Typically you might create a configuration file specifying the model and
training parameters and then use :mod:`~allennlp.commands.train`
rather than instantiating a ``Trainer`` yourself.
"""
# pylint: disable=too-many-lines

import logging
import os
import shutil
import time
import re
import datetime
import traceback
from typing import Dict, Optional, List, Tuple, Union, Iterable, Any, Set

import torch
import torch.optim.lr_scheduler
from torch.nn.parallel import replicate, parallel_apply
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from tensorboardX import SummaryWriter

from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import peak_memory_mb, gpu_memory_mb, dump_metrics
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.models.model import Model
from allennlp.nn import util
from allennlp.training.learning_rate_schedulers import LearningRateScheduler
from allennlp.training.optimizers import Optimizer

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


def is_sparse(tensor):
    return tensor.is_sparse


def sparse_clip_norm(parameters, max_norm, norm_type=2) -> float:
    """Clips gradient norm of an iterable of parameters.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.
    Supports sparse gradients.

    Parameters
    ----------
    parameters : ``(Iterable[torch.Tensor])``
        An iterable of Tensors that will have gradients normalized.
    max_norm : ``float``
        The max norm of the gradients.
    norm_type : ``float``
        The type of the used p-norm. Can be ``'inf'`` for infinity norm.

    Returns
    -------
    Total norm of the parameters (viewed as a single vector).
    """
    # pylint: disable=invalid-name,protected-access
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if norm_type == float('inf'):
        total_norm = max(p.grad.data.abs().max() for p in parameters)
    else:
        total_norm = 0
        for p in parameters:
            if is_sparse(p.grad):
                # need to coalesce the repeated indices before finding norm
                grad = p.grad.data.coalesce()
                param_norm = grad._values().norm(norm_type)
            else:
                param_norm = p.grad.data.norm(norm_type)
            total_norm += param_norm ** norm_type
        total_norm = total_norm ** (1. / norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            if is_sparse(p.grad):
                p.grad.data._values().mul_(clip_coef)
            else:
                p.grad.data.mul_(clip_coef)
    return total_norm


def move_optimizer_to_cuda(optimizer):
    """
    Move the optimizer state to GPU, if necessary.
    After calling, any parameter specific state in the optimizer
    will be located on the same device as the parameter.
    """
    for param_group in optimizer.param_groups:
        for param in param_group['params']:
            if param.is_cuda:
                param_state = optimizer.state[param]
                for k in param_state.keys():
                    if isinstance(param_state[k], torch.Tensor):
                        param_state[k] = param_state[k].cuda(device=param.get_device())


class TensorboardWriter:
    """
    Wraps a pair of ``SummaryWriter`` instances but is a no-op if they're ``None``.
    Allows Tensorboard logging without always checking for Nones first.
    """
    def __init__(self, train_log: SummaryWriter = None, validation_log: SummaryWriter = None) -> None:
        self._train_log = train_log
        self._validation_log = validation_log

    @staticmethod
    def _item(value: Any):
        if hasattr(value, 'item'):
            val = value.item()
        else:
            val = value
        return val

    def add_train_scalar(self, name: str, value: float, global_step: int) -> None:
        # get the scalar
        if self._train_log is not None:
            self._train_log.add_scalar(name, self._item(value), global_step)

    def add_train_histogram(self, name: str, values: torch.Tensor, global_step: int) -> None:
        if self._train_log is not None:
            if isinstance(values, torch.Tensor):
                values_to_write = values.cpu().data.numpy().flatten()
                self._train_log.add_histogram(name, values_to_write, global_step)

    def add_validation_scalar(self, name: str, value: float, global_step: int) -> None:

        if self._validation_log is not None:
            self._validation_log.add_scalar(name, self._item(value), global_step)


def time_to_str(timestamp: int) -> str:
    """
    Convert seconds past Epoch to human readable string.
    """
    datetimestamp = datetime.datetime.fromtimestamp(timestamp)
    return '{:04d}-{:02d}-{:02d}-{:02d}-{:02d}-{:02d}'.format(
            datetimestamp.year, datetimestamp.month, datetimestamp.day,
            datetimestamp.hour, datetimestamp.minute, datetimestamp.second
    )


def str_to_time(time_str: str) -> datetime.datetime:
    """
    Convert human readable string to datetime.datetime.
    """
    pieces: Any = [int(piece) for piece in time_str.split('-')]
    return datetime.datetime(*pieces)


class MultiTaskTrainer2:
    def __init__(self,
                 model: Model,
                 optimizer: torch.optim.Optimizer,
                 iterator: DataIterator,
                 train_dataset: Iterable[Instance],
                 train_dataset_aux: Iterable[Instance],
                 train_dataset_aux2: Optional[Iterable[Instance]],
                 mixing_ratio: float = 0.17,
                 mixing_ratio2: float = 0.17,
                 cutoff_epoch: int = -1,
                 validation_dataset: Optional[Iterable[Instance]] = None,
                 validation_dataset_aux: Optional[Iterable] = None,
                 validation_dataset_aux2: Optional[Iterable[Instance]] = None,
                 patience: Optional[int] = None,
                 validation_metric: str = "-loss",
                 validation_iterator: DataIterator = None,
                 shuffle: bool = True,
                 num_epochs: int = 20,
                 serialization_dir: Optional[str] = None,
                 num_serialized_models_to_keep: int = 20,
                 keep_serialized_model_every_num_seconds: int = None,
                 model_save_interval: float = None,
                 cuda_device: Union[int, List] = -1,
                 grad_norm: Optional[float] = None,
                 grad_clipping: Optional[float] = None,
                 learning_rate_scheduler: Optional[LearningRateScheduler] = None,
                 summary_interval: int = 100,
                 histogram_interval: int = None,
                 should_log_parameter_statistics: bool = True,
                 should_log_learning_rate: bool = False,
                 iterator_aux: Optional[DataIterator] = None,
                 iterator_aux2: Optional[DataIterator] = None) -> None:
        """
        Parameters
        ----------
        model : ``Model``, required.
            An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
            their ``forward`` method returns a dictionary with a "loss" key, containing a
            scalar tensor representing the loss function to be optimized.
        optimizer : ``torch.nn.Optimizer``, required.
            An instance of a Pytorch Optimizer, instantiated with the parameters of the
            model to be optimized.
        iterator : ``DataIterator``, required.
            A method for iterating over a ``Dataset``, yielding padded indexed batches.
        train_dataset : ``Dataset``, required.
            A ``Dataset`` to train on. The dataset should have already been indexed.
        train_dataset_aux : ``Dataset``, required.
            A ``Dataset`` for auxiliary task 1 to train on.
        train_dataset_aux2 : ``Dataset``, required.
            A ``Dataset`` for second auxiliary task to train on. The dataset should have already been indexed.
        mixing_ratio: a float specifying the influence of the first auxiliary task on the final loss
        mixing_ratio2: a float specifying the influence of the second auxiliary task on the final loss
        cutoff_epoch: multitask training starts from the epoch after the epoch specified by cutoff_epoch
        validation_dataset : ``Dataset``, optional, (default = None).
            A ``Dataset`` to evaluate on. The dataset should have already been indexed.
        validation_dataset_aux : a validation dataset for the first auxiliary task
        validation_dataset_aux_2 : a validation dataset for the second auxiliary task
        patience : Optional[int] > 0, optional (default=None)
            Number of epochs to be patient before early stopping: the training is stopped
            after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
            If None, early stopping is disabled.
        validation_metric : str, optional (default="loss")
            Validation metric to measure for whether to stop training using patience
            and whether to serialize an ``is_best`` model each epoch. The metric name
            must be prepended with either "+" or "-", which specifies whether the metric
            is an increasing or decreasing function.
        validation_iterator : ``DataIterator``, optional (default=None)
            An iterator to use for the validation set.  If ``None``, then
            use the training `iterator`.
        shuffle: ``bool``, optional (default=True)
            Whether to shuffle the instances in the iterator or not.
        num_epochs : int, optional (default = 20)
            Number of training epochs.
        serialization_dir : str, optional (default=None)
            Path to directory for saving and loading model files. Models will not be saved if
            this parameter is not passed.
        num_serialized_models_to_keep : ``int``, optional (default=20)
            Number of previous model checkpoints to retain.  Default is to keep 20 checkpoints.
            A value of None or -1 means all checkpoints will be kept.
        keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
            If num_serialized_models_to_keep is not None, then occasionally it's useful to
            save models at a given interval in addition to the last num_serialized_models_to_keep.
            To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
            between permanently saved checkpoints.  Note that this option is only used if
            num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
        model_save_interval : ``float``, optional (default=None)
            If provided, then serialize models every ``model_save_interval``
            seconds within single epochs.  In all cases, models are also saved
            at the end of every epoch if ``serialization_dir`` is provided.
        cuda_device : ``int``, optional (default = -1)
            An integer specifying the CUDA device to use. If -1, the CPU is used.
        grad_norm : ``float``, optional, (default = None).
            If provided, gradient norms will be rescaled to have a maximum of this value.
        grad_clipping : ``float``, optional (default = ``None``).
            If provided, gradients will be clipped `during the backward pass` to have an (absolute)
            maximum of this value.  If you are getting ``NaNs`` in your gradients during training
            that are not solved by using ``grad_norm``, you may need this.
        learning_rate_scheduler : ``PytorchLRScheduler``, optional, (default = None)
            A Pytorch learning rate scheduler. The learning rate will be decayed with respect to
            this schedule at the end of each epoch. If you use
            :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`, this will use the ``validation_metric``
            provided to determine if learning has plateaued.  To support updating the learning
            rate on every batch, this can optionally implement ``step_batch(batch_num_total)`` which
            updates the learning rate given the batch number.
        summary_interval: ``int``, optional, (default = 100)
            Number of batches between logging scalars to tensorboard
        histogram_interval : ``int``, optional, (default = ``None``)
            If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
            When this parameter is specified, the following additional logging is enabled:
                * Histograms of model parameters
                * The ratio of parameter update norm to parameter norm
                * Histogram of layer activations
            We log histograms of the parameters returned by
            ``model.get_parameters_for_histogram_tensorboard_logging``.
            The layer activations are logged for any modules in the ``Model`` that have
            the attribute ``should_log_activations`` set to ``True``.  Logging
            histograms requires a number of GPU-CPU copies during training and is typically
            slow, so we recommend logging histograms relatively infrequently.
            Note: only Modules that return tensors, tuples of tensors or dicts
            with tensors as values currently support activation logging.
        should_log_parameter_statistics : ``bool``, optional, (default = True)
            Whether to send parameter statistics (mean and standard deviation
            of parameters and gradients) to tensorboard.
        should_log_learning_rate : ``bool``, optional, (default = False)
            Whether to send parameter specific learning rate to tensorboard.
        iterator_aux : ``DataIterator``, required.
            A method for iterating over a ``Dataset`` for the first auxiliary task, yielding padded indexed batches.
        iterator_aux2 : ``DataIterator``, required.
            A method for iterating over a ``Dataset`` for the second auxiliary task, yielding padded indexed batches.
        """
        self._model = model
        self._iterator = iterator
        self._validation_iterator = validation_iterator
        self._shuffle = shuffle
        self._optimizer = optimizer
        self._train_data = train_dataset
        self._validation_data = validation_dataset
        self._train_dataset_aux = train_dataset_aux
        self._train_dataset_aux2 = train_dataset_aux2
        self._validation_data_aux = validation_dataset_aux
        self._validation_data_aux2 = validation_dataset_aux2

        self._cutoff_epoch = cutoff_epoch
        self._mixing_ratio = mixing_ratio
        self._mixing_ratio2 = mixing_ratio2
        self._iterator_aux = iterator_aux
        self._iterator_aux2 = iterator_aux2

        if patience is None:  # no early stopping
            if validation_dataset:
                logger.warning('You provided a validation dataset but patience was set to None, '
                               'meaning that early stopping is disabled')
        elif (not isinstance(patience, int)) or patience <= 0:
            raise ConfigurationError('{} is an invalid value for "patience": it must be a positive integer '
                                     'or None (if you want to disable early stopping)'.format(patience))
        self._patience = patience
        self._num_epochs = num_epochs

        self._serialization_dir = serialization_dir
        self._num_serialized_models_to_keep = num_serialized_models_to_keep
        self._keep_serialized_model_every_num_seconds = keep_serialized_model_every_num_seconds
        self._serialized_paths: List[Any] = []
        self._last_permanent_saved_checkpoint_time = time.time()
        self._model_save_interval = model_save_interval

        self._grad_norm = grad_norm
        self._grad_clipping = grad_clipping
        self._learning_rate_scheduler = learning_rate_scheduler

        increase_or_decrease = validation_metric[0]
        if increase_or_decrease not in ["+", "-"]:
            raise ConfigurationError("Validation metrics must specify whether they should increase "
                                     "or decrease by pre-pending the metric name with a +/-.")
        self._validation_metric = validation_metric[1:]
        self._validation_metric_decreases = increase_or_decrease == "-"

        if not isinstance(cuda_device, int) and not isinstance(cuda_device, list):
            raise ConfigurationError("Expected an int or list for cuda_device, got {}".format(cuda_device))

        if isinstance(cuda_device, list):
            logger.warning(f"Multiple GPU support is experimental not recommended for use. "
                           "In some cases it may lead to incorrect results or undefined behavior.")
            self._multiple_gpu = True
            self._cuda_devices = cuda_device
        else:
            self._multiple_gpu = False
            self._cuda_devices = [cuda_device]

        if self._cuda_devices[0] != -1:
            self._model = self._model.cuda(self._cuda_devices[0])

        self._cuda_device = self._cuda_devices[0]

        self._log_interval = 10  # seconds
        self._summary_interval = summary_interval
        self._histogram_interval = histogram_interval
        self._log_histograms_this_batch = False
        self._should_log_parameter_statistics = should_log_parameter_statistics
        self._should_log_learning_rate = should_log_learning_rate

        # We keep the total batch number as a class variable because it
        # is used inside a closure for the hook which logs activations in
        # ``_enable_activation_logging``.
        self._batch_num_total = 0

        self._last_log = 0.0  # time of last logging

        if serialization_dir is not None:
            train_log = SummaryWriter(os.path.join(serialization_dir, "log", "train"))
            validation_log = SummaryWriter(os.path.join(serialization_dir, "log", "validation"))
            self._tensorboard = TensorboardWriter(train_log, validation_log)
        else:
            self._tensorboard = TensorboardWriter()
        self._warned_tqdm_ignores_underscores = False

    def _enable_gradient_clipping(self) -> None:
        if self._grad_clipping is not None:
            # Pylint is unable to tell that we're in the case that _grad_clipping is not None...
            # pylint: disable=invalid-unary-operand-type
            clip_function = lambda grad: grad.clamp(-self._grad_clipping, self._grad_clipping)
            for parameter in self._model.parameters():
                if parameter.requires_grad:
                    parameter.register_hook(clip_function)

    def _enable_activation_logging(self) -> None:
        """
        Log activations to tensorboard
        """
        if self._histogram_interval is not None:
            # To log activation histograms to the forward pass, we register
            # a hook on forward to capture the output tensors.
            # This uses a closure on self._log_histograms_this_batch to
            # determine whether to send the activations to tensorboard,
            # since we don't want them on every call.
            for _, module in self._model.named_modules():
                if not getattr(module, 'should_log_activations', False):
                    # skip it
                    continue

                def hook(module_, inputs, outputs):
                    # pylint: disable=unused-argument,cell-var-from-loop
                    log_prefix = 'activation_histogram/{0}'.format(module_.__class__)
                    if self._log_histograms_this_batch:
                        if isinstance(outputs, torch.Tensor):
                            log_name = log_prefix
                            self._tensorboard.add_train_histogram(log_name,
                                                                  outputs,
                                                                  self._batch_num_total)
                        elif isinstance(outputs, (list, tuple)):
                            for i, output in enumerate(outputs):
                                log_name = "{0}_{1}".format(log_prefix, i)
                                self._tensorboard.add_train_histogram(log_name,
                                                                      output,
                                                                      self._batch_num_total)
                        elif isinstance(outputs, dict):
                            for k, tensor in outputs.items():
                                log_name = "{0}_{1}".format(log_prefix, k)
                                self._tensorboard.add_train_histogram(log_name,
                                                                      tensor,
                                                                      self._batch_num_total)
                        else:
                            # skip it
                            pass

                module.register_forward_hook(hook)

    def _rescale_gradients(self) -> Optional[float]:
        """
        Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled.
        """
        if self._grad_norm:
            parameters_to_clip = [p for p in self._model.parameters()
                                  if p.grad is not None]
            return sparse_clip_norm(parameters_to_clip, self._grad_norm)
        return None

    def _data_parallel(self, batch):
        """
        Do the forward pass using multiple GPUs.  This is a simplification
        of torch.nn.parallel.data_parallel to support the allennlp model
        interface.
        """
        inputs, module_kwargs = scatter_kwargs((), batch, self._cuda_devices, 0)
        used_device_ids = self._cuda_devices[:len(inputs)]
        replicas = replicate(self._model, used_device_ids)
        outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)

        # Only the 'loss' is needed.
        # a (num_gpu, ) tensor with loss on each GPU
        losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0)
        return {'loss': losses.mean()}

    def _batch_loss(self, batch: torch.Tensor,
                    for_training: bool,
                    batch_aux: torch.Tensor=None,
                    batch_aux2: torch.Tensor=None) -> torch.Tensor:
        """
        Does a forward pass on the given batch and auxiliary data batches and returns the ``loss`` value in the result.
        If ``for_training`` is `True` also applies regularization penalty.
        """
        if self._multiple_gpu:
            output_dict = self._data_parallel(batch)
            if batch_aux is not None:
                raise ConfigurationError('multi-gpu not supported for multi-task training.')
        else:
            batch = util.move_to_device(batch, self._cuda_devices[0])
            print(batch)
            output_dict = self._model(**batch)

        try:
            loss = output_dict["loss"]
            if for_training:
                loss += self._model.get_regularization_penalty()
        except KeyError:
            if for_training:
                raise RuntimeError("The model you are trying to optimize does not contain a"
                                   " 'loss' key in the output of model.forward(inputs).")
            loss = None

        if batch_aux is not None and batch_aux2 is not None:
            batch_aux = util.move_to_device(batch_aux, self._cuda_devices[0])
            batch_aux2 = util.move_to_device(batch_aux2, self._cuda_devices[0])
            output_dict_aux = self._model(**batch_aux)
            output_dict_aux2 = self._model(**batch_aux2)
            try:
                loss_aux = output_dict_aux["loss"]
                loss_aux2 = output_dict_aux2["loss"]
                if for_training:
                    loss_aux += self._model.get_regularization_penalty()
                    loss_aux2 += self._model.get_regularization_penalty()
            except KeyError:
                raise ConfigurationError("The auxiliary model you are trying to optimize does not contain a"
                                         " 'loss' key in the output of model.forward(inputs).")

            # multi-task loss
            loss = loss + self._mixing_ratio * loss_aux + self._mixing_ratio2 * loss_aux2
        return loss

    def _get_metrics(self, total_loss: float, num_batches: int, reset: bool = False) -> Dict[str, float]:
        """
        Gets the metrics but sets ``"loss"`` to
        the total loss divided by the ``num_batches`` so that
        the ``"loss"`` metric is "average loss per batch".
        """
        metrics = self._model.get_metrics(reset=reset)
        metrics["loss"] = float(total_loss / num_batches) if num_batches > 0 else 0.0
        return metrics

    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics.
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        logger.info(f"Peak CPU memory usage MB: {peak_memory_mb()}")
        for gpu, memory in gpu_memory_mb().items():
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self._model.train()

        # Get tqdm for the training batches
        train_generator = self._iterator(self._train_data,
                                         num_epochs=1,
                                         shuffle=self._shuffle)
        train_generator_aux = self._iterator_aux(self._train_dataset_aux,
                                                 num_epochs=1,
                                                 shuffle=self._shuffle)
        train_generator_aux2 = self._iterator_aux2(self._train_dataset_aux2,
                                                  num_epochs=1,
                                                  shuffle=self._shuffle)

        multitask_training = False
        if epoch > self._cutoff_epoch:
            multitask_training = True
            logger.info("Multitask Training")
        else:
            logger.info("Training")

        num_training_batches = self._iterator.get_num_batches(self._train_data)
        num_training_batches_aux = self._iterator_aux.get_num_batches(self._train_dataset_aux)
        num_training_batches_aux2 = self._iterator_aux2.get_num_batches(self._train_dataset_aux2)
        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        if self._histogram_interval is not None:
            histogram_parameters = set(self._model.get_parameters_for_histogram_tensorboard_logging())

        logger.info("Training")
        train_generator_tqdm = Tqdm.tqdm(train_generator,
                                         total=num_training_batches)
        # train_aux_generator_tqdm = Tqdm.tqdm(train_generator_aux,
        #                                      total=num_training_batches_aux)
        for batch, batch_aux, batch_aux2 in zip(train_generator_tqdm, train_generator_aux, train_generator_aux2):
            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self._log_histograms_this_batch = self._histogram_interval is not None and (
                    batch_num_total % self._histogram_interval == 0)

            self._optimizer.zero_grad()

            if multitask_training:
                loss = self._batch_loss(batch,
                                        for_training=True,
                                        batch_aux=batch_aux,
                                        batch_aux2=batch_aux2)
            else:
                loss = self._batch_loss(batch, for_training=True)

            loss.backward()

            train_loss += loss.item()

            batch_grad_norm = self._rescale_gradients()

            # This does nothing if batch_num_total is None or you are using an
            # LRScheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)

            if self._log_histograms_this_batch:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {name: param.detach().cpu().clone()
                                 for name, param in self._model.named_parameters()}
                self._optimizer.step()
                for name, param in self._model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1, ))
                    param_norm = torch.norm(param.view(-1, )).cpu()
                    self._tensorboard.add_train_scalar("gradient_update/" + name,
                                                       update_norm / (param_norm + 1e-7),
                                                       batch_num_total)
            else:
                self._optimizer.step()

            # Update the description with the latest metrics
            metrics = self._get_metrics(train_loss, batches_this_epoch)
            description = self._description_from_metrics(metrics)

            train_generator_tqdm.set_description(description, refresh=False)

            # Log parameter values to Tensorboard
            if batch_num_total % self._summary_interval == 0:
                if self._should_log_parameter_statistics:
                    self._parameter_and_gradient_statistics_to_tensorboard(batch_num_total, batch_grad_norm)
                if self._should_log_learning_rate:
                    self._learning_rates_to_tensorboard(batch_num_total)
                self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"], batch_num_total)
                self._metrics_to_tensorboard(batch_num_total,
                                             {"epoch_metrics/" + k: v for k, v in metrics.items()})

            if self._log_histograms_this_batch:
                self._histograms_to_tensorboard(batch_num_total, histogram_parameters)

            # Save model if needed.
            if self._model_save_interval is not None and (
                    time.time() - last_save_time > self._model_save_interval
            ):
                last_save_time = time.time()
                self._save_checkpoint(
                        '{0}.{1}'.format(epoch, time_to_str(int(last_save_time))), [], is_best=False
                )

        return self._get_metrics(train_loss, batches_this_epoch, reset=True)

    def _should_stop_early(self, metric_history: List[float]) -> bool:
        """
        uses patience and the validation metric to determine if training should stop early
        """
        if self._patience and self._patience < len(metric_history):
            # Pylint can't figure out that in this branch `self._patience` is an int.
            # pylint: disable=invalid-unary-operand-type

            # Is the best score in the past N epochs worse than or equal the best score overall?
            if self._validation_metric_decreases:
                return min(metric_history[-self._patience:]) >= min(metric_history[:-self._patience])
            else:
                return max(metric_history[-self._patience:]) <= max(metric_history[:-self._patience])

        return False

    def _parameter_and_gradient_statistics_to_tensorboard(self, # pylint: disable=invalid-name
                                                          epoch: int,
                                                          batch_grad_norm: float) -> None:
        """
        Send the mean and std of all parameters and gradients to tensorboard, as well
        as logging the average gradient norm.
        """
        # Log parameter values to Tensorboard
        for name, param in self._model.named_parameters():
            self._tensorboard.add_train_scalar("parameter_mean/" + name,
                                               param.data.mean(),
                                               epoch)
            self._tensorboard.add_train_scalar("parameter_std/" + name, param.data.std(), epoch)
            if param.grad is not None:
                if is_sparse(param.grad):
                    # pylint: disable=protected-access
                    grad_data = param.grad.data._values()
                else:
                    grad_data = param.grad.data

                # skip empty gradients
                if torch.prod(torch.tensor(grad_data.shape)).item() > 0: # pylint: disable=not-callable
                    self._tensorboard.add_train_scalar("gradient_mean/" + name,
                                                       grad_data.mean(),
                                                       epoch)
                    self._tensorboard.add_train_scalar("gradient_std/" + name,
                                                       grad_data.std(),
                                                       epoch)
                else:
                    # no gradient for a parameter with sparse gradients
                    logger.info("No gradient for %s, skipping tensorboard logging.", name)
        # norm of gradients
        if batch_grad_norm is not None:
            self._tensorboard.add_train_scalar("gradient_norm",
                                               batch_grad_norm,
                                               epoch)

    def _learning_rates_to_tensorboard(self, batch_num_total: int):
        """
        Send current parameter specific learning rates to tensorboard
        """
        # optimizer stores lr info keyed by parameter tensor
        # we want to log with parameter name
        names = {param: name for name, param in self._model.named_parameters()}
        for group in self._optimizer.param_groups:
            if 'lr' not in group:
                continue
            rate = group['lr']
            for param in group['params']:
                # check whether params has requires grad or not
                effective_rate = rate * float(param.requires_grad)
                self._tensorboard.add_train_scalar(
                        "learning_rate/" + names[param],
                        effective_rate,
                        batch_num_total
                )

    def _histograms_to_tensorboard(self, epoch: int, histogram_parameters: Set[str]) -> None:
        """
        Send histograms of parameters to tensorboard.
        """
        for name, param in self._model.named_parameters():
            if name in histogram_parameters:
                self._tensorboard.add_train_histogram("parameter_histogram/" + name,
                                                      param,
                                                      epoch)

    def _metrics_to_tensorboard(self,
                                epoch: int,
                                train_metrics: dict,
                                val_metrics: dict = None) -> None:
        """
        Sends all of the train metrics (and validation metrics, if provided) to tensorboard.
        """
        metric_names = set(train_metrics.keys())
        if val_metrics is not None:
            metric_names.update(val_metrics.keys())
        val_metrics = val_metrics or {}

        for name in metric_names:
            train_metric = train_metrics.get(name)
            if train_metric is not None:
                self._tensorboard.add_train_scalar(name, train_metric, epoch)
            val_metric = val_metrics.get(name)
            if val_metric is not None:
                self._tensorboard.add_validation_scalar(name, val_metric, epoch)

    def _metrics_to_console(self,  # pylint: disable=no-self-use
                            train_metrics: dict,
                            val_metrics: dict = None) -> None:
        """
        Logs all of the train metrics (and validation metrics, if provided) to the console.
        """
        val_metrics = val_metrics or {}
        dual_message_template = "%s |  %8.3f  |  %8.3f"
        no_val_message_template = "%s |  %8.3f  |  %8s"
        no_train_message_template = "%s |  %8s  |  %8.3f"
        header_template = "%s |  %-10s"

        metric_names = set(train_metrics.keys())
        if val_metrics:
            metric_names.update(val_metrics.keys())

        name_length = max([len(x) for x in metric_names])

        logger.info(header_template, "Training".rjust(name_length + 13), "Validation")
        for name in metric_names:
            train_metric = train_metrics.get(name)
            val_metric = val_metrics.get(name)

            if val_metric is not None and train_metric is not None:
                logger.info(dual_message_template, name.ljust(name_length), train_metric, val_metric)
            elif val_metric is not None:
                logger.info(no_train_message_template, name.ljust(name_length), "N/A", val_metric)
            elif train_metric is not None:
                logger.info(no_val_message_template, name.ljust(name_length), train_metric, "N/A")

    def _validation_loss(self) -> Tuple[float, int]:
        """
        Computes the validation loss. Returns it and the number of batches.
        """
        logger.info("Validating")

        self._model.eval()

        if self._validation_iterator is not None:
            val_iterator = self._validation_iterator
        else:
            val_iterator = self._iterator

        val_generator = val_iterator(self._validation_data,
                                     num_epochs=1,
                                     shuffle=False)
        num_validation_batches = val_iterator.get_num_batches(self._validation_data)
        val_generator_tqdm = Tqdm.tqdm(val_generator,
                                       total=num_validation_batches)
        batches_this_epoch = 0
        val_loss = 0
        for batch in val_generator_tqdm:

            loss = self._batch_loss(batch, for_training=False)
            if loss is not None:
                # You shouldn't necessarily have to compute a loss for validation, so we allow for
                # `loss` to be None.  We need to be careful, though - `batches_this_epoch` is
                # currently only used as the divisor for the loss function, so we can safely only
                # count those batches for which we actually have a loss.  If this variable ever
                # gets used for something else, we might need to change things around a bit.
                batches_this_epoch += 1
                val_loss += loss.detach().cpu().numpy()

            # Update the description with the latest metrics
            val_metrics = self._get_metrics(val_loss, batches_this_epoch)
            description = self._description_from_metrics(val_metrics)
            val_generator_tqdm.set_description(description, refresh=False)

        return val_loss, batches_this_epoch

    def train(self) -> Dict[str, Any]:
        """
        Trains the supplied model with the supplied parameters.
        """
        try:
            epoch_counter, validation_metric_per_epoch = self._restore_checkpoint()
        except RuntimeError:
            traceback.print_exc()
            raise ConfigurationError("Could not recover training from the checkpoint.  Did you mean to output to "
                                     "a different serialization directory or delete the existing serialization "
                                     "directory?")

        self._enable_gradient_clipping()
        self._enable_activation_logging()

        logger.info("Beginning training.")

        train_metrics: Dict[str, float] = {}
        val_metrics: Dict[str, float] = {}
        metrics: Dict[str, Any] = {}
        epochs_trained = 0
        training_start_time = time.time()

        for epoch in range(epoch_counter, self._num_epochs):
            epoch_start_time = time.time()
            train_metrics = self._train_epoch(epoch)

            if self._validation_data is not None:
                with torch.no_grad():
                    # We have a validation set, so compute all the metrics on it.
                    val_loss, num_batches = self._validation_loss()
                    val_metrics = self._get_metrics(val_loss, num_batches, reset=True)

                    # Check validation metric for early stopping
                    this_epoch_val_metric = val_metrics[self._validation_metric]

                    # Check validation metric to see if it's the best so far
                    is_best_so_far = self._is_best_so_far(this_epoch_val_metric, validation_metric_per_epoch)
                    validation_metric_per_epoch.append(this_epoch_val_metric)
                    if self._should_stop_early(validation_metric_per_epoch):
                        logger.info("Ran out of patience.  Stopping training.")
                        break

            else:
                # No validation set, so just assume it's the best so far.
                is_best_so_far = True
                val_metrics = {}
                this_epoch_val_metric = None

            self._metrics_to_tensorboard(epoch, train_metrics, val_metrics=val_metrics)
            self._metrics_to_console(train_metrics, val_metrics)

            # Create overall metrics dict
            training_elapsed_time = time.time() - training_start_time
            metrics["training_duration"] = time.strftime("%H:%M:%S", time.gmtime(training_elapsed_time))
            metrics["training_start_epoch"] = epoch_counter
            metrics["training_epochs"] = epochs_trained
            metrics["epoch"] = epoch

            for key, value in train_metrics.items():
                metrics["training_" + key] = value
            for key, value in val_metrics.items():
                metrics["validation_" + key] = value

            if is_best_so_far:
                # Update all the best_ metrics.
                # (Otherwise they just stay the same as they were.)
                metrics['best_epoch'] = epoch
                for key, value in val_metrics.items():
                    metrics["best_validation_" + key] = value

            if self._serialization_dir:
                dump_metrics(os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics)

            if self._learning_rate_scheduler:
                # The LRScheduler API is agnostic to whether your schedule requires a validation metric -
                # if it doesn't, the validation metric passed here is ignored.
                self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)

            self._save_checkpoint(epoch, validation_metric_per_epoch, is_best=is_best_so_far)

            epoch_elapsed_time = time.time() - epoch_start_time
            logger.info("Epoch duration: %s", time.strftime("%H:%M:%S", time.gmtime(epoch_elapsed_time)))

            if epoch < self._num_epochs - 1:
                training_elapsed_time = time.time() - training_start_time
                estimated_time_remaining = training_elapsed_time * \
                    ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1)
                formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
                logger.info("Estimated training time remaining: %s", formatted_time)

            epochs_trained += 1

        return metrics

    def _is_best_so_far(self,
                        this_epoch_val_metric: float,
                        validation_metric_per_epoch: List[float]):
        if not validation_metric_per_epoch:
            return True
        elif self._validation_metric_decreases:
            return this_epoch_val_metric < min(validation_metric_per_epoch)
        else:
            return this_epoch_val_metric > max(validation_metric_per_epoch)

    def _description_from_metrics(self, metrics: Dict[str, float]) -> str:
        if (not self._warned_tqdm_ignores_underscores and
                    any(metric_name.startswith("_") for metric_name in metrics)):
            logger.warning("Metrics with names beginning with \"_\" will "
                           "not be logged to the tqdm progress bar.")
            self._warned_tqdm_ignores_underscores = True
        return ', '.join(["%s: %.4f" % (name, value) for name, value in
                          metrics.items() if not name.startswith("_")]) + " ||"

    def _save_checkpoint(self,
                         epoch: Union[int, str],
                         val_metric_per_epoch: List[float],
                         is_best: Optional[bool] = None) -> None:
        """
        Saves a checkpoint of the model to self._serialization_dir.
        Is a no-op if self._serialization_dir is None.

        Parameters
        ----------
        epoch : Union[int, str], required.
            The epoch of training.  If the checkpoint is saved in the middle
            of an epoch, the parameter is a string with the epoch and timestamp.
        is_best: bool, optional (default = None)
            A flag which causes the model weights at the given epoch to
            be copied to a "best.th" file. The value of this flag should
            be based on some validation metric computed by your model.
        """
        if self._serialization_dir is not None:
            model_path = os.path.join(self._serialization_dir, "model_state_epoch_{}.th".format(epoch))
            model_state = self._model.state_dict()
            torch.save(model_state, model_path)

            training_state = {'epoch': epoch,
                              'val_metric_per_epoch': val_metric_per_epoch,
                              'optimizer': self._optimizer.state_dict(),
                              'batch_num_total': self._batch_num_total}
            if self._learning_rate_scheduler is not None:
                training_state["learning_rate_scheduler"] = \
                    self._learning_rate_scheduler.lr_scheduler.state_dict()
            training_path = os.path.join(self._serialization_dir,
                                         "training_state_epoch_{}.th".format(epoch))
            torch.save(training_state, training_path)
            if is_best:
                logger.info("Best validation performance so far. "
                            "Copying weights to '%s/best.th'.", self._serialization_dir)
                shutil.copyfile(model_path, os.path.join(self._serialization_dir, "best.th"))

            if self._num_serialized_models_to_keep and self._num_serialized_models_to_keep >= 0:
                self._serialized_paths.append([time.time(), model_path, training_path])
                if len(self._serialized_paths) > self._num_serialized_models_to_keep:
                    paths_to_remove = self._serialized_paths.pop(0)
                    # Check to see if we should keep this checkpoint, if it has been longer
                    # then self._keep_serialized_model_every_num_seconds since the last
                    # kept checkpoint.
                    remove_path = True
                    if self._keep_serialized_model_every_num_seconds is not None:
                        save_time = paths_to_remove[0]
                        time_since_checkpoint_kept = save_time - self._last_permanent_saved_checkpoint_time
                        if time_since_checkpoint_kept > self._keep_serialized_model_every_num_seconds:
                            # We want to keep this checkpoint.
                            remove_path = False
                            self._last_permanent_saved_checkpoint_time = save_time
                    if remove_path:
                        for fname in paths_to_remove[1:]:
                            os.remove(fname)

    def find_latest_checkpoint(self) -> Tuple[str, str]:
        """
        Return the location of the latest model and training state files.
        If there isn't a valid checkpoint then return None.
        """
        have_checkpoint = (self._serialization_dir is not None and
                           any("model_state_epoch_" in x for x in os.listdir(self._serialization_dir)))

        if not have_checkpoint:
            return None

        serialization_files = os.listdir(self._serialization_dir)
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        # Get the last checkpoint file.  Epochs are specified as either an
        # int (for end of epoch files) or with epoch and timestamp for
        # within epoch checkpoints, e.g. 5.2018-02-02-15-33-42
        found_epochs = [
                # pylint: disable=anomalous-backslash-in-string
                re.search("model_state_epoch_([0-9\.\-]+)\.th", x).group(1)
                for x in model_checkpoints
        ]
        int_epochs: Any = []
        for epoch in found_epochs:
            pieces = epoch.split('.')
            if len(pieces) == 1:
                # Just a single epoch without timestamp
                int_epochs.append([int(pieces[0]), 0])
            else:
                # has a timestamp
                int_epochs.append([int(pieces[0]), pieces[1]])
        last_epoch = sorted(int_epochs, reverse=True)[0]
        if last_epoch[1] == 0:
            epoch_to_load = str(last_epoch[0])
        else:
            epoch_to_load = '{0}.{1}'.format(last_epoch[0], last_epoch[1])

        model_path = os.path.join(self._serialization_dir,
                                  "model_state_epoch_{}.th".format(epoch_to_load))
        training_state_path = os.path.join(self._serialization_dir,
                                           "training_state_epoch_{}.th".format(epoch_to_load))

        return (model_path, training_state_path)

    def _restore_checkpoint(self) -> Tuple[int, List[float]]:
        """
        Restores a model from a serialization_dir to the last saved checkpoint.
        This includes an epoch count and optimizer state, which is serialized separately
        from  model parameters. This function should only be used to continue training -
        if you wish to load a model for inference/load parts of a model into a new
        computation graph, you should use the native Pytorch functions:
        `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``

        If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
        this function will do nothing and return 0.

        Returns
        -------
        epoch: int
            The epoch at which to resume training, which should be one after the epoch
            in the saved training state.
        """
        latest_checkpoint = self.find_latest_checkpoint()

        if latest_checkpoint is None:
            # No checkpoint to restore, start at 0
            return 0, []

        model_path, training_state_path = latest_checkpoint

        # Load the parameters onto CPU, then transfer to GPU.
        # This avoids potential OOM on GPU for large models that
        # load parameters onto GPU then make a new GPU copy into the parameter
        # buffer. The GPU transfer happens implicitly in load_state_dict.
        model_state = torch.load(model_path, map_location=util.device_mapping(-1))
        training_state = torch.load(training_state_path, map_location=util.device_mapping(-1))
        self._model.load_state_dict(model_state)
        self._optimizer.load_state_dict(training_state["optimizer"])
        if self._learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state:
            self._learning_rate_scheduler.lr_scheduler.load_state_dict(
                    training_state["learning_rate_scheduler"])
        move_optimizer_to_cuda(self._optimizer)

        # We didn't used to save `validation_metric_per_epoch`, so we can't assume
        # that it's part of the trainer state. If it's not there, an empty list is all
        # we can do.
        if "val_metric_per_epoch" not in training_state:
            logger.warning("trainer state `val_metric_per_epoch` not found, using empty list")
            val_metric_per_epoch: List[float] = []
        else:
            val_metric_per_epoch = training_state["val_metric_per_epoch"]

        if isinstance(training_state["epoch"], int):
            epoch_to_return = training_state["epoch"] + 1
        else:
            epoch_to_return = int(training_state["epoch"].split('.')[0]) + 1

        # For older checkpoints with batch_num_total missing, default to old behavior where
        # it is unchanged.
        batch_num_total = training_state.get('batch_num_total')
        if batch_num_total is not None:
            self._batch_num_total = batch_num_total

        return epoch_to_return, val_metric_per_epoch

    # Requires custom from_params.
    @classmethod
    def from_params(cls,
                    model: Model,
                    serialization_dir: str,
                    iterator: DataIterator,
                    iterator_aux: DataIterator,
                    iterator_aux2: DataIterator,
                    train_data: Iterable[Instance],
                    train_data_aux: Iterable[Instance],
                    train_data_aux2: Iterable[Instance],
                    mixing_ratio: float,
                    mixing_ratio2: float,
                    cutoff_epoch: int,
                    validation_data: Optional[Iterable[Instance]],
                    validation_data_aux: Optional[Iterable[Instance]],
                    validation_data_aux2: Optional[Iterable[Instance]],
                    params: Params,
                    validation_iterator: DataIterator = None) -> 'MultiTaskTrainer2':

        patience = params.pop_int("patience", None)
        validation_metric = params.pop("validation_metric", "-loss")
        shuffle = params.pop_bool("shuffle", True)
        num_epochs = params.pop_int("num_epochs", 20)
        cuda_device = params.pop_int("cuda_device", -1)
        grad_norm = params.pop_float("grad_norm", None)
        grad_clipping = params.pop_float("grad_clipping", None)
        lr_scheduler_params = params.pop("learning_rate_scheduler", None)

        if cuda_device >= 0:
            model = model.cuda(cuda_device)
        parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))

        if lr_scheduler_params:
            scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
        else:
            scheduler = None

        num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
        keep_serialized_model_every_num_seconds = params.pop_int(
                "keep_serialized_model_every_num_seconds", None)
        model_save_interval = params.pop_float("model_save_interval", None)
        summary_interval = params.pop_int("summary_interval", 100)
        histogram_interval = params.pop_int("histogram_interval", None)
        should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
        should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)

        params.assert_empty(cls.__name__)
        return MultiTaskTrainer2(model, optimizer, iterator,
                                    train_data,
                                    train_data_aux,
                                    train_data_aux2,
                                    mixing_ratio,
                                    mixing_ratio2,
                                    cutoff_epoch,
                                    validation_data,
                                    validation_data_aux,
                                    validation_data_aux2,
                                    patience=patience,
                                    validation_metric=validation_metric,
                                    validation_iterator=validation_iterator,
                                    shuffle=shuffle,
                                    num_epochs=num_epochs,
                                    serialization_dir=serialization_dir,
                                    cuda_device=cuda_device,
                                    grad_norm=grad_norm,
                                    grad_clipping=grad_clipping,
                                    learning_rate_scheduler=scheduler,
                                    num_serialized_models_to_keep=num_serialized_models_to_keep,
                                    keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
                                    model_save_interval=model_save_interval,
                                    summary_interval=summary_interval,
                                    histogram_interval=histogram_interval,
                                    should_log_parameter_statistics=should_log_parameter_statistics,
                                    should_log_learning_rate=should_log_learning_rate,
                                    iterator_aux=iterator_aux,
                                    iterator_aux2=iterator_aux2)


In [19]:
# train_model_from_file('experiment_configs/custom_config.json', './runs/test1')

### Debug

In [20]:
params = Params.from_file('experiment_configs/custom_config.json', "")
serialization_dir = './runs/test5'
file_friendly_logging = False
recover = False

In [21]:

# train_model(params, serialization_dir, file_friendly_logging, recover)
"""
Trains the model specified in the given :class:`Params` object, using the data and training
parameters also specified in that object, and saves the results in ``serialization_dir``.

Parameters
----------
params : ``Params``
    A parameter object specifying an AllenNLP Experiment.
serialization_dir : ``str``
    The directory in which to save results and logs.
file_friendly_logging : ``bool``, optional (default=False)
    If ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow
    down tqdm's output to only once every 10 seconds.
recover : ``bool``, optional (default=False)
    If ``True``, we will try to recover a training run from an existing serialization
    directory.  This is only intended for use when something actually crashed during the middle
    of a run.  For continuing training a model on new data, see the ``fine-tune`` command.

Returns
-------
best_model: ``Model``
    The model with the best epoch weights.
"""
prepare_environment(params)

create_serialization_dir(params, serialization_dir, recover)
prepare_global_logging(serialization_dir, file_friendly_logging)

check_for_gpu(params.get('trainer').get('cuda_device', -1))

params.to_file(os.path.join(serialization_dir, CONFIG_NAME))

all_datasets, all_datasets_aux, all_datasets_aux2 = datasets_from_params(params)
# print(all_datasets)
datasets_for_vocab_creation = set(params.pop("datasets_for_vocab_creation", all_datasets))
datasets_for_vocab_creation_aux = set(params.pop("auxiliary_datasets_for_vocab_creation", all_datasets_aux))
datasets_for_vocab_creation_aux2 = set(params.pop("auxiliary_datasets_for_vocab_creation_2", all_datasets_aux2))


mixing_ratio = params.pop_float("mixing_ratio")
mixing_ratio2 = params.pop_float("mixing_ratio2")

cutoff_epoch = params.pop("cutoff_epoch", -1)

for dataset in datasets_for_vocab_creation:
    if dataset not in all_datasets:
        raise ConfigurationError(f"invalid 'dataset_for_vocab_creation' {dataset}")

logger.info("From dataset instances, %s will be considered for vocabulary creation.",
            ", ".join(datasets_for_vocab_creation))
vocab_instances_aux = [
    instance for key, dataset in all_datasets_aux.items()
    for instance in dataset
    if key in datasets_for_vocab_creation_aux
]
vocab_instances_aux.extend([
    instance for key, dataset in all_datasets_aux2.items()
    for instance in dataset
    if key in datasets_for_vocab_creation_aux2
])
vocab = VocabularyMultitask.from_params(
        params.pop("vocabulary", {}),
        (instance for key, dataset in all_datasets.items()
         for instance in dataset
         if key in datasets_for_vocab_creation),
        instances_aux=vocab_instances_aux
)
model = Model.from_params(vocab=vocab, params=params.pop('model'))

# Initializing the model can have side effect of expanding the vocabulary
vocab.save_to_files(os.path.join(serialization_dir, "vocabulary"))

iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)

iterator_aux = DataIterator.from_params(params.pop("iterator_aux"))
iterator_aux.index_with(vocab)

iterator_aux2 = DataIterator.from_params(params.pop("iterator_aux2"))
iterator_aux2.index_with(vocab)

validation_iterator_params = params.pop("validation_iterator", None)
if validation_iterator_params:
    validation_iterator = DataIterator.from_params(validation_iterator_params)
    validation_iterator.index_with(vocab)
else:
    validation_iterator = None

# TODO: if validation in multi-task need to add validation iterator as above

train_data = all_datasets.get('train')
validation_data = all_datasets.get('validation')
test_data = all_datasets.get('test')

train_data_aux = all_datasets_aux.get('train_aux')
validation_data_aux = all_datasets_aux.get('validation_aux')
test_data_aux = all_datasets_aux.get('test_aux')

train_data_aux2 = all_datasets_aux2.get('train_aux')
validation_data_aux2 = all_datasets_aux2.get('validation_aux')
test_data_aux2 = all_datasets_aux2.get('test_aux')

trainer_params = params.pop("trainer")
no_grad_regexes = trainer_params.pop("no_grad", ())
for name, parameter in model.named_parameters():
    if any(re.search(regex, name) for regex in no_grad_regexes):
        parameter.requires_grad_(False)

frozen_parameter_names, tunable_parameter_names = \
               get_frozen_and_tunable_parameter_names(model)
logger.info("Following parameters are Frozen  (without gradient):")
for name in frozen_parameter_names:
    logger.info(name)
logger.info("Following parameters are Tunable (with gradient):")
for name in tunable_parameter_names:
    logger.info(name)

trainer = MultiTaskTrainer2.from_params(model=model,
                                        serialization_dir=serialization_dir,
                                        iterator=iterator,
                                        iterator_aux=iterator_aux,
                                        iterator_aux2=iterator_aux2,
                                        train_data=train_data,
                                        train_data_aux=train_data_aux,
                                        train_data_aux2=train_data_aux2,
                                        mixing_ratio=mixing_ratio,
                                        mixing_ratio2=mixing_ratio2,
                                        cutoff_epoch=cutoff_epoch,
                                        validation_data_aux=validation_data_aux,
                                        validation_data_aux2=validation_data_aux2,
                                        validation_data=validation_data,
                                        params=trainer_params,
                                        validation_iterator=validation_iterator)
# print(trainer._cuda_devices[0])
evaluate_on_test = params.pop_bool("evaluate_on_test", False)
evaluate_aux_on_test = params.pop_bool("evaluate_aux_on_test", False)
params.assert_empty('base train command')
'''
try:
    metrics = trainer.train()
except KeyboardInterrupt:
    # if we have completed an epoch, try to create a model archive.
    if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):
        logging.info("Training interrupted by the user. Attempting to create "
                     "a model archive using the current best epoch weights.")
        archive_model(serialization_dir, files_to_archive=params.files_to_archive)
    raise

# Now tar up results
archive_model(serialization_dir, files_to_archive=params.files_to_archive)

logger.info("Loading the best epoch weights.")
best_model_state_path = os.path.join(serialization_dir, 'best.th')
best_model_state = torch.load(best_model_state_path)
best_model = model
best_model.load_state_dict(best_model_state)

if test_data and evaluate_on_test:
    logger.info("The model will be evaluated using the best epoch weights.")
    test_metrics = evaluate(
            best_model, test_data, validation_iterator or iterator,
            cuda_device=trainer._cuda_devices[0] # pylint: disable=protected-access
    )
    for key, value in test_metrics.items():
        metrics["test_" + key] = value

elif test_data:
    logger.info("To evaluate on the test set after training, pass the "
                "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")

if test_data_aux and evaluate_aux_on_test:
    # for instance in test_data_aux:
    #     instance.index_fields(vocab)
    # for instance in test_data_aux2:
    #     instance.index_fields(vocab)
    test_metrics_aux = evaluate(best_model, test_data_aux, iterator_aux,
                                cuda_device=trainer._cuda_devices[0])  # pylint: disable=protected-access
    test_metrics_aux2 = evaluate(best_model, test_data_aux2, iterator_aux2,
                                 cuda_device=trainer._cuda_devices[0])  # pylint: disable=protected-access

    for key, value in test_metrics_aux.items():
        metrics["test_aux_" + key] = value
    for key, value in test_metrics_aux2.items():
        metrics["test_aux2_" + key] = value

elif test_data_aux:
    logger.info("To evaluate on the auxiliary test set after training, pass the "
                "'evaluate_on_test' flag, or use the 'allennlp evaluate' command.")

dump_metrics(os.path.join(serialization_dir, "metrics.json"), metrics, log=True)
'''

04/10/2024 22:23:08 - INFO - allennlp.common.params -   random_seed = 21016
04/10/2024 22:23:08 - INFO - allennlp.common.params -   numpy_seed = 5000
04/10/2024 22:23:08 - INFO - allennlp.common.params -   pytorch_seed = 8000
04/10/2024 22:23:08 - INFO - allennlp.common.checks -   Pytorch version: 1.10.2
04/10/2024 22:23:08 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.dataset_readers.dataset_reader.DatasetReader'> from params {'multilabel': 'false', 'type': 'scicite_datasetreader', 'use_sparse_lexicon_features': 'false', 'with_elmo': 'true'} and extras {}
04/10/2024 22:23:08 - INFO - allennlp.common.params -   dataset_reader.type = scicite_datasetreader
04/10/2024 22:23:08 - INFO - allennlp.common.params -   dataset_reader.lazy = False
04/10/2024 22:23:08 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.tokenizers.tokenizer.Tokenizer'> from params {} and extras {}
04/10/2024 22:23:08 - INFO - allennlp.common.params

pred mode: False


04/10/2024 22:24:37 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.iterators.data_iterator.DataIterator'> from params {'batch_size': 16, 'sorting_keys': [['citation_text', 'num_tokens']], 'type': 'bucket'} and extras {}
04/10/2024 22:24:37 - INFO - allennlp.common.params -   iterator.type = bucket
04/10/2024 22:24:37 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.iterators.bucket_iterator.BucketIterator'> from params {'batch_size': 16, 'sorting_keys': [['citation_text', 'num_tokens']]} and extras {}
04/10/2024 22:24:37 - INFO - allennlp.common.params -   iterator.sorting_keys = [['citation_text', 'num_tokens']]
04/10/2024 22:24:37 - INFO - allennlp.common.params -   iterator.padding_noise = 0.1
04/10/2024 22:24:37 - INFO - allennlp.common.params -   iterator.biggest_batch_first = False
04/10/2024 22:24:37 - INFO - allennlp.common.params -   iterator.batch_size = 16
04/10/2024 22:24:37 - INFO - allennlp.common.param

'\ntry:\n    metrics = trainer.train()\nexcept KeyboardInterrupt:\n    # if we have completed an epoch, try to create a model archive.\n    if os.path.exists(os.path.join(serialization_dir, _DEFAULT_WEIGHTS)):\n        logging.info("Training interrupted by the user. Attempting to create "\n                     "a model archive using the current best epoch weights.")\n        archive_model(serialization_dir, files_to_archive=params.files_to_archive)\n    raise\n\n# Now tar up results\narchive_model(serialization_dir, files_to_archive=params.files_to_archive)\n\nlogger.info("Loading the best epoch weights.")\nbest_model_state_path = os.path.join(serialization_dir, \'best.th\')\nbest_model_state = torch.load(best_model_state_path)\nbest_model = model\nbest_model.load_state_dict(best_model_state)\n\nif test_data and evaluate_on_test:\n    logger.info("The model will be evaluated using the best epoch weights.")\n    test_metrics = evaluate(\n            best_model, test_data, validation_ite

04/10/2024 22:24:41 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.dataset_readers.dataset_reader.DatasetReader'> from params {'multilabel': 'false', 'type': 'scicite_datasetreader', 'use_sparse_lexicon_features': 'false', 'with_elmo': 'true'} and extras {}
04/10/2024 22:24:41 - INFO - allennlp.common.params -   dataset_reader.type = scicite_datasetreader
04/10/2024 22:24:41 - INFO - allennlp.common.params -   dataset_reader.lazy = False
04/10/2024 22:24:41 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.tokenizers.tokenizer.Tokenizer'> from params {} and extras {}
04/10/2024 22:24:41 - INFO - allennlp.common.params -   dataset_reader.tokenizer.type = word
04/10/2024 22:24:41 - INFO - allennlp.common.from_params -   instantiating class <class 'allennlp.data.tokenizers.word_tokenizer.WordTokenizer'> from params {} and extras {}
04/10/2024 22:24:41 - INFO - allennlp.common.params -   dataset_reader.tokenizer.start_tok

{'citation_text': {'elmo': tensor([[[259,  84, 109,  ..., 261, 261, 261],
         [259, 100, 112,  ..., 261, 261, 261],
         [259,  79,  80,  ..., 261, 261, 261],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[259,  85, 105,  ..., 261, 261, 261],
         [259, 102, 121,  ..., 261, 261, 261],
         [259, 106, 116,  ..., 261, 261, 261],
         ...,
         [259, 105, 122,  ..., 261, 261, 261],
         [259,  47, 260,  ..., 261, 261, 261],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[259,  84, 102,  ..., 261, 261, 261],
         [259,  45, 260,  ..., 261, 261, 261],
         [259,  73,  66,  ..., 261, 261, 261],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        ...,

        [[259,  85, 105,  ..., 261, 261, 261],
         [259, 115


04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.type = None
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.extend = False
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.directory_path = None
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.min_count = None
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.max_vocab_size = None
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.non_padded_namespaces = ('*tags', '*labels')
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.min_pretrained_embeddings = None
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.only_include_pretrained_words = False
04/10/2024 22:35:58 - INFO - allennlp.common.params -   vocabulary.tokens_to_add = None
04/10/2024 22:35:58 - INFO - scicite.training.vocabulary_multitask -   Fitting token dictionary from dataset.
0it [00:00, ?it/s]
94189it [00:04, 22152.85it/s]

04/10/2024 

### Datareader

In [22]:
params = Params.from_file('experiment_configs/custom_config.json', "")
data = datasets_from_params(params)
# from scicite_datasetreader text_to_instance() -> Instance

In [23]:
# data[0]['train'][0].fields

In [24]:
# data[0]['train'][1].fields['citation_text'].tokens

In [34]:
data[0]['train'][0].fields['labels'].label

'background'

In [26]:
vocab_instances_aux[0].fields['citation_text'].tokens[:5]

[Therefore, ,, in, our, case]

In [27]:
# vocab._token_to_index

In [28]:
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)

iterator_aux = DataIterator.from_params(params.pop("iterator_aux"))
iterator_aux.index_with(vocab)

iterator_aux2 = DataIterator.from_params(params.pop("iterator_aux2"))
iterator_aux2.index_with(vocab)

In [29]:
iterator

<allennlp.data.iterators.bucket_iterator.BucketIterator at 0x7f13636ff1d0>

In [30]:
# train_data[0].fields['citation_text'].tokens

In [31]:
_shuffle = True
train_generator = iterator(train_data,
                                 num_epochs=1,
                                 shuffle=_shuffle)

train_generator_aux = iterator_aux(train_data_aux,
                                         num_epochs=1,
                                         shuffle=_shuffle)
train_generator_aux2 = iterator_aux2(train_data_aux2,
                                      num_epochs=1,
                                          shuffle=_shuffle)

In [32]:
train_generator

<generator object DataIterator.__call__ at 0x7f1370749db0>

In [33]:
num_training_batches=1
train_generator_tqdm = Tqdm.tqdm(train_generator,
                                 total=num_training_batches)

In [37]:
for batch, batch_aux, batch_aux2 in zip(train_generator_tqdm, train_generator_aux, train_generator_aux2):
    print(batch, batch_aux, batch_aux2)
    break

In [65]:
# batch['citation_text']

{'elmo': tensor([[[259,  87, 106,  ..., 261, 261, 261],
          [259,  66, 260,  ..., 261, 261, 261],
          [259, 116, 118,  ..., 261, 261, 261],
          ...,
          [  0,   0,   0,  ...,   0,   0,   0],
          [  0,   0,   0,  ...,   0,   0,   0],
          [  0,   0,   0,  ...,   0,   0,   0]],
 
         [[259,  71, 112,  ..., 261, 261, 261],
          [259,  72, 118,  ..., 261, 261, 261],
          [259,  98, 111,  ..., 261, 261, 261],
          ...,
          [  0,   0,   0,  ...,   0,   0,   0],
          [  0,   0,   0,  ...,   0,   0,   0],
          [  0,   0,   0,  ...,   0,   0,   0]],
 
         [[259,  72, 106,  ..., 261, 261, 261],
          [259, 117, 105,  ..., 261, 261, 261],
          [259, 108, 111,  ..., 261, 261, 261],
          ...,
          [259, 117, 105,  ..., 261, 261, 261],
          [259, 101, 106,  ..., 261, 261, 261],
          [  0,   0,   0,  ...,   0,   0,   0]],
 
         ...,
 
         [[259,  85, 105,  ..., 261, 261, 261],
          

In [38]:
# output_dict['prediction'] = labels
output_dict = {}
citation_text = []
for batch_text in batch['citation_text']['tokens']:
    citation_text.append([vocab.get_token_from_index(token_id.item()) for token_id in batch_text])
# output_dict['citation_text'] = citation_text
# output_dict['all_labels'] = [vocab.get_index_to_token_vocabulary(namespace="labels")
#                              for _ in range(output_dict['logits'].shape[0])]

In [39]:
citation_text

[['Slant',
  'column',
  'NO2',
  'is',
  'retrieved',
  'using',
  'a',
  'DOAS',
  'linear',
  'least',
  'squares',
  'fit',
  '(',
  'Platt',
  'and',
  'Stutz',
  ',',
  '@@NUM@@',
  ';',
  'Wenig',
  'et',
  'al',
  '.',
  ',',
  '@@NUM@@',
  ')',
  'as',
  'in',
  'the',
  'operational',
  'DOAS',
  'retrievals',
  '(',
  'Boersma',
  'et',
  'al',
  '.',
  ',',
  '2007;10',
  'Bucsela',
  'et',
  'al',
  '.',
  ',',
  '@@NUM@@',
  ')',
  '.',
  '@@PADDING@@',
  '@@PADDING@@',
  '@@PADDING@@',
  '@@PADDING@@'],
 ['This',
  'explanation',
  'is',
  'consistent',
  'with',
  'the',
  'observation',
  'that',
  'Swiss',
  'mice',
  'were',
  'unable',
  'to',
  'respond',
  'to',
  'surgical',
  'removal',
  'of',
  'some',
  'of',
  'their',
  'mammary',
  'tissue',
  'by',
  'elevation',
  'of',
  'milk',
  'production',
  'in',
  'the',
  'remaining',
  'tissue',
  '(',
  'Hammond',
  'et',
  'al',
  '.',
  ',',
  '@@NUM@@',
  ')',
  'which',
  'is',
  'also',
  'consistent',
  

## Datareader convert

In [53]:
""" Data reader for AllenNLP """


from typing import Dict, List
import json
import logging

import torch
from allennlp.data import Field
from overrides import overrides
from allennlp.common import Params
from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField, MultiLabelField, ListField, ArrayField, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.tokenizers import Tokenizer, WordTokenizer
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer, ELMoTokenCharactersIndexer, PretrainedTransformerIndexer

from scicite.resources.lexicons import ALL_ACTION_LEXICONS, ALL_CONCEPT_LEXICONS
from scicite.data import DataReaderJurgens
from scicite.data import DataReaderS2, DataReaderS2ExcerptJL
from scicite.compute_features import is_in_lexicon

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

from scicite.constants import S2_CATEGORIES, NONE_LABEL_NAME


# @DatasetReader.register("scicite_datasetreader")
class SciciteDatasetReader(DatasetReader):
    """
    Reads a JSON-lines file containing citations from the Semantic Scholar database, and creates a
    dataset suitable for document classification using these papers.

    The output of ``read`` is a list of ``Instance`` s with the fields:
        citation_text: ``TextField``
        label: ``LabelField``

    where the ``label`` is derived from the methodology/comparison labels.

    Parameters
    ----------
    lazy : ``bool`` (optional, default=False)
        Passed to ``DatasetReader``.  If this is ``True``, training will start sooner, but will
        take longer per batch.  This also allows training with datasets that are too large to fit
        in memory.
    tokenizer : ``Tokenizer``, optional
        Tokenizer to use to split the title and abstrct into words or other kinds of tokens.
        Defaults to ``WordTokenizer()``.
    token_indexers : ``Dict[str, TokenIndexer]``, optional
        Indexers used to define input token representations. Defaults to ``{"tokens":
        SingleIdTokenIndexer()}``.
    reader_format : can be `flat` or `nested`. `flat` for flat json format and nested for
        Json format where the each object contains multiple excerpts
    """
    def __init__(self,
                 lazy: bool = False,
                 tokenizer: Tokenizer = None,
                 use_lexicon_features: bool=False,
                 use_sparse_lexicon_features: bool = False,
                 multilabel: bool = False,
                 with_elmo: bool = False,
                 reader_format: str = 'flat') -> None:
        super().__init__(lazy)
        self._tokenizer = tokenizer or WordTokenizer() # using WordTokenizer() because config['tokenizer'] not specified 
        if with_elmo:
            # self._token_indexers = {"tokens": SingleIdTokenIndexer()}
            self._token_indexers = {"elmo": ELMoTokenCharactersIndexer(),
                                    "tokens": SingleIdTokenIndexer()}
        else:
            self._token_indexers = {"tokens": SingleIdTokenIndexer()}

        self.use_lexicon_features = use_lexicon_features
        self.use_sparse_lexicon_features = use_sparse_lexicon_features
        if self.use_lexicon_features or self.use_sparse_lexicon_features:
            self.lexicons = {**ALL_ACTION_LEXICONS, **ALL_CONCEPT_LEXICONS}
        self.multilabel = multilabel
        self.reader_format = reader_format

    @overrides
    def _read(self, jsonl_file: str):
        if self.reader_format == 'flat':
            reader_s2 = DataReaderS2ExcerptJL(jsonl_file)
        elif self.reader_format == 'nested':
            reader_s2 = DataReaderS2(jsonl_file)
        for citation in reader_s2.read():
            yield self.text_to_instance(
                citation_text=citation.text,
                intent=citation.intent,
                citing_paper_id=citation.citing_paper_id,
                cited_paper_id=citation.cited_paper_id,
                citation_excerpt_index=citation.citation_excerpt_index
            )

    @overrides
    def text_to_instance(self,
                         citation_text: str,
                         citing_paper_id: str,
                         cited_paper_id: str,
                         intent: List[str] = None,
                         citing_paper_title: str = None,
                         cited_paper_title: str = None,
                         citing_paper_year: int = None,
                         cited_paper_year: int = None,
                         citing_author_ids: List[str] = None,
                         cited_author_ids: List[str] = None,
                         extended_context: str = None,
                         section_number: int = None,
                         section_title: str = None,
                         cite_marker_begin: int = None,
                         cite_marker_end: int = None,
                         sents_before: List[str] = None,
                         sents_after: List[str] = None,
                         cleaned_cite_text: str = None,
                         citation_excerpt_index: str = None,
                         venue: str = None) -> Instance:  # type: ignore

        citation_tokens = self._tokenizer.tokenize(citation_text)

        fields = {
            'citation_text': TextField(citation_tokens, self._token_indexers),
        }

        if self.use_sparse_lexicon_features:
            # convert to regular string
            sent = [token.text.lower() for token in citation_tokens]
            lexicon_features, _ = is_in_lexicon(self.lexicons, sent)
            fields["lexicon_features"] = ListField([LabelField(feature, skip_indexing=True)
                                                    for feature in lexicon_features])

        if intent:
            if self.multilabel:
                fields['labels'] = MultiLabelField([S2_CATEGORIES[e] for e in intent], skip_indexing=True,
                                                   num_labels=len(S2_CATEGORIES))
            else:
                if not isinstance(intent, str):
                    raise TypeError(f"Undefined label format. Should be a string. Got: f'{intent}'")
                fields['labels'] = LabelField(intent)

        if citing_paper_year and cited_paper_year and \
                citing_paper_year > -1 and cited_paper_year > -1:
            year_diff = citing_paper_year - cited_paper_year
        else:
            year_diff = -1
        fields['year_diff'] = ArrayField(torch.Tensor([year_diff]))
        fields['citing_paper_id'] = MetadataField(citing_paper_id)
        fields['cited_paper_id'] = MetadataField(cited_paper_id)
        fields['citation_excerpt_index'] = MetadataField(citation_excerpt_index)
        fields['citation_id'] = MetadataField(f"{citing_paper_id}>{cited_paper_id}")
        return Instance(fields)

    @classmethod
    def from_params(cls, params: Params) -> 'SciciteDatasetReader':
        lazy = params.pop('lazy', False)
        tokenizer = Tokenizer.from_params(params.pop('tokenizer', {}))
        use_lexicon_features = params.pop_bool("use_lexicon_features", False)
        use_sparse_lexicon_features = params.pop_bool("use_sparse_lexicon_features", False)
        multilabel = params.pop_bool("multilabel")
        with_elmo = params.pop_bool("with_elmo", False)
        reader_format = params.pop("reader_format", 'flat')
        params.assert_empty(cls.__name__)
        return cls(lazy=lazy, tokenizer=tokenizer,
                   use_lexicon_features=use_lexicon_features,
                   use_sparse_lexicon_features=use_sparse_lexicon_features,
                   multilabel=multilabel,
                   with_elmo=with_elmo,
                   reader_format=reader_format)


ImportError: cannot import name 'PretrainedTransformerIndexer'

In [43]:
citation_text = "this is a happy pancake (123)!"
citation_tokens = WordTokenizer().tokenize(citation_text)
special_token = "<NUM>"
citation_tokens_preprocessed = [special_token if x.like_num else x for x in citation_tokens]

In [44]:
# if self.convert_num:
# citation_tokens = [self.NUM_TOKEN if x.like_num else x for x in citation_tokens]
processed_text = ""
for word in citation_text.split():
    processed_text += special_token if word.isdigit() else word
    processed_text += " "

citation_text = processed_text
citation_tokens = WordTokenizer.tokenize(citation_text)

TypeError: tokenize() missing 1 required positional argument: 'text'

In [45]:
citation_text.split()

['this', 'is', 'a', 'happy', 'pancake', '(123)!']

In [46]:
token = nlp("@@NUM@@")[0]
# token = doc[0]
token

NameError: name 'nlp' is not defined

In [47]:
# SingleIdTokenIndexer --> one embedding for each word
_token_indexers = {"elmo": ELMoTokenCharactersIndexer(),
                                    "tokens": SingleIdTokenIndexer(), "bert": PretrainedTransformerIndexer()}
fields = {
    'citation_text': TextField(citation_tokens, _token_indexers),
}

In [48]:
from scicite.training.vocabulary_multitask import VocabularyMultitask

params = Params.from_file('experiment_configs/custom_config.json', "")
serialization_dir = './runs/test'
file_friendly_logging = False
recover = False
vocab = VocabularyMultitask.from_params(
        params.pop("vocabulary", {}),
        (instance for key, dataset in all_datasets.items()
         for instance in dataset
         if key in datasets_for_vocab_creation),
        instances_aux=vocab_instances_aux
)

In [49]:
fields['citation_text'].as_tensor

<bound method TextField.as_tensor of <allennlp.data.fields.text_field.TextField object at 0x7f12eb487eb8>>

In [52]:
token_indices = fields['citation_text'].as_tensor(fields['citation_text'].get_padding_lengths()).get("tokens").detach().cpu().numpy()
token_indices

ConfigurationError: 'You must call .index(vocabulary) on a field before determining padding lengths.'

In [51]:
vocab._index_to_token['tokens'][9526]

'undesirable'

## TextFieldEmbedder

In [15]:
from allennlp.modules import FeedForward, Seq2VecEncoder, Seq2SeqEncoder, TextFieldEmbedder, Embedding, TimeDistributed
params = Params.from_file('experiment_configs/custom_config.json', "")
params

<allennlp.common.params.Params at 0x7fc90067c160>

In [16]:
# elmo = {
#       "tokens": {
#         "type": "embedding",
#         "pretrained_file": "/home/kanhon/Desktop/kanhon/NUS Computing/CS4248/github_dir/cs4248/scaffold/Project/sci-cite/scicite/pretrained_weights/GoogleNews-vectors-negative300.txt",
#         "embedding_dim": 300,
#         "trainable": "false"
#       },
#       "elmo": {
#         "type": "elmo_token_embedder",
#         "options_file": "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json",
#         "weight_file": "https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5",
#         "do_layer_norm": "true",
#         "dropout": 0.5
#       }
# }
text_field_embedder = TextFieldEmbedder.from_params(params.pop('model').pop("text_field_embedder"), vocab=vocab)

NameError: name 'vocab' is not defined

In [17]:
text_field_embedder

NameError: name 'text_field_embedder' is not defined

In [None]:
# from allennlp.models import *
# print(pretrained.get_pretrained_models())

ModuleNotFoundError: No module named 'transformers'

## Model

In [78]:
import operator
from copy import deepcopy
from distutils.version import StrictVersion
from typing import Dict, Optional

import allennlp
import numpy as np
import torch
import torch.nn.functional as F
from allennlp.common import Params
from allennlp.data import Instance
from allennlp.data import Vocabulary
from allennlp.data.dataset import Batch
from allennlp.data.fields import TextField, LabelField
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.models.model import Model
from allennlp.modules import FeedForward, Seq2VecEncoder, Seq2SeqEncoder, TextFieldEmbedder, Embedding, TimeDistributed
from allennlp.nn import InitializerApplicator, RegularizerApplicator
from allennlp.nn import util
from allennlp.training.metrics import CategoricalAccuracy, F1Measure
from overrides import overrides
from torch.nn import Parameter, Linear

from scicite.constants import  Scicite_Format_Nested_Jsonlines


import torch.nn as nn



In [90]:

# @Model.register("scaffold_bilstm_attention_classifier")
class ScaffoldBilstmAttentionClassifier1(Model):
    """
    This ``Model`` performs text classification for citation intents.  We assume we're given a
    citation text, and we predict some output label.
    """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 citation_text_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 classifier_feedforward_2: FeedForward,
                 classifier_feedforward_3: FeedForward,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 report_auxiliary_metrics: bool = False,
                 predict_mode: bool = False,
                 ) -> None:
        """
        Additional Args:
            lexicon_embedder_params: parameters for the lexicon attention model
            use_sparse_lexicon_features: whether to use sparse (onehot) lexicon features
            multilabel: whether the classification is multi-label
            data_format: s2 or jurgens
            report_auxiliary_metrics: report metrics for aux tasks
            predict_mode: predict unlabeled examples
        """
        super(ScaffoldBilstmAttentionClassifier1, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.num_classes_sections = self.vocab.get_vocab_size("section_labels")
        self.num_classes_cite_worthiness = self.vocab.get_vocab_size("cite_worthiness_labels")
        self.citation_text_encoder = citation_text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.classifier_feedforward_2 = classifier_feedforward_2
        self.classifier_feedforward_3 = classifier_feedforward_3

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}
        self.label_f1_metrics_sections = {}
        self.label_f1_metrics_cite_worthiness = {}
        # for i in range(self.num_classes):
        #     self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] =\
        #         F1Measure(positive_label=i)

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] =\
                F1Measure(positive_label=i)
        for i in range(self.num_classes_sections):
            self.label_f1_metrics_sections[vocab.get_token_from_index(index=i, namespace="section_labels")] =\
                F1Measure(positive_label=i)
        for i in range(self.num_classes_cite_worthiness):
            self.label_f1_metrics_cite_worthiness[vocab.get_token_from_index(index=i, namespace="cite_worthiness_labels")] =\
                F1Measure(positive_label=i)
        self.loss = torch.nn.CrossEntropyLoss()

        self.attention_seq2seq = Attention(citation_text_encoder.get_output_dim())

        self.report_auxiliary_metrics = report_auxiliary_metrics
        self.predict_mode = predict_mode

        initializer(self)

    @overrides
    def forward(self,
                citation_text: Dict[str, torch.LongTensor],
                labels: torch.LongTensor = None,
                lexicon_features: Optional[torch.IntTensor] = None,
                year_diff: Optional[torch.Tensor] = None,
                citing_paper_id: Optional[str] = None,
                cited_paper_id: Optional[str] = None,
                citation_excerpt_index: Optional[str] = None,
                citation_id: Optional[str] = None,
                section_label: Optional[torch.Tensor] = None,
                is_citation: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the model
        Args:
            citation_text: citation text of shape (batch, sent_len, embedding_dim)
            labels: labels
            lexicon_features: lexicon sparse features (batch, lexicon_feature_len)
            year_diff: difference between cited and citing years
            citing_paper_id: id of the citing paper
            cited_paper_id: id of the cited paper
            citation_excerpt_index: index of the excerpt
            citation_id: unique id of the citation
            section_label: label of the section
            is_citation: citation worthiness label
        """
        # pylint: disable=arguments-differ
        citation_text_embedding = self.text_field_embedder(citation_text)
        print("citation_text_embedding", citation_text_embedding.shape)
        # print("citation_text", citation_text) # {'elmo': tensor, 'tokens': tensor}
        # tokens are converted into numbers using vocab
        # print("citation_text_embedding", citation_text_embedding) # tensor
        citation_text_mask = util.get_text_field_mask(citation_text)

        # shape: [batch, sent, output_dim]
        encoded_citation_text = self.citation_text_encoder(citation_text_embedding, citation_text_mask)

        # shape: [batch, output_dim]
        attn_dist, encoded_citation_text = self.attention_seq2seq(encoded_citation_text, return_attn_distribution=True)

        # In training mode, labels are the citation intents
        # If in predict_mode, predict the citation intents
        if labels is not None:
            logits = self.classifier_feedforward(encoded_citation_text)
            class_probs = F.softmax(logits, dim=1)

            output_dict = {"logits": logits}

            loss = self.loss(logits, labels)
            output_dict["loss"] = loss

            # compute F1 per label
            for i in range(self.num_classes):
                metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="labels")]
                metric(class_probs, labels)
            output_dict['labels'] = labels

        if section_label is not None:  # this is the first scaffold task
            logits = self.classifier_feedforward_2(encoded_citation_text)
            class_probs = F.softmax(logits, dim=1)
            output_dict = {"logits": logits}
            loss = self.loss(logits, section_label)
            output_dict["loss"] = loss
            for i in range(self.num_classes_sections):
                metric = self.label_f1_metrics_sections[self.vocab.get_token_from_index(index=i, namespace="section_labels")]
                metric(logits, section_label)

        if is_citation is not None:  # second scaffold task
            logits = self.classifier_feedforward_3(encoded_citation_text)
            class_probs = F.softmax(logits, dim=1)
            output_dict = {"logits": logits}
            loss = self.loss(logits, is_citation)
            output_dict["loss"] = loss
            for i in range(self.num_classes_cite_worthiness):
                metric = self.label_f1_metrics_cite_worthiness[
                    self.vocab.get_token_from_index(index=i, namespace="cite_worthiness_labels")]
                metric(logits, is_citation)

        if self.predict_mode:
            logits = self.classifier_feedforward(encoded_citation_text)
            class_probs = F.softmax(logits, dim=1)
            output_dict = {"logits": logits}

        output_dict['citing_paper_id'] = citing_paper_id
        output_dict['cited_paper_id'] = cited_paper_id
        output_dict['citation_excerpt_index'] = citation_excerpt_index
        output_dict['citation_id'] = citation_id
        output_dict['attn_dist'] = attn_dist  # also return attention distribution for analysis
        output_dict['citation_text'] = citation_text['tokens']
        return output_dict

    @overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        class_probabilities = F.softmax(output_dict['logits'], dim=-1)
        predictions = class_probabilities.cpu().data.numpy()
        argmax_indices = np.argmax(predictions, axis=-1)
        labels = [self.vocab.get_token_from_index(x, namespace="labels")
                 for x in argmax_indices]
        output_dict['probabilities'] = class_probabilities
        output_dict['positive_labels'] = labels
        output_dict['prediction'] = labels
        citation_text = []
        for batch_text in output_dict['citation_text']:
            citation_text.append([self.vocab.get_token_from_index(token_id.item()) for token_id in batch_text])
        output_dict['citation_text'] = citation_text
        output_dict['all_labels'] = [self.vocab.get_index_to_token_vocabulary(namespace="labels")
                                     for _ in range(output_dict['logits'].shape[0])]
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}

        sum_f1 = 0.0
        for name, metric in self.label_f1_metrics.items():
            metric_val = metric.get_metric(reset)
            metric_dict[name + '_P'] = metric_val[0]
            metric_dict[name + '_R'] = metric_val[1]
            metric_dict[name + '_F1'] = metric_val[2]
            if name != 'none':  # do not consider `none` label in averaging F1
                sum_f1 += metric_val[2]

        names = list(self.label_f1_metrics.keys())
        total_len = len(names) if 'none' not in names else len(names) - 1
        average_f1 = sum_f1 / total_len
        # metric_dict['combined_metric'] = (accuracy + average_f1) / 2
        metric_dict['average_F1'] = average_f1

        if self.report_auxiliary_metrics:
            sum_f1 = 0.0
            for name, metric in self.label_f1_metrics_sections.items():
                metric_val = metric.get_metric(reset)
                metric_dict['aux-sec--' + name + '_P'] = metric_val[0]
                metric_dict['aux-sec--' + name + '_R'] = metric_val[1]
                metric_dict['aux-sec--' + name + '_F1'] = metric_val[2]
                if name != 'none':  # do not consider `none` label in averaging F1
                    sum_f1 += metric_val[2]
            names = list(self.label_f1_metrics_sections.keys())
            total_len = len(names) if 'none' not in names else len(names) - 1
            average_f1 = sum_f1 / total_len
            # metric_dict['combined_metric'] = (accuracy + average_f1) / 2
            metric_dict['aux-sec--' + 'average_F1'] = average_f1

            sum_f1 = 0.0
            for name, metric in self.label_f1_metrics_cite_worthiness.items():
                metric_val = metric.get_metric(reset)
                metric_dict['aux-worth--' + name + '_P'] = metric_val[0]
                metric_dict['aux-worth--' + name + '_R'] = metric_val[1]
                metric_dict['aux-worth--' + name + '_F1'] = metric_val[2]
                if name != 'none':  # do not consider `none` label in averaging F1
                    sum_f1 += metric_val[2]
            names = list(self.label_f1_metrics_cite_worthiness.keys())
            total_len = len(names) if 'none' not in names else len(names) - 1
            average_f1 = sum_f1 / total_len
            # metric_dict['combined_metric'] = (accuracy + average_f1) / 2
            metric_dict['aux-worth--' + 'average_F1'] = average_f1

        return metric_dict

    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'ScaffoldBilstmAttentionClassifier':
        with_elmo = params.pop_bool("with_elmo", False)
        if with_elmo:
            embedder_params = params.pop("elmo_text_field_embedder")
        else:
            embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(embedder_params, vocab=vocab)
        # citation_text_encoder = Seq2VecEncoder.from_params(params.pop("citation_text_encoder"))
        citation_text_encoder = Seq2SeqEncoder.from_params(params.pop("citation_text_encoder"))
        classifier_feedforward = FeedForward.from_params(params.pop("classifier_feedforward"))
        classifier_feedforward_2 = FeedForward.from_params(params.pop("classifier_feedforward_2"))
        classifier_feedforward_3 = FeedForward.from_params(params.pop("classifier_feedforward_3"))

        initializer = InitializerApplicator.from_params(params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))

        use_lexicon = params.pop_bool("use_lexicon_features", False)
        use_sparse_lexicon_features = params.pop_bool("use_sparse_lexicon_features", False)
        data_format = params.pop('data_format')

        report_auxiliary_metrics = params.pop_bool("report_auxiliary_metrics", False)

        predict_mode = params.pop_bool("predict_mode", False)
        print(f"pred mode: {predict_mode}")

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   citation_text_encoder=citation_text_encoder,
                   classifier_feedforward=classifier_feedforward,
                   classifier_feedforward_2=classifier_feedforward_2,
                   classifier_feedforward_3=classifier_feedforward_3,
                   initializer=initializer,
                   regularizer=regularizer,
                   report_auxiliary_metrics=report_auxiliary_metrics,
                   predict_mode=predict_mode)


def new_parameter(*size):
    out = Parameter(torch.FloatTensor(*size))
    torch.nn.init.xavier_normal_(out)
    return out


class Attention(nn.Module):
    """ Simple multiplicative attention"""
    def __init__(self, attention_size):
        super(Attention, self).__init__()
        self.attention = new_parameter(attention_size, 1)

    def forward(self, x_in, reduction_dim=-2, return_attn_distribution=False):
        """
        return_attn_distribution: if True it will also return the original attention distribution

        this reduces the one before last dimension in x_in to a weighted sum of the last dimension
        e.g., x_in.shape == [64, 30, 100] -> output.shape == [64, 100]
        Usage: You have a sentence of shape [batch, sent_len, embedding_dim] and you want to
            represent sentence to a single vector using attention [batch, embedding_dim]

        Here we use it to aggregate the lexicon-aware representation of the sentence
        In two steps we convert [batch, sent_len, num_words_in_category, num_categories] into [batch, num_categories]
        """
        # calculate attn weights
        attn_score = torch.matmul(x_in, self.attention).squeeze()
        # add one dimension at the end and get a distribution out of scores
        attn_distrib = F.softmax(attn_score.squeeze(), dim=-1).unsqueeze(-1)
        scored_x = x_in * attn_distrib
        weighted_sum = torch.sum(scored_x, dim=reduction_dim)
        if return_attn_distribution:
            return attn_distrib.reshape(x_in.shape[0], -1), weighted_sum
        else:
            return weighted_sum


In [91]:
params = Params.from_file('experiment_configs/custom_config.json', "")

In [92]:
model = ScaffoldBilstmAttentionClassifier1.from_params(vocab=vocab, params=params.pop('model'))

BasicTextFieldEmbedder(
  (token_embedder_elmo): ElmoTokenEmbedder(
    (_elmo): Elmo(
      (_elmo_lstm): _ElmoBiLm(
        (_token_embedder): _ElmoCharacterEncoder(
          (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
          (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
          (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
          (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
          (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
          (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
          (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
          (_highways): Highway(
            (_layers): ModuleList(
              (0): Linear(in_features=2048, out_features=4096, bias=True)
              (1): Linear(in_features=2048, out_features=4096, bias=True)
            )
          )
          (_projection): Linear(in_features=2048, out_features=512, bias=True)
        )


In [108]:
model.cuda(0)
batch = util.move_to_device(batch, 0)
output_dict = model(**batch)

In [144]:
# pylint: disable=arguments-differ
citation_text = batch['citation_text'] 
citation_text_embedding = model.text_field_embedder(citation_text)
print("citation_text_embedding", citation_text_embedding.shape)
# print("citation_text", citation_text) # {'elmo': tensor, 'tokens': tensor}
# tokens are converted into numbers using vocab
# print("citation_text_embedding", citation_text_embedding) # tensor
citation_text_mask = util.get_text_field_mask(citation_text)
print("citation_text_mask", citation_text_mask.shape)

# shape: [batch, sent, output_dim]
encoded_citation_text = model.citation_text_encoder(citation_text_embedding, citation_text_mask)
print("encoded_citation_text", encoded_citation_text.shape)

# shape: [batch, output_dim]
# attn_dist, encoded_citation_text = model.attention_seq2seq(encoded_citation_text, return_attn_distribution=True)
# print("encoded_citation_text", encoded_citation_text.shape)

# In training mode, labels are the citation intents

In [145]:
batch['citation_text']

{'elmo': tensor([[[259,  87, 106,  ..., 261, 261, 261],
         [259,  66, 260,  ..., 261, 261, 261],
         [259, 116, 118,  ..., 261, 261, 261],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[259,  71, 112,  ..., 261, 261, 261],
         [259,  72, 118,  ..., 261, 261, 261],
         [259,  98, 111,  ..., 261, 261, 261],
         ...,
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   0]],

        [[259,  72, 106,  ..., 261, 261, 261],
         [259, 117, 105,  ..., 261, 261, 261],
         [259, 108, 111,  ..., 261, 261, 261],
         ...,
         [259, 117, 105,  ..., 261, 261, 261],
         [259, 101, 106,  ..., 261, 261, 261],
         [  0,   0,   0,  ...,   0,   0,   0]],

        ...,

        [[259,  85, 105,  ..., 261, 261, 261],
         [259, 115, 102,  ..., 261, 

In [155]:
encoded_citation_text[:, -1, :].shape

torch.Size([16, 200])

In [136]:
params1={'model': {
      "type": "gru",
      "bidirectional": 'true',
      "input_size": 1324,
      "hidden_size": 100,
      "num_layers": 2,
      "dropout": 0.3
    }
       }
Model.from_params(vocab=vocab, params=params1.pop('model'))

AttributeError: 'dict' object has no attribute 'pop_choice'

## Remove stopwords

In [7]:
import nltk
from nltk.corpus import stopwords
from allennlp.data.tokenizers import Tokenizer, WordTokenizer

 
nltk.download('stopwords')
sw = stopwords.words('english')

[nltk_data] Downloading package stopwords to /home/kanhon/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [16]:
st = '''
Activated Are PBMC are the basis of the standard PBMC blast assay for HIV-1 neutralization, whereas the various GHOST and HeLa cell lines have all been used in neutralization assays (42, 66)
'''

In [17]:
st_t = WordTokenizer().tokenize(st)

In [18]:
[x for x in st_t if not x.is_stop]

[Activated,
 Are,
 PBMC,
 basis,
 standard,
 PBMC,
 blast,
 assay,
 HIV-1,
 neutralization,
 ,,
 GHOST,
 HeLa,
 cell,
 lines,
 neutralization,
 assays,
 (,
 42,
 ,,
 66,
 )]

## Dataset remove citation brackets

In [3]:
import re
import json
import pandas as pd

In [7]:

with open(f"./scicite_data/train.jsonl", encoding="utf8") as f:
    lines = f.read().splitlines()
    lines = [json.loads(x) for x in lines]
df_train = pd.DataFrame(lines)

In [15]:
samples = df_train['string'].iloc[:10].tolist()

In [1]:
sample = "Previous empirical analyses of subnational consumer subsidies found that a majority of solar adoption was attributable to subsidies (e.g., Hughes and Podolefsky, 2015; Burr, 2016; De Groote and Verboven, 2016; Gillingham and Tsvetanov, 2017)."

In [2]:
sample

'Previous empirical analyses of subnational consumer subsidies found that a majority of solar adoption was attributable to subsidies (e.g., Hughes and Podolefsky, 2015; Burr, 2016; De Groote and Verboven, 2016; Gillingham and Tsvetanov, 2017).'

In [38]:
for sample in samples: 
    print(sample)
    # print(re.findall("[(](\D+\d{4})*?[)]", sample))
    # print(re.findall("([\[]\d+([,-]?\s*\W*\d+)*[\]])", sample))
    print(re.sub("[\(\[](\D+\d{4})*?[\)\]]", "@@@CITE@@@", sample))
    print(re.sub("([\[\(]\d+([,-]?\s*\W*\d+)*[\]\)])", "@@@CITE@@@", sample))

However, how frataxin interacts with the Fe-S cluster biosynthesis components remains unclear as direct one-to-one interactions with each component were reported (IscS [12,22], IscU/Isu1 [6,11,16] or ISD11/Isd11 [14,15]).
However, how frataxin interacts with the Fe-S cluster biosynthesis components remains unclear as direct one-to-one interactions with each component were reported (IscS [12,22], IscU/Isu1 [6,11,16] or ISD11/Isd11 [14,15]).
However, how frataxin interacts with the Fe-S cluster biosynthesis components remains unclear as direct one-to-one interactions with each component were reported (IscS @@@CITE@@@, IscU/Isu1 @@@CITE@@@ or ISD11/Isd11 @@@CITE@@@).
In the study by Hickey et al. (2012), spikes were sampled from the field at the point of physiological
robinson et al.: genomic regions influencing root traits in barley 11 of 13
maturity, dried, grain threshed by hand, and stored at −20C to preserve grain dormancy before germination testing.
In the study by Hickey et al. (20

## JSONL

In [30]:
import json
import pandas as pd

In [31]:
with open(f"./scicite_data/dev.jsonl", encoding="utf8") as f:
    lines = f.read().splitlines()
    lines = [json.loads(x) for x in lines]

In [32]:
df_dev = pd.DataFrame(lines)

In [33]:
df_dev

Unnamed: 0,source,citeEnd,sectionName,citeStart,string,label,label2,citingPaperId,citedPaperId,isKeyCitation,id,unique_id,excerpt_index,label_confidence,label2_confidence
0,explicit,68.0,Discussion,64.0,These results are in contrast with the finding...,result,supportive,8f1fbe460a901d994e9b81d69f77bfbe32719f4c,5e413c7872f5df231bf4a4f694504384560e98ca,False,8f1fbe460a901d994e9b81d69f77bfbe32719f4c>5e413...,8f1fbe460a901d994e9b81d69f77bfbe32719f4c>5e413...,0,,
1,explicit,241.0,Discussion,222.0,…nest burrows in close proximity of one anothe...,background,,d9f3207db0c79a3b154f3875c9760cc6b056904b,2cc6ff899bf17666ad35893524a4d61624555ed7,False,d9f3207db0c79a3b154f3875c9760cc6b056904b>2cc6f...,d9f3207db0c79a3b154f3875c9760cc6b056904b>2cc6f...,10,0.7337,
2,explicit,94.0,. 6 Discussion,71.0,This is clearly in contrast to the results of ...,result,supportive,226f798d30e5523c5b9deafb826ddb04d47c11dc,,False,226f798d30e5523c5b9deafb826ddb04d47c11dc>None,226f798d30e5523c5b9deafb826ddb04d47c11dc>None_0,0,,
3,explicit,170.0,,148.0,"…in a subset of alcoholics (Chen et al., 2004;...",background,,59dba7cd80edcce831d20b35f9eb597bba290154,273996fbf99465211eb8306abe8c56c5835f332e,False,59dba7cd80edcce831d20b35f9eb597bba290154>27399...,59dba7cd80edcce831d20b35f9eb597bba290154>27399...,0,1.0000,
4,explicit,89.0,DISCUSSION,85.0,This result is consistent with the conclusions...,result,not_supportive,0640f6e098a9d241cd680473e8705357ae101e04,e33da0584b8db37816d510fd9ba7c1216858fd5f,False,0640f6e098a9d241cd680473e8705357ae101e04>e33da...,0640f6e098a9d241cd680473e8705357ae101e04>e33da...,0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
911,explicit,99.0,Discussion,71.0,Our results are consistent with those of a pre...,result,not_supportive,d9ef5d1cb543d0f1330908b36f7106f23cbad404,ae7c65d4d7710bf5f36309faea886648f9a51d4a,False,d9ef5d1cb543d0f1330908b36f7106f23cbad404>ae7c6...,d9ef5d1cb543d0f1330908b36f7106f23cbad404>ae7c6...,0,,
912,explicit,136.0,1. Introduction,129.0,Some of these peptides act as neurotoxins on t...,background,,d937bc4d0722d5b45366b3c4dfde4732224bc048,9200a75534f44836ca7651c9d63d11b884947fa6,True,d937bc4d0722d5b45366b3c4dfde4732224bc048>9200a...,d937bc4d0722d5b45366b3c4dfde4732224bc048>9200a...,4,1.0000,
913,explicit,150.0,4. Discussion,144.0,"Therefore, despite an apparent higher number o...",background,,3f50975c58d861e4fbd3b4fd065f0658b1aa1e10,d16a1d95e6947da69797bb0cb59148057174e35a,True,3f50975c58d861e4fbd3b4fd065f0658b1aa1e10>d16a1...,3f50975c58d861e4fbd3b4fd065f0658b1aa1e10>d16a1...,0,1.0000,
914,explicit,28.0,INTRODUCTION,13.0,According to Xu et al (2011) the factors that ...,method,,d776ade6bb4898c032b971d2cec145976408e838,22e3889f93c19c15b746c1339ce9d7439ccb632e,False,d776ade6bb4898c032b971d2cec145976408e838>22e38...,d776ade6bb4898c032b971d2cec145976408e838>22e38...,0,,


In [73]:
output_path = f"./output_pretrained_test.txt"
with open(output_path, encoding="utf8") as f:
    output_pretrained = f.read()
    output_pretrained = "[" + output_pretrained + "]"
    # output_pretrained = [json.loads(x) for x in output_pretrained]
    output_pretrained = json.loads(output_pretrained.replace('\n', ', ')[:-3]+']')

In [74]:
df_dev_predict = pd.DataFrame(output_pretrained)
df_dev_predict

Unnamed: 0,citingPaperId,citedPaperId,prediction,unique_id,label
0,2c6797dab4c118cb73197f65ba39dacc99ac743d,95c37bc99982d33873fd141ee00857160fd717a0,method,2c6797dab4c118cb73197f65ba39dacc99ac743d>95c37...,background
1,fa7145adc9f8cfb8af7a189d9040c13c84ced094,20e23b4f76761d246a7c3b00b80e139e2008f77d,result,fa7145adc9f8cfb8af7a189d9040c13c84ced094>20e23...,result
2,98a8d8c0c5dae246720d4f339b88e8a9f44e3002,bd222c7ec83dadefba513738290b3624f6dd6b21,result,98a8d8c0c5dae246720d4f339b88e8a9f44e3002>bd222...,background
3,aeb178ef1910a61152cd74209c28641199c82855,754c04953c261072fa367f4104e3deff082d9484,method,aeb178ef1910a61152cd74209c28641199c82855>754c0...,method
4,e4d2591ac3bb65e2ec59f092884a7b15b8018592,f0fb468a54fe8021bc7986a1618222c4fcd16df4,method,e4d2591ac3bb65e2ec59f092884a7b15b8018592>f0fb4...,background
...,...,...,...,...,...
1856,3cf9c7cd259a356839f42ecf143af3a8f6ef8b54,74cbd6d0eeb051b036f806d8a86c3a85859f9d7d,result,3cf9c7cd259a356839f42ecf143af3a8f6ef8b54>74cbd...,result
1857,e609824e9ea6bee5aca817238d81d1cdd6b462ad,f7bfdcf8892a561b6030ed541924551fb78acf1f,background,e609824e9ea6bee5aca817238d81d1cdd6b462ad>f7bfd...,background
1858,19317f7188bc6ecad985c46277969c0ac03dbcf8,0c86f0d577f04534edc14a509a68ae80ce6fbb74,method,19317f7188bc6ecad985c46277969c0ac03dbcf8>0c86f...,method
1859,62ac94ab9227b84f1317edad1b6312e311981961,df5084196ea93af9250fae27c981ea3d7959599d,background,62ac94ab9227b84f1317edad1b6312e311981961>df508...,background


In [75]:
# df_dev_predict.dropna(subset=['prediction'])
df_dev_predict = df_dev_predict[~(df_dev_predict['prediction'] == '')]
df_dev_predict

Unnamed: 0,citingPaperId,citedPaperId,prediction,unique_id,label
0,2c6797dab4c118cb73197f65ba39dacc99ac743d,95c37bc99982d33873fd141ee00857160fd717a0,method,2c6797dab4c118cb73197f65ba39dacc99ac743d>95c37...,background
1,fa7145adc9f8cfb8af7a189d9040c13c84ced094,20e23b4f76761d246a7c3b00b80e139e2008f77d,result,fa7145adc9f8cfb8af7a189d9040c13c84ced094>20e23...,result
2,98a8d8c0c5dae246720d4f339b88e8a9f44e3002,bd222c7ec83dadefba513738290b3624f6dd6b21,result,98a8d8c0c5dae246720d4f339b88e8a9f44e3002>bd222...,background
3,aeb178ef1910a61152cd74209c28641199c82855,754c04953c261072fa367f4104e3deff082d9484,method,aeb178ef1910a61152cd74209c28641199c82855>754c0...,method
4,e4d2591ac3bb65e2ec59f092884a7b15b8018592,f0fb468a54fe8021bc7986a1618222c4fcd16df4,method,e4d2591ac3bb65e2ec59f092884a7b15b8018592>f0fb4...,background
...,...,...,...,...,...
1856,3cf9c7cd259a356839f42ecf143af3a8f6ef8b54,74cbd6d0eeb051b036f806d8a86c3a85859f9d7d,result,3cf9c7cd259a356839f42ecf143af3a8f6ef8b54>74cbd...,result
1857,e609824e9ea6bee5aca817238d81d1cdd6b462ad,f7bfdcf8892a561b6030ed541924551fb78acf1f,background,e609824e9ea6bee5aca817238d81d1cdd6b462ad>f7bfd...,background
1858,19317f7188bc6ecad985c46277969c0ac03dbcf8,0c86f0d577f04534edc14a509a68ae80ce6fbb74,method,19317f7188bc6ecad985c46277969c0ac03dbcf8>0c86f...,method
1859,62ac94ab9227b84f1317edad1b6312e311981961,df5084196ea93af9250fae27c981ea3d7959599d,background,62ac94ab9227b84f1317edad1b6312e311981961>df508...,background


In [76]:
# df_dev_merged = df_dev.merge(df_dev_predict, how='left', left_on="unique_id", right_on="unique_id").dropna(subset=['prediction'])

In [77]:
# def compare_res(df_row):
#     if df_row["label"] == df_row["prediction"]:
#         return True
#     else:
#         return False

In [78]:
# df_dev_merged["correct_pred"] = df_dev_merged.apply(compare_res, axis=1)

In [79]:
# df_dev_merged[df_dev_merged["correct_pred"] == True]["correct_pred"].count()

In [80]:
# df_dev_merged["correct_pred"].count()

In [81]:
from sklearn.metrics import classification_report, f1_score

In [82]:
print(classification_report(df_dev_predict['label'], df_dev_predict['prediction']))

              precision    recall  f1-score   support

  background       0.88      0.87      0.88       992
      method       0.88      0.81      0.85       605
      result       0.71      0.86      0.78       259

   micro avg       0.85      0.85      0.85      1856
   macro avg       0.82      0.85      0.83      1856
weighted avg       0.86      0.85      0.85      1856



In [83]:
print(f1_score(df_dev_predict['label'], df_dev_predict['prediction'], average='macro'))
# df_dev_predict['prediction'].unique()

0.8328905208191819


In [45]:
# df_dev_merged[df_dev_merged['prediction']!= df_dev_merged['prediction']]

## Experiment results

In [52]:
# lambda_pairs = [(0,0), (0.05,0.05), (0.1,0.1), (0.1, 0.2), (0.1, 0.3), (0.2, 0.2), (0.3, 0.3)]

# for lamb in lambda_pairs:
#     print(lamb)
#     with open(f"./runs/experiments-_{lamb[0]}_{lamb[1]}/metrics.json", encoding="utf8") as f:
#         metrics = f.read()
#         metrics = json.loads(metrics)
#     print(metrics['best_validation_average_F1'])

In [6]:
import glob
for fl in glob.glob("./runs/experiment*/metrics.json"):
    print(fl)
    with open(fl, encoding="utf8") as f:
        metrics = f.read()
        metrics = json.loads(metrics)
    try:
        print(metrics['best_validation_average_F1'])
    except:
        print(metrics['validation_average_F1'])

./runs/experiment-0.05-0.05-w2v-1/metrics.json
0.8178103926306495
./runs/experiment-0.05-0.05-citetokens/metrics.json
0.8195389041330493
./runs/experiments-_0.1_0.1/metrics.json
0.8131806035859791
./runs/experiments-_0.1_0.3/metrics.json
0.8077549709636348
./runs/experiments-_0.3_0.3/metrics.json
0.8186008842495505
./runs/experiments-_0_0/metrics.json
0.8273689041357627
./runs/experiment-0.05-0.05-numtokens2/metrics.json
0.8116501788703033
./runs/experiment-0.05-0.05-removecite/metrics.json
0.8194678580833311
./runs/experiment-0.05-0.05-numtokens/metrics.json
0.8332833985516181
./runs/experiment-0.05-0.05-elmo-lstm/metrics.json
0.8194471827306655
./runs/experiment-0.05-0.05-elmo-forwardgru/metrics.json
0.81613775692173
./runs/experiments-_0.2_0.2/metrics.json
0.8216554689057259
./runs/experiments-_0.05_0.05/metrics.json
0.8294245106130006
./runs/experiment-0.05-0.05-elmo-forwardgru-noattention/metrics.json
0.3827258501262281
./runs/experiment-0.05-0.05-no-elmo/metrics.json
0.7909244219

In [7]:
import glob
for fl in glob.glob("./runs/experiment*/metrics.json"):
    print(fl)
    with open(fl, encoding="utf8") as f:
        metrics = f.read()
        metrics = json.loads(metrics)
    print(metrics['test_average_F1'])

./runs/experiment-0.05-0.05-w2v-1/metrics.json
0.8310028720107511
./runs/experiment-0.05-0.05-citetokens/metrics.json
0.8204258964176881
./runs/experiments-_0.1_0.1/metrics.json
0.8203714018464888
./runs/experiments-_0.1_0.3/metrics.json
0.8050099186479746
./runs/experiments-_0.3_0.3/metrics.json
0.8320188112872264
./runs/experiments-_0_0/metrics.json
0.8231164879744503
./runs/experiment-0.05-0.05-numtokens2/metrics.json
0.8330661047322474
./runs/experiment-0.05-0.05-removecite/metrics.json
0.8277156189814104
./runs/experiment-0.05-0.05-numtokens/metrics.json
0.8064126960158635
./runs/experiment-0.05-0.05-elmo-lstm/metrics.json
0.8146276763446183
./runs/experiment-0.05-0.05-elmo-forwardgru/metrics.json
0.825328793589981
./runs/experiments-_0.2_0.2/metrics.json
0.8338795863810212
./runs/experiments-_0.05_0.05/metrics.json
0.8276466110976536
./runs/experiment-0.05-0.05-elmo-forwardgru-noattention/metrics.json
0.38035496254387585
./runs/experiment-0.05-0.05-no-elmo/metrics.json
0.76827555

In [10]:
# best_model_state_path = os.path.join(serialization_dir, 'best.th')
# best_model_state = torch.load(best_model_state_path)
# best_model = model
# best_model.load_state_dict(best_model_state)