In [None]:
depends_on = [
    "preproc_jigsaw",
    "jigsaw_create_augmented_data",
]

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# settings for seamlessly running on colab
import os
os.environ["IS_COLAB"] = "False"

In [None]:
if os.environ["IS_COLAB"] == "True":
    from google.colab import drive
    drive.mount('/content/gdrive')

In [None]:
%%bash
if [ "$IS_COLAB" = "True" ]; then
    pip install git+https://github.com/facebookresearch/fastText.git
    pip install torch
    pip install torchvision
    pip install allennlp
    pip install dnspython
fi

In [None]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.nn import util as nn_util

import logging
logger = logging.getLogger(__name__)  # pylint: disable=invalid-name

In [None]:
import time
from contextlib import contextmanager

class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)
    
    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

@contextmanager
def timer(name):
    t0 = time.time()
    yield
    print(f'[{name}] done in {time.time() - t0:.0f} s')
    
import functools
import traceback
import sys

def get_ref_free_exc_info():
    "Free traceback from references to locals/globals to avoid circular reference leading to gc.collect() unable to reclaim memory"
    type, val, tb = sys.exc_info()
    traceback.clear_frames(tb)
    return (type, val, tb)

def gpu_mem_restore(func):
    "Reclaim GPU RAM if CUDA out of memory happened, or execution was interrupted"
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except:
            type, val, tb = get_ref_free_exc_info() # must!
            raise type(val).with_traceback(tb) from None
    return wrapper

In [None]:
# for papermill
testing = True
debugging = False
seed = 1
computational_batch_size = 256
batch_size = 256
lr = 0.001
epochs = 1
hidden_sz = 64
dataset = "jigsaw"
n_classes = 6
max_seq_len = 512
download_data = False
ft_model_path = "../data/jigsaw/ft_model.bin"
max_vocab_size = 300000
dropoute = 0.5
val_ratio = 0.05
use_augmented = False
mixup_ratio = 0.2
bias_init = True
neg_splits = 3
model_type = "standard"
run_id = None

In [None]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    debugging=debugging,
    seed=seed,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    lr=lr,
    epochs=epochs,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    ft_model_path=ft_model_path,
    max_vocab_size=max_vocab_size,
    dropoute=dropoute,
    val_ratio=val_ratio,
    use_augmented=use_augmented,
    bias_init=bias_init,
    neg_splits=neg_splits,
    model_type=model_type,
    mixup_ratio=mixup_ratio,
    run_id=run_id,
)

In [None]:
from allennlp.common.checks import ConfigurationError

In [None]:
if config.model_type != "standard" and "bert" not in config.model_type and "elmo" not in config.model_type:
    raise ConfigurationError(f"Invalid model type {config.model_type} given")

In [None]:
import datetime
now = datetime.datetime.now()
RUN_ID = config.run_id if config.run_id is not None else now.strftime("%m_%d_%H:%M:%S")

In [None]:
USE_GPU = torch.cuda.is_available()

In [None]:
if os.environ["IS_COLAB"] != "True":
    DATA_ROOT = Path("../data") / config.dataset
else:
    DATA_ROOT = Path("./gdrive/My Drive/Colab_Workspace/Colab Notebooks/data") / config.dataset
    config.ft_model_path = str(DATA_ROOT / "ft_model.bin")

In [None]:
!mkdir -p {DATA_ROOT}

In [None]:
import subprocess
if download_data:
    if config.val_ratio > 0.0:
        fnames = ["train_wo_val.csv", "test_proced.csv", "val.csv", "ft_model.bin"]
    else:
        fnames = ["train.csv", "test_proced.csv", "ft_model.bin"]
    if config.use_augmented: fnames.append("train_extra.csv")
    for fname in fnames:
        if not (DATA_ROOT / fname).exists():
            print(subprocess.Popen([f"aws s3 cp s3://nnfornlp/raw_data/jigsaw/{fname} {str(DATA_ROOT)}"],
                                   shell=True, stdout=subprocess.PIPE).stdout.read())

In [None]:
!ls {DATA_ROOT}

Set random seed manually to replicate results

In [None]:
torch.manual_seed(config.seed)

# Load Data

In [None]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader, StanfordSentimentTreeBankDatasetReader

### Prepare dataset

In [None]:
reader_registry = {}

In [None]:
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [None]:
label_cols = ["toxic", "severe_toxic", "obscene",
              "threat", "insult", "identity_hate"]

from enum import IntEnum
ColIdx = IntEnum('ColIdx', [(x.upper(), i) for i, x in enumerate(label_cols)])

In [None]:
from allennlp.data.fields import TextField, SequenceLabelField, LabelField, MetadataField, ArrayField

@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token], id: str,
                         labels: np.ndarray) -> Instance:
        sentence_field = TextField([Token(x) for x in tokens],
                                   self.token_indexers)
        fields = {"tokens": sentence_field}
        
        id_field = MetadataField(id)
        fields["id"] = id_field
        
        meta_field = MetadataField({"lengths": np.array([len(t) for t in tokens])})
        fields["meta"] = meta_field
        
        label_field = ArrayField(array=labels)
        fields["label"] = label_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(1000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                self.tokenizer(row["comment_text"]),
                row["id"], row[label_cols].values,
            )

### Prepare token handlers

In [None]:
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer

In [None]:
if model_type == "standard":
    from allennlp.data.token_indexers import SingleIdTokenIndexer
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=False,  # don't lowercase by default
    )
    def tokenizer(x: str):
        return [w.text for w in
                SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words(x)[:config.max_seq_len]]
elif "elmo" in model_type:
    from allennlp.data.token_indexers.elmo_indexer import ELMoTokenCharactersIndexer
    token_indexer = ELMoTokenCharactersIndexer()
    def tokenizer(x: str):
        return [w.text for w in
                SpacyWordSplitter(language='en_core_web_sm', pos_tags=False).split_words(x)[:config.max_seq_len]]
elif "bert" in model_type:
    from allennlp.data.token_indexers import PretrainedBertIndexer
    token_indexer = PretrainedBertIndexer(
        pretrained_model=config.model_type,
        max_pieces=config.max_seq_len,
        do_lowercase=True,
     )
    # apparently we need to truncate the sequence here, which is a stupid design decision
    def tokenizer(s: str):
        return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]

In [None]:
reader = JigsawDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer}
)

In [None]:
if config.val_ratio > 0.0:
    train_ds, val_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train_wo_val.csv",
                                                                              "val.csv",
                                                                              "test_proced.csv"])
else:
    train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train.csv",
                                                                      "test_proced.csv"])

In [None]:
if config.use_augmented:
    # TODO: Handle data leak for validation!
    train_aug_ds = reader.read(DATA_ROOT / "train_extra_interpolated.csv")

In [None]:
len(train_ds)

### Prepare labels

In [None]:
if config.val_ratio > 0.0:
    train_labels = pd.read_csv(DATA_ROOT / "train_wo_val.csv")[label_cols].values
else:
    train_labels = pd.read_csv(DATA_ROOT / "train.csv")[label_cols].values
if config.testing: train_labels = train_labels[:len(train_ds), :]
if config.use_augmented:
    train_aux_labels = pd.read_csv(DATA_ROOT / "train_extra.csv")[label_cols].values
    if config.testing: train_aux_labels = train_aux_labels[:len(train_ds), :]

### Prepare vocabulary

In [None]:
vocab = Vocabulary.from_instances(train_ds, max_vocab_size=config.max_vocab_size)

### Prepare iterator

In [None]:
from allennlp.data.iterators import BucketIterator, DataIterator

In [None]:
from sklearn.model_selection import KFold

class Sampler:
    def sample(self, ds: List[Instance]) -> List[Instance]:
        return ds

class BiasedSampler(Sampler):
    def __init__(self, mask: np.ndarray, n_splits: int):
        self.mask = mask
        self.n_splits = n_splits
        self.pos = np.where(self.mask)[0]
        self.neg = np.where(~self.mask)[0]
        self._n_splits_iterated = 0
        
    def sample(self, ds: List[Instance]):
        if self._n_splits_iterated % self.n_splits == 0:
            self.folds = KFold(n_splits=self.n_splits).split(self.neg)
        _, neg_idxs = next(self.folds)
        
        p = np.random.permutation(len(self.pos) + len(neg_idxs))
        smpl = np.r_[self.pos, self.neg[neg_idxs]][p]
        
        self._n_splits_iterated += 1
        return [ds[i] for i in smpl]

In [None]:
class ScoredSampler:
    def __init__(self, mask: np.ndarray, ratio: float):
        self.mask = mask
        self.ratio = ratio
        self.n_samples = int(len(self.tgt) * self.ratio)
        self.score = mask.astype("int")
    
    def set_score(self, score: np.ndarray):
        assert len(score) == len(self.tgt)
        self.score = score
    
    def sample(self, ds: List[Instance]):
        """Sample top n targets sorted by score descending"""
        smpl = np.arange(len(self.mask))[np.argsort(-self.score)][:self.n_samples]
        return [ds[i] for i in smpl]

In [None]:
import random
from collections import deque
from overrides import overrides

from allennlp.common.checks import ConfigurationError
from allennlp.common.util import lazy_groups_of, add_noise_to_dict_values
from allennlp.data.dataset import Batch
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.vocabulary import Vocabulary

TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]  # pylint: disable=invalid-name

def sort_by_padding(instances: List[Instance],
                    sorting_keys: List[Tuple[str, str]],  # pylint: disable=invalid-sequence-index
                    vocab: Vocabulary,
                    padding_noise: float = 0.0) -> List[Instance]:
    """
    Sorts the instances by their padding lengths, using the keys in
    ``sorting_keys`` (in the order in which they are provided).  ``sorting_keys`` is a list of
    ``(field_name, padding_key)`` tuples.
    """
    instances_with_lengths = []
    for instance in instances:
        # Make sure instance is indexed before calling .get_padding
        instance.index_fields(vocab)
        padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths())
        if padding_noise > 0.0:
            noisy_lengths = {}
            for field_name, field_lengths in padding_lengths.items():
                noisy_lengths[field_name] = add_noise_to_dict_values(field_lengths, padding_noise)
            padding_lengths = noisy_lengths
        instance_with_lengths = ([padding_lengths[field_name][padding_key]
                                  for (field_name, padding_key) in sorting_keys],
                                 instance)
        instances_with_lengths.append(instance_with_lengths)
    instances_with_lengths.sort(key=lambda x: x[0])
    return [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths]

class CustomBucketIterator(DataIterator):
    """
    An iterator which by default, pads batches with respect to the maximum input lengths `per
    batch`. Additionally, you can provide a list of field names and padding keys which the dataset
    will be sorted by before doing this batching, causing inputs with similar length to be batched
    together, making computation more efficient (as less time is wasted on padded elements of the
    batch).

    Parameters
    ----------
    sorting_keys : List[Tuple[str, str]]
        To bucket inputs into batches, we want to group the instances by padding length, so that we
        minimize the amount of padding necessary per batch. In order to do this, we need to know
        which fields need what type of padding, and in what order.

        For example, ``[("sentence1", "num_tokens"), ("sentence2", "num_tokens"), ("sentence1",
        "num_token_characters")]`` would sort a dataset first by the "num_tokens" of the
        "sentence1" field, then by the "num_tokens" of the "sentence2" field, and finally by the
        "num_token_characters" of the "sentence1" field.  TODO(mattg): we should have some
        documentation somewhere that gives the standard padding keys used by different fields.
    padding_noise : float, optional (default=.1)
        When sorting by padding length, we add a bit of noise to the lengths, so that the sorting
        isn't deterministic.  This parameter determines how much noise we add, as a percentage of
        the actual padding value for each instance.
    biggest_batch_first : bool, optional (default=False)
        This is largely for testing, to see how large of a batch you can safely use with your GPU.
        This will let you try out the largest batch that you have in the data `first`, so that if
        you're going to run out of memory, you know it early, instead of waiting through the whole
        epoch to find out at the end that you're going to crash.

        Note that if you specify ``max_instances_in_memory``, the first batch will only be the
        biggest from among the first "max instances in memory" instances.
    batch_size : int, optional, (default = 32)
        The size of each batch of instances yielded when calling the iterator.
    instances_per_epoch : int, optional, (default = None)
        See :class:`BasicIterator`.
    max_instances_in_memory : int, optional, (default = None)
        See :class:`BasicIterator`.
    maximum_samples_per_batch : ``Tuple[str, int]``, (default = None)
        See :class:`BasicIterator`.
    """

    def __init__(self,
                 sorting_keys: List[Tuple[str, str]],
                 padding_noise: float = 0.1,
                 biggest_batch_first: bool = False,
                 batch_size: int = 32,
                 instances_per_epoch: int = None,
                 max_instances_in_memory: int = None,
                 cache_instances: bool = False,
                 track_epoch: bool = False,
                 maximum_samples_per_batch: Tuple[str, int] = None,
                 negative_sample_rate: Optional[int]=None,
                 sampler: Sampler=None) -> None:
        if not sorting_keys:
            raise ConfigurationError("BucketIterator requires sorting_keys to be specified")

        super().__init__(cache_instances=cache_instances,
                         track_epoch=track_epoch,
                         batch_size=batch_size,
                         instances_per_epoch=instances_per_epoch,
                         max_instances_in_memory=max_instances_in_memory,
                         maximum_samples_per_batch=maximum_samples_per_batch)
        self._sorting_keys = sorting_keys
        self._padding_noise = padding_noise
        self._biggest_batch_first = biggest_batch_first
        if sampler is not None:
            self.sampler = Sampler()
        else:
            self.sampler = sampler

    @overrides
    def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
        for instance_list in self._memory_sized_lists(self.sampler.sample(instances)):
            instance_list = sort_by_padding(instance_list,
                                            self._sorting_keys,
                                            self.vocab,
                                            self._padding_noise)
            batches = []
            excess: Deque[Instance] = deque()
            for batch_instances in lazy_groups_of(iter(instance_list), self._batch_size):
                for possibly_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances, excess):
                    batches.append(Batch(possibly_smaller_batches))
            if excess:
                batches.append(Batch(excess))

            move_to_front = self._biggest_batch_first and len(batches) > 1
            if move_to_front:
                # We'll actually pop the last _two_ batches, because the last one might not be full.
                last_batch = batches.pop()
                penultimate_batch = batches.pop()
            if shuffle:
                # NOTE: if shuffle is false, the data will still be in a different order
                # because of the bucket sorting.
                random.shuffle(batches)
            if move_to_front:
                batches.insert(0, penultimate_batch)
                batches.insert(0, last_batch)

            yield from batches
        
        def __call__(self, instances: Iterable[Instance],
                     num_epochs: int=None, shuffle: bool=True) -> Iterator[TensorDict]:
            yield from super().__call__(instances, 
                                        num_epochs=num_epochs, shuffle=shuffle)

In [None]:
# TODO: Allow for customization
if config.neg_splits > 1:
    if config.use_augmented:
        full_trn_labels = np.concatenate([train_labels, train_aux_labels], axis=0)
    else:
        full_trn_labels = train_labels
    sampler = BiasedSampler(full_trn_labels.sum(1) >= 1,
                            config.neg_splits)
else:
    sampler = Sampler()
iterator = CustomBucketIterator(
    batch_size=config.batch_size, 
    biggest_batch_first=True,
    sorting_keys=[("tokens", "num_tokens")],
    sampler=sampler,
)
iterator.index_with(vocab)

### Read sample

In [None]:
batch = next(iter(iterator(train_ds)))

In [None]:
batch

In [None]:
batch["tokens"]["tokens"]

In [None]:
batch["tokens"]["tokens"].shape

# Prepare Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask

In [None]:
class Attention(nn.Module):
    def __init__(self, inp_sz, dim=1, eps=1e-9):
        super().__init__()
        self.inp_sz, self.dim, self.eps = inp_sz, dim, eps
        self.l1 = nn.Linear(inp_sz, inp_sz)
        nn.init.xavier_uniform_(self.l1.weight.data)
        nn.init.zeros_(self.l1.bias.data)
        
        vw = torch.zeros(inp_sz, 1)
        nn.init.xavier_uniform_(vw)        
        self.vw = nn.Parameter(vw)
        
    def forward(self, x, mask=None):
        e = torch.tanh(self.l1(x))
        e = torch.einsum("bij,jk->bi", [e, self.vw])            
        a = torch.exp(e)
        
        if mask is not None: a = a.masked_fill(mask, 0)

        a = a / (torch.sum(a, dim=self.dim, keepdim=True) + self.eps)

        weighted_input = x * a.unsqueeze(-1)
        return torch.sum(weighted_input, dim=1), a

In [None]:
def init_gru_weights(gru: nn.GRU):
    """Applies orthogonal and xavier uniform initialization according to best practices"""
    for nm, param in gru.named_parameters():
        if "weight_hh" in nm:
            nn.init.orthogonal_(param.data)
        elif "weight_ih" in nm:
            nn.init.xavier_uniform_(param.data)

In [None]:
class BiGRUAttentionEncoder(Seq2VecEncoder):
    def __init__(self, embed_sz: int, hidden_sz: int, num_layers=2):
        super().__init__()
        self.embed_sz = embed_sz
        self.hidden_sz = hidden_sz
        self.gru = nn.GRU(self.embed_sz, self.hidden_sz,
                          num_layers=num_layers, bidirectional=True)
        init_gru_weights(self.gru)
        self.attention = Attention(self.hidden_sz * 2, dim=1)
        
    @overrides
    def get_input_dim(self) -> int:
        return self.embed_sz
    
    @overrides
    def get_output_dim(self) -> int:
        return self.hidden_sz * 2
    
    @overrides
    def forward(self, x: torch.tensor, 
                mask: Optional[torch.tensor]=None) -> torch.tensor:
        x, _ = self.gru(x, None)
        x, _ = self.attention(x, mask=mask == 0)
        return x

In [None]:
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy, Metric

def prod(x: Iterable):
    acc = 1
    for v in x: acc *= v
    return acc

class MultilabelAccuracy(Metric):
    def __init__(self, thres=0.5):
        self.thres = 0.5
        self.correct_count = 0
        self.total_count = 0
    
    def __call__(self, logits: torch.FloatTensor, 
                 t: torch.LongTensor) -> float:
        logits = logits.detach().cpu().numpy()
        t = t.detach().cpu().numpy()
        cc = ((logits >= self.thres) == t).sum()
        tc = prod(logits.shape)
        self.correct_count += cc
        self.total_count += tc
        return cc / tc
    
    def get_metric(self, reset: bool=False):
        acc = self.correct_count / self.total_count
        if reset:
            self.reset()
        return acc
    
    @overrides
    def reset(self):
        self.correct_count = 0
        self.total_count = 0
    
class MultilabelCrossEntropyLoss(nn.Module):
    def forward(self, lgt, tgt: torch.LongTensor):
        neg_abs = -lgt.abs()
        loss = lgt.clamp(min=0) - lgt * tgt.float() + (1 + neg_abs.exp()).log()
        return loss.mean()

In [None]:
from allennlp.nn.util import move_to_device, has_tensor

def permute(obj, p: torch.Tensor):
    """
    Given a structure (possibly) containing Tensors on the CPU,
    permute all the Tensors
    """
    if not has_tensor(obj):
        return obj
    elif isinstance(obj, torch.Tensor):
        return obj[p]
    elif isinstance(obj, dict):
        return {key: permute(value, p) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [permute(item, p) for item in obj]
    elif isinstance(obj, tuple):
        return tuple([permute(item, p) for item in obj])
    else:
        return obj

In [None]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from torch.distributions.beta import Beta

class BaselineModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 out_sz: int=config.n_classes,
                 multilabel: bool=True, mixup_alpha: int=0.2):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.multilabel = multilabel
        self.lambda_sampler = Beta(torch.tensor([mixup_alpha]), torch.tensor([mixup_alpha]))
        # TODO: Handle multiclass case
        if self.multilabel:
            self.accuracy = MultilabelAccuracy()
            self.per_label_acc = {c: MultilabelAccuracy() for c in label_cols}
            self.loss = MultilabelCrossEntropyLoss()
        else:
            self.loss = nn.CrossEntropyLoss()
            self.accuracy = CategoricalAccuracy()

    def forward(self, tokens: Dict[str, torch.Tensor],
                label: torch.Tensor, **meta) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}

        output["accuracy"] = self.accuracy(class_logits, label)
        output["loss"] = self.loss(class_logits, label)

        return output

    def mixup(self, tokens: Dict[str, torch.Tensor],
              label: torch.Tensor, **meta) -> TensorDict:
        # generate new tokens and labels
        bs = label.size(0)
        shuf = torch.randperm(bs).to(label.device)
        tokens2 = permute(tokens, shuf)
        labels2 = permute(label, shuf)
        # TODO: Think of how to handle this masking intelligently
        mask1, mask2 = (get_text_field_mask(t) for t in (tokens, tokens2))
        embs1, embs2 = (self.word_embeddings(t) for t in (tokens, tokens2))
        # interpolate
        ratios = self.lambda_sampler.sample((bs, 1)).to(label.device)
        embs = ratios * embs1 + (1-ratios) * embs2
        label = ratios * label + (1-ratios) * labels2
        
        # remaining process is the same
        state = self.encoder(embs, mask1 * mask2) # TODO: Handle masking
        class_logits = self.projection(state)
        
        output = {"loss": self.loss(class_logits, label)}
        return output
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

### Prepare embeddings

In [None]:
import fastText

def get_fasttext_embeddings(model_path: str, vocab: Vocabulary):
    vocab_size = min(vocab.get_vocab_size(), config.max_vocab_size)
    ft_model = fastText.load_model(config.ft_model_path)
    embedding_dim = ft_model.get_dimension()

    # register parameters
    config.set("vocab_size", vocab_size)
    config.set("embedding_dim", embedding_dim)
    
    embeddings = np.zeros((vocab_size + 5, embedding_dim))
    for idx, token in vocab.get_index_to_token_vocabulary().items():
        embeddings[idx, :] = ft_model.get_word_vector(token)
    
    return embeddings

In [None]:
with timer("Loading embeddings"):
    embedding_weights = get_fasttext_embeddings(config.ft_model_path, vocab)

In [None]:
class CustomEmbedding(Embedding):
    def __init__(self, num_embeddings, embedding_dim,
                 padding_index=None, max_norm=None,
                 weight=None, dropout=0., scale=None):
        super().__init__(num_embeddings, embedding_dim)
        self.dropout = dropout
        self.scale = scale
        self.padding_idx = padding_index
        self.embed = Embedding(num_embeddings, embedding_dim,
                               padding_index=padding_index, max_norm=max_norm,
                               weight=weight)
    
    def forward(self, words):
        weight = self.embed.weight
        if self.dropout > 0.0 and self.training:
            mask = self.embed.weight.data.new().resize_((weight.size(0), 1)).bernoulli_(1 - self.dropout).expand_as(weight) / (1 - self.dropout)
            masked_embed_weight = mask * weight
        else:
            masked_embed_weight = weight
        if self.scale:
            masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight

        padding_idx = self.padding_idx
        if padding_idx is None:
            padding_idx = -1

        X = torch.nn.functional.embedding(words, masked_embed_weight,
            padding_idx, self.embed.max_norm, self.embed.norm_type,
            self.embed.scale_grad_by_freq, self.embed.sparse
          )
        return X

In [None]:
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.data import Vocabulary
from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
from allennlp.modules.time_distributed import TimeDistributed
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder

class CustomTextFieldEmbedder(TextFieldEmbedder):
    """
    This is a ``TextFieldEmbedder`` that wraps a collection of :class:`TokenEmbedder` objects.  Each
    ``TokenEmbedder`` embeds or encodes the representation output from one
    :class:`~allennlp.data.TokenIndexer`.  As the data produced by a
    :class:`~allennlp.data.fields.TextField` is a dictionary mapping names to these
    representations, we take ``TokenEmbedders`` with corresponding names.  Each ``TokenEmbedders``
    embeds its input, and the result is concatenated in an arbitrary order.

    Parameters
    ----------

    token_embedders : ``Dict[str, TokenEmbedder]``, required.
        A dictionary mapping token embedder names to implementations.
        These names should match the corresponding indexer used to generate
        the tensor passed to the TokenEmbedder.
    embedder_to_indexer_map : ``Dict[str, List[str]]``, optional, (default = None)
        Optionally, you can provide a mapping between the names of the TokenEmbedders
        that you are using to embed your TextField and an ordered list of indexer names
        which are needed for running it. In most cases, your TokenEmbedder will only
        require a single tensor, because it is designed to run on the output of a
        single TokenIndexer. For example, the ELMo Token Embedder can be used in
        two modes, one of which requires both character ids and word ids for the
        same text. Note that the list of token indexer names is `ordered`, meaning
        that the tensors produced by the indexers will be passed to the embedders
        in the order you specify in this list.
    allow_unmatched_keys : ``bool``, optional (default = False)
        If True, then don't enforce the keys of the ``text_field_input`` to
        match those in ``token_embedders`` (useful if the mapping is specified
        via ``embedder_to_indexer_map``).
    """
    def __init__(self,
                 token_embedders: Dict[str, TokenEmbedder],
                 embedder_to_indexer_map: Dict[str, List[str]] = None,
                 allow_unmatched_keys: bool = False) -> None:
        super(BasicTextFieldEmbedder, self).__init__()
        self._token_embedders = token_embedders
        self._embedder_to_indexer_map = embedder_to_indexer_map
        for key, embedder in token_embedders.items():
            name = 'token_embedder_%s' % key
            self.add_module(name, embedder)
        self._allow_unmatched_keys = allow_unmatched_keys

    @overrides
    def get_output_dim(self) -> int:
        output_dim = 0
        for embedder in self._token_embedders.values():
            output_dim += embedder.get_output_dim()
        return output_dim
    
    def augment(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int=0):
        pass

    def forward(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int = 0) -> torch.Tensor:
        embedder_keys = self._token_embedders.keys()
        input_keys = text_field_input.keys()

        # Check for unmatched keys
        if not self._allow_unmatched_keys:
            if embedder_keys < input_keys:
                # token embedder keys are a strict subset of text field input keys.
                message = (f"Your text field is generating more keys ({list(input_keys)}) "
                           f"than you have token embedders ({list(embedder_keys)}. "
                           f"If you are using a token embedder that requires multiple keys "
                           f"(for example, the OpenAI Transformer embedder or the BERT embedder) "
                           f"you need to add allow_unmatched_keys = True "
                           f"(and likely an embedder_to_indexer_map) to your "
                           f"BasicTextFieldEmbedder configuration. "
                           f"Otherwise, you should check that there is a 1:1 embedding "
                           f"between your token indexers and token embedders.")
                raise ConfigurationError(message)

            elif self._token_embedders.keys() != text_field_input.keys():
                # some other mismatch
                message = "Mismatched token keys: %s and %s" % (str(self._token_embedders.keys()),
                                                                str(text_field_input.keys()))
                raise ConfigurationError(message)

        embedded_representations = []
        keys = sorted(embedder_keys)
        for key in keys:
            # If we pre-specified a mapping explictly, use that.
            if self._embedder_to_indexer_map is not None:
                tensors = [text_field_input[indexer_key] for
                           indexer_key in self._embedder_to_indexer_map[key]]
            else:
                # otherwise, we assume the mapping between indexers and embedders
                # is bijective and just use the key directly.
                tensors = [text_field_input[key]]
            # Note: need to use getattr here so that the pytorch voodoo
            # with submodules works with multiple GPUs.
            embedder = getattr(self, 'token_embedder_{}'.format(key))
            for _ in range(num_wrapping_dims):
                embedder = TimeDistributed(embedder)
            token_vectors = embedder(*tensors)
            embedded_representations.append(token_vectors)
        return torch.cat(embedded_representations, dim=-1)

    # This is some unusual logic, it needs a custom from_params.
    @classmethod
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder':  # type: ignore
        # pylint: disable=arguments-differ,bad-super-call

        # The original `from_params` for this class was designed in a way that didn't agree
        # with the constructor. The constructor wants a 'token_embedders' parameter that is a
        # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those
        # key-value pairs to be top-level in the params object.
        #
        # This breaks our 'configuration wizard' and configuration checks. Hence, going forward,
        # the params need a 'token_embedders' key so that they line up with what the constructor wants.
        # For now, the old behavior is still supported, but produces a DeprecationWarning.

        embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None)
        if embedder_to_indexer_map is not None:
            embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True)
        allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False)

        token_embedder_params = params.pop('token_embedders', None)

        if token_embedder_params is not None:
            # New way: explicitly specified, so use it.
            token_embedders = {
                    name: TokenEmbedder.from_params(subparams, vocab=vocab)
                    for name, subparams in token_embedder_params.items()
            }

        else:
            # Warn that the original behavior is deprecated
            warnings.warn(DeprecationWarning("the token embedders for BasicTextFieldEmbedder should now "
                                             "be specified as a dict under the 'token_embedders' key, "
                                             "not as top-level key-value pairs"))

            token_embedders = {}
            keys = list(params.keys())
            for key in keys:
                embedder_params = params.pop(key)
                token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params)

        params.assert_empty(cls.__name__)
        return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys)

In [None]:
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder

if config.model_type == "standard":
    token_embedding = CustomEmbedding(num_embeddings=config.vocab_size + 5,
                                      embedding_dim=config.embedding_dim,
                                      weight=torch.tensor(embedding_weights, dtype=torch.float),
                                      dropout=config.dropoute, padding_index=0)
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
elif "elmo" in config.model_type:
    from allennlp.modules.token_embedders import ElmoTokenEmbedder

    options_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json'
    weight_file = 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5'

    elmo_embedder = ElmoTokenEmbedder(options_file, weight_file)
    word_embeddings = BasicTextFieldEmbedder({"tokens": elmo_embedder})
elif "bert" in config.model_type:
    from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertEmbedder
    bert_embedder = PretrainedBertEmbedder(
             pretrained_model=config.model_type,
             top_layer_only=True, # conserve memory
    )
    word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": bert_embedder},
                                                                 # we'll be ignoring masks so we'll need to set this to True
                                                                allow_unmatched_keys = True)

In [None]:
if "bert" not in config.model_type:
    encoder = BiGRUAttentionEncoder(
            word_embeddings.get_output_dim(),
            config.hidden_sz,
        )
else:
    BERT_DIM = word_embeddings.get_output_dim()

    class BertSentencePooler(Seq2VecEncoder):
        def forward(self, embs: torch.tensor, 
                    mask: torch.tensor=None) -> torch.tensor:
            # extract first token tensor
            return embs[:, 0]

        @overrides
        def get_output_dim(self) -> int:
            return BERT_DIM

    encoder = BertSentencePooler(vocab)

In [None]:
model = BaselineModel(
    word_embeddings, 
    encoder, 
    out_sz=config.n_classes,
)

Initialize bias according to prior

In [None]:
if config.bias_init:
    class_bias = torch.zeros(len(label_cols))
    for i, _ in enumerate(label_cols):
        p = train_labels[:, i].mean()
        class_bias[i] = np.log(p / (1-p))

    model.projection.bias.data = class_bias

In [None]:
if USE_GPU: model.cuda()
else: model

### Basic sanity checks

In [None]:
batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)

In [None]:
batch

In [None]:
tokens = batch["tokens"]
labels = batch

mask = get_text_field_mask(tokens) == 0
embeddings = model.word_embeddings(tokens)
state = model.encoder(embeddings, mask)
class_logits = model.projection(state)

In [None]:
tokens

In [None]:
model(**batch)

In [None]:
loss = model(**batch)["loss"]

In [None]:
loss

In [None]:
loss.backward()

In [None]:
[x.grad for x in list(model.encoder.parameters())]

In [None]:
batch["label"].shape[0]

In [None]:
batch["tokens"]["tokens"].shape[0]

In [None]:
p = torch.randperm(batch["label"].size(0))
tokens1, tokens2 = batch["tokens"], {"tokens": batch["tokens"]["tokens"][p]}
labels1, labels2 = batch["label"], batch["label"][p]
mask1, mask2 = (get_text_field_mask(t) for t in (tokens1, tokens2))
embs1, embs2 = (model.word_embeddings(t) for t in (tokens1, tokens2))

In [None]:
embs1.shape

In [None]:
ratios = model.lambda_sampler.sample((mask.size(0), 1))

In [None]:
# interpolate
embs = ratios * embs1 + (1-ratios) * embs2
label = ratios * labels1 + (1-ratios) * labels2

In [None]:
embs.shape

In [None]:
tokens1

In [None]:
model.mixup(**batch)

# Train

In [None]:
from allennlp.training import trainer as _trainer
from allennlp.training.trainer import *
import math
logger = _trainer.logger

N_BATCHES_PER_UPDATE = config.batch_size // config.computational_batch_size

class CustomTrainer(Trainer):
    def __init__(self, *args, mixup_ratio: float=0., **kwargs):
        """Applies mixup to mixup_ratio samples in each batch"""
        super().__init__(*args, **kwargs)
        self.mixup_ratio = mixup_ratio
        
    @overrides
    def batch_loss(self, batch: TensorDict, for_training=True) -> torch.Tensor:
        batch = nn_util.move_to_device(batch, self._cuda_devices[0])
        output_dict = self.model(**batch)
        try:
            loss = output_dict["loss"]
            if for_training:
                loss += self.model.get_regularization_penalty()
        except KeyError:
            if for_training:
                raise RuntimeError("The model you are trying to optimize does not contain a"
                                   " 'loss' key in the output of model.forward(inputs).")
            return None
        
        # apply mixup loss
        if for_training and self.mixup_ratio > 0.:
            mixup_output_dict = self.model.mixup(**batch)
            loss += mixup_output_dict["loss"] * self.mixup_ratio
        return loss
        
    @gpu_mem_restore
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics. Copied from source
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        peak_cpu_usage = peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in gpu_memory_mb().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self.model.train()

        # Get tqdm for the training batches
        train_generator = self.iterator(self.train_data,
                                        num_epochs=1,
                                        shuffle=self.shuffle)
        num_training_batches = self.iterator.get_num_batches(self.train_data)
        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        if self._histogram_interval is not None:
            histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())

        logger.info("Training")
        train_generator_tqdm = Tqdm.tqdm(train_generator,
                                         total=num_training_batches)
        cumulative_batch_size = 0
        for batch in train_generator_tqdm:
            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self._log_histograms_this_batch = self._histogram_interval is not None and (
                    batch_num_total % self._histogram_interval == 0)

            self.optimizer.zero_grad()
            
            ###########
            # Custom  #
            ###########
            loss = self.batch_loss(batch, for_training=True)
            if torch.isnan(loss):
                raise ValueError("nan loss encountered")
            train_loss += loss.item()
            # wait to update
            if (batches_this_epoch % N_BATCHES_PER_UPDATE) != 0: continue
            ###############
            # End Custom  #
            ###############
            
            loss.backward()
            batch_grad_norm = self.rescale_gradients()
            
            ###########
            # Custom  #
            ###########
            for name, param in self.model.named_parameters():
                if torch.isnan(param.data).any() or torch.isinf(param.data).any():
                    raise ValueError(f"Nan/Inf weights in param {name}: \n {param}")
            ###############
            # End Custom  #
            ###############

            # This does nothing if batch_num_total is None or you are using an
            # LRScheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)

            if self._log_histograms_this_batch:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {name: param.detach().cpu().clone()
                                 for name, param in self.model.named_parameters()}
                self.optimizer.step()
                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1, ))
                    param_norm = torch.norm(param.view(-1, )).cpu()
                    self._tensorboard.add_train_scalar("gradient_update/" + name,
                                                       update_norm / (param_norm + 1e-7),
                                                       batch_num_total)
            else:
                self.optimizer.step()

            # Update the description with the latest metrics
            metrics = self._get_metrics(train_loss, batches_this_epoch)
            description = self._description_from_metrics(metrics)

            train_generator_tqdm.set_description(description, refresh=False)

            # Log parameter values to Tensorboard
            if batch_num_total % self._summary_interval == 0:
                if self._should_log_parameter_statistics:
                    self._parameter_and_gradient_statistics_to_tensorboard(batch_num_total, batch_grad_norm)
                if self._should_log_learning_rate:
                    self._learning_rates_to_tensorboard(batch_num_total)
                self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"], batch_num_total)
                self._metrics_to_tensorboard(batch_num_total,
                                             {"epoch_metrics/" + k: v for k, v in metrics.items()})

            if self._log_histograms_this_batch:
                self._histograms_to_tensorboard(batch_num_total, histogram_parameters)

            if self._log_batch_size_period:
                cur_batch = self._get_batch_size(batch)
                cumulative_batch_size += cur_batch
                if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
                    average = cumulative_batch_size/batches_this_epoch
                    logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
                    self._tensorboard.add_train_scalar("current_batch_size", cur_batch, batch_num_total)
                    self._tensorboard.add_train_scalar("mean_batch_size", average, batch_num_total)

            # Save model if needed.
            if self._model_save_interval is not None and (
                    time.time() - last_save_time > self._model_save_interval
            ):
                last_save_time = time.time()
                self._save_checkpoint(
                        '{0}.{1}'.format(epoch, time_to_str(int(last_save_time))), [], is_best=False
                )
        metrics = self._get_metrics(train_loss, batches_this_epoch, reset=True)
        metrics['cpu_memory_MB'] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics['gpu_'+str(gpu_num)+'_memory_MB'] = memory
        return metrics

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config.lr)

In [None]:
training_options = {
    # TODO: Add appropriate learning rate scheduler
    "should_log_parameter_statistics": True,
    "should_log_learning_rate": True,
    "num_epochs": config.epochs,
    "mixup_ratio": config.mixup_ratio,
}

In [None]:
if (os.environ["IS_COLAB"] != "True" and not config.testing):
    SER_DIR = DATA_ROOT / "ckpts" / RUN_ID
else:
    SER_DIR = None

trainer = CustomTrainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds + train_aug_ds if config.use_augmented else train_ds,
    validation_dataset=val_ds,
    serialization_dir=SER_DIR,
    cuda_device=0 if USE_GPU else -1,
    **training_options,
)

In [None]:
metrics = trainer.train()

In [None]:
metrics

# Evaluate

In [None]:
from scipy.special import expit
from collections import defaultdict

def dict_append(d: Dict[str, List], upd: Dict[str, Any]) -> Dict[str, List]:
    for k, v in upd.items(): d[k].append(v)

def tonp(tsr): return tsr.detach().cpu().numpy()
        
class Predictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> Dict[str, np.ndarray]:
        out_dict = self.model(**batch)
        lens = tonp(get_text_field_mask(batch["tokens"]).sum(1))
        return {
                "preds": expit(tonp(out_dict["class_logits"])),
                "oov_ratio": tonp((batch["tokens"]["tokens"] == 1).sum(1)) / lens,
                "lens": lens,
               }
        
    def _postprocess(self, predictions: Dict[str, np.ndarray]) -> np.ndarray:
        return {k: np.concatenate(v, axis=0) for k, v in predictions.items()}
    
    @gpu_mem_restore
    def predict(self, ds: Iterable[Instance]) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = Tqdm.tqdm(pred_generator,
                                        total=self.iterator.get_num_batches(ds))
        preds = defaultdict(list)
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                dict_append(preds, self._extract_data(batch))
        return self._postprocess(preds)

In [None]:
from allennlp.data.iterators import BasicIterator
seq_iterator = BasicIterator(batch_size=64)
seq_iterator.index_with(vocab)

In [None]:
predictor = Predictor(model, seq_iterator, cuda_device=0 if USE_GPU else -1)
train_meta = predictor.predict(train_ds) 
train_preds = train_meta.pop("preds")
test_meta = predictor.predict(test_ds)
test_preds = test_meta.pop("preds")

In [None]:
test_labels = pd.read_csv(DATA_ROOT / "test_proced.csv")[label_cols].values
if config.testing:
    test_labels = test_labels[:len(test_ds), :]

In [None]:
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, confusion_matrix

Per label

In [None]:
class Evaluator:
    def __init__(self, thres=0.5):
        if isinstance(thres, float):
            self.thres = np.ones(len(label_cols)) * thres
        else:
            self.thres = thres
    
    def _to_metric_dict(self, t: np.ndarray, y: np.ndarray, thres: float) -> Dict:
        tn, fp, fn, tp = confusion_matrix(t, y >= thres).ravel()
        return {"auc": roc_auc_score(t, y),
                "f1": f1_score(t, y >= thres),
                "acc": accuracy_score(t, y >= thres),
                "tnr": tn / len(t), "fpr": fp / len(t),
                "fnr": fn / len(t), "tpr": tp / len(t),
                "precision": tp / (tp + fp), "recall": tp / (tp + fn),
        }

    def _stats_per_quadrant(self, tgt, preds, metadata: Dict[str, np.ndarray]):
        out_data = {}
        for i, lbl in enumerate(label_cols):
            # get indicies of each quadrant`
            preds_bin = preds[:, i] >= self.thres[i]
            quads = {
                "tp": np.where((tgt[:, i] == 1) & preds_bin)[0],
                "fp": np.where((tgt[:, i] == 0) & preds_bin)[0],
                "tn": np.where((tgt[:, i] == 0) & ~preds_bin)[0],
                "fn": np.where((tgt[:, i] == 1) & ~preds_bin)[0],
            }
            # get stats for metadata
            out_data[lbl] = {}
            for q, qidxs in quads.items():
                out_data[lbl][q] = {}
                for k, full_data in metadata.items():
                    data = full_data[qidxs]
                    if len(data) > 0:
                        out_data[lbl][q][f"{k}_mean"] = data.mean()
                        out_data[lbl][q][f"{k}_std"] = data.std()
                        out_data[lbl][q][f"{k}_min"] = data.min()
                        out_data[lbl][q][f"{k}_max"] = data.max()
                    else:
                        out_data[lbl][q][f"{k}_mean"] = np.nan
                        out_data[lbl][q][f"{k}_std"] = np.nan
                        out_data[lbl][q][f"{k}_min"] = np.nan
                        out_data[lbl][q][f"{k}_max"] = np.nan
        return out_data
    
    @gpu_mem_restore
    def evaluate(self, tgt: np.ndarray, preds: np.ndarray,
                 trn_tgt: np.ndarray, trn_preds: np.ndarray,
                 metadata: Dict[str, np.ndarray]={}) -> Dict:
        """
        Metadata: Data about the inputs (e.g. length, OOV ratio)
        """
        train_label_metrics = {}
        label_metrics = {}
                
        # get per-label stats
        for i, lbl in enumerate(label_cols):
            train_label_metrics[lbl] = self._to_metric_dict(trn_tgt[:, i],
                                                            trn_preds[:, i],
                                                            self.thres[i])
            label_metrics[lbl] = self._to_metric_dict(tgt[:, i], preds[:, i],
                                                      self.thres[i])
            print(f"========{lbl}=========")
            print(label_metrics[lbl])
        
        # get global stats
        label_metrics["global"] = {}
        for mtrc in label_metrics["toxic"].keys():
            label_metrics["global"][mtrc] = \
                np.mean([label_metrics[col][mtrc] for col in label_cols])
            
        # get per-label-quadrant stats
        quad_stats = self._stats_per_quadrant(tgt, preds, metadata)
        if len(quad_stats) > 0:
            for c in label_cols:
                label_metrics[c]["quad_stats"] = quad_stats[c]

        metrics = {
            "train": train_label_metrics,
            "test": label_metrics,
        }
        return metrics

In [None]:
# Compute best threshold based on training data
thres = np.zeros(len(label_cols))
for i, col in enumerate(label_cols):
    best_score = -1
    best_thres = -1
    for x in np.linspace(0, 1.0, num=99):
        scr = f1_score(train_labels[:, i], train_preds[:, i] > x)
        if scr > best_score:
            best_thres = x
            best_score = scr
    thres[i] = best_thres

In [None]:
thres

In [None]:
evaluator = Evaluator(thres=thres)
label_metrics = evaluator.evaluate(
    test_labels, test_preds,
    train_labels, train_preds,
    metadata=test_meta,
)

In [None]:
label_metrics

# Record results and save weights

In [None]:
if os.environ["IS_COLAB"] != "True":
    import sys
    sys.path.append("../lib")
    from record_experiments import record
else:
    PASSWORD = "foobar" # FILL IN IF COLAB

    from typing import *
    import pymongo
    from bson.objectid import ObjectId
    import os
    import logging

    # Logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter('[%(levelname)s] %(asctime)s - %(name)s %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    conn_str = f"mongodb+srv://root:{PASSWORD}@cluster0-ptgoc.mongodb.net/test?retryWrites=true"

    client = pymongo.MongoClient(conn_str)
    db = client.experiments
    collection = db.logs

    def record(log: dict):
        res = collection.insert_one(log)
        logger.info(f"Inserted results at id {res.inserted_id}")
        return res

    def find(id_: Optional[str]=None, query: Optional[dict]=None):
        if query is None: query = {"_id": ObjectId(id_)}
        res = collection.find_one(query)
        return res

    def delete(id_: Optional[str]=None, query: Optional[dict]=None):
        if query is None: query = {"_id": ObjectId(id_)}
        res = collection.delete_many(query)
        logger.info(f"Deleted {res.deleted_count} entries")
        return res

Record summary

In [None]:
from datetime import datetime
from pytz import timezone

if not config.testing:
    experiment_log = dict(config)
    tz = timezone('EST')
    experiment_log["execution_date"] = datetime.now(tz).strftime("%Y-%m-%d %H:%M %Z")
    experiment_log.update(metrics)
    experiment_log.update(label_metrics)
    record(experiment_log)