In [None]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField, LabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

In [None]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

In [None]:
# for papermill
testing = True
seed = 1
computational_batch_size = 16
batch_size = 16
lr = 5e-5
epochs = 1
embed_dim = 200
hidden_sz = 64
dataset = "jigsaw"
n_classes = 6
max_seq_len = 512
run_id = "jigsaw_rep_test"
download_data = False

In [None]:
import subprocess
if download_data:
    for fname in ["train.csv", "test_proced.csv"]:
        subprocess.run(["aws", "s3", "cp", f"s3://nnfornlp/data/jigsaw/{fname}"], 
                       shell=True, check=True)

In [None]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    seed=seed,
    computational_batch_size=computational_batch_size,
    batch_size=batch_size,
    lr=lr,
    epochs=epochs,
    embed_dim=embed_dim,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
    run_id=run_id,
)

In [None]:
from allennlp.common.checks import ConfigurationError

In [None]:
import datetime
now = datetime.datetime.now()
RUN_ID = config.run_id if config.run_id is not None else now.strftime("%m_%d_%H:%M:%S")

In [None]:
USE_GPU = torch.cuda.is_available()

In [None]:
DATA_ROOT = Path("../data") / config.dataset

Set random seed manually to replicate results

In [None]:
torch.manual_seed(config.seed)

# Load Data

In [None]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader, StanfordSentimentTreeBankDatasetReader

### Prepare dataset

In [None]:
reader_registry = {}

In [None]:
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [None]:
label_cols = ["toxic", "severe_toxic", "obscene",
              "threat", "insult", "identity_hate"]

from enum import IntEnum
ColIdx = IntEnum('ColIdx', [(x.upper(), i) for i, x in enumerate(label_cols)])

In [None]:
@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token],
                         toxic: int, severe_toxic: int, obscene: int,
                         threat: int, insult: int, identity_hate: int) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}

        toxic_field = LabelField(label=toxic, skip_indexing=True)
        fields["toxic"] = toxic_field
        
        severe_toxic_field = LabelField(label=severe_toxic, skip_indexing=True)
        fields["severe_toxic"] = severe_toxic_field
        
        obscene_field = LabelField(label=obscene, skip_indexing=True)
        fields["obscene"] = obscene_field
        
        threat_field = LabelField(label=threat, skip_indexing=True)
        fields["threat"] = threat_field
        
        insult_field = LabelField(label=insult, skip_indexing=True)
        fields["insult"] = insult_field
        
        identity_hate_field = LabelField(label=identity_hate, skip_indexing=True)
        fields["identity_hate"] = identity_hate_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(10000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["comment_text"])],
                row["toxic"], row["severe_toxic"], row["obscene"], 
                row["threat"], row["insult"], row["identity_hate"],
            )

### Prepare token handlers

In [None]:
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer

In [None]:
token_indexer = SingleIdTokenIndexer(
    lowercase_tokens=False,  # don't lowercase by default
)
tokenizer = lambda x: x.split()

In [None]:
reader = JigsawDatasetReader(
    tokenizer=tokenizer,
    token_indexers={"tokens": token_indexer}
)

In [None]:
train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train.csv", "test_proced.csv"])
val_ds = None

In [None]:
len(train_ds)

### Prepare vocabulary

In [None]:
vocab = Vocabulary.from_instances(train_ds)

### Prepare iterator

In [None]:
from allennlp.data.iterators import BucketIterator

In [None]:
# TODO: Allow for customization
iterator = BucketIterator(batch_size=config.batch_size, 
                          biggest_batch_first=True,
                          sorting_keys=[("tokens", "num_tokens")],
                         )
iterator.index_with(vocab)

### Read sample

In [None]:
batch = next(iter(iterator(train_ds)))

In [None]:
batch

In [None]:
batch["tokens"]["tokens"]

In [None]:
batch["tokens"]["tokens"].shape

# Prepare Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask

In [None]:
class Attention(nn.Module):
    def __init__(self, inp_sz, dim=1, eps=1e-9):
        super().__init__()
        self.inp_sz, self.dim, self.eps = inp_sz, dim, eps
        self.l1 = nn.Linear(inp_sz, inp_sz)
        nn.init.xavier_uniform_(self.l1.weight.data)
        nn.init.zeros_(self.l1.bias.data)
        
        vw = torch.zeros(inp_sz, 1)
        nn.init.xavier_uniform_(vw)        
        self.vw = nn.Parameter(vw)
        
    def forward(self, x, mask=None):
        e = torch.tanh(self.l1(x))
        e = torch.einsum("bij,jk->bi", [e, self.vw])            
        a = torch.exp(e)
        
        if mask is not None: a = a.masked_fill(mask, 0)

        a = a / (torch.sum(a, dim=self.dim, keepdim=True) + self.eps)

        weighted_input = x * a.unsqueeze(-1)
        return torch.sum(weighted_input, dim=1), a

In [None]:
class BiGRUAttentionEncoder(Seq2VecEncoder):
    def __init__(self, embed_sz: int, hidden_sz: int, num_layers=2):
        super().__init__()
        self.embed_sz = embed_sz
        self.hidden_sz = hidden_sz
        self.gru = nn.GRU(self.embed_sz, self.hidden_sz,
                          num_layers=num_layers, bidirectional=True)
        self.attention = Attention(self.hidden_sz * 2, dim=1)
        
    @overrides
    def get_input_dim(self) -> int:
        return self.embed_sz
    
    @overrides
    def get_output_dim(self) -> int:
        return self.hidden_sz * 2
    
    @overrides
    def forward(self, x: torch.tensor, 
                mask: Optional[torch.tensor]=None) -> torch.tensor:
        x, _ = self.gru(x, None)
        x, _ = self.attention(x, mask=mask)
        return x

In [None]:
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy, Metric

def prod(x: Iterable):
    acc = 1
    for v in x: acc *= v
    return acc

class MultilabelAccuracy(Metric):
    def __init__(self, thres=0.5):
        self.thres = 0.5
        self.correct_count = 0
        self.total_count = 0
    
    def __call__(self, logits: torch.FloatTensor, 
                 t: torch.LongTensor) -> float:
        logits = logits.detach().cpu().numpy()
        t = t.detach().cpu().numpy()
        cc = ((logits >= self.thres) == t).sum()
        tc = prod(logits.shape)
        self.correct_count += cc
        self.total_count += tc
        return cc / tc
    
    def get_metric(self, reset: bool=False):
        acc = self.correct_count / self.total_count
        if reset:
            self.reset()
        return acc
    
    @overrides
    def reset(self):
        self.correct_count = 0
        self.total_count = 0
    
class MultilabelCrossEntropyLoss(nn.Module):
    def forward(self, lgt, tgt: torch.LongTensor):
        neg_abs = -lgt.abs()
        loss = lgt.clamp(min=0) - lgt * tgt.float() + (1 + neg_abs.exp()).log()
        return loss.mean()

In [None]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder

class BaselineModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 out_sz: int=config.n_classes,
                 multilabel: bool=True):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)
        self.multilabel = multilabel
        # TODO: Handle multiclass case
        if self.multilabel:
            self.accuracy = MultilabelAccuracy()
            self.per_label_acc = {c: MultilabelAccuracy() for c in label_cols}
            self.loss = MultilabelCrossEntropyLoss()
        else:
            self.loss = nn.CrossEntropyLoss()
            self.accuracy = CategoricalAccuracy()
        
    def forward(self, tokens: Dict[str, torch.Tensor],
                **labels: torch.Tensor) -> torch.Tensor:
        mask = get_text_field_mask(tokens) == 1
        embeddings = self.word_embeddings(tokens)
        state = self.encoder(embeddings, mask)
        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        if len(labels) > 0:
            # This is grossly inefficient...
            label = torch.cat([labels[c].unsqueeze(-1) for c in label_cols], dim=1)
            output["accuracy"] = self.accuracy(class_logits, label)
            for i, c in enumerate(label_cols):
                output[f"{c}_acc"] = self.per_label_acc[c](class_logits[:, i], 
                                                          labels[c])
            output["loss"] = self.loss(class_logits, label)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

### Prepare embeddings

In [None]:
embedding_weights = None

In [None]:
# def load_fasttext(word_index):    
#     EMBEDDING_FILE = '../input/embeddings/wiki-news-300d-1M/wiki-news-300d-1M.vec'
#     def get_coefs(word,*arr): return word, np.asarray(arr, dtype='float32')
#     embeddings_index = dict(get_coefs(*o.split(" ")) for o in open(EMBEDDING_FILE) if len(o)>100)

#     all_embs = np.stack(embeddings_index.values())
#     emb_mean,emb_std = all_embs.mean(), all_embs.std()
#     embed_size = all_embs.shape[1]

#     # word_index = tokenizer.word_index
#     embedding_matrix = np.random.normal(emb_mean, emb_std, (config.max_features, config.embed_size))
#     for word, i in word_index.items():
#         if i >= config.max_features: continue
#         embedding_vector = embeddings_index.get(word)
#         if embedding_vector is not None: embedding_matrix[i] = embedding_vector

#     return embedding_matrix

In [None]:
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=config.embed_dim,
                            weight=embedding_weights)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
encoder = BiGRUAttentionEncoder(
    config.embed_dim, 
    config.hidden_sz,
)

In [None]:
model = BaselineModel(
    word_embeddings, 
    encoder, 
    out_sz=config.n_classes,
)

In [None]:
if USE_GPU: model.cuda()
else: model

### Basic sanity checks

In [None]:
np.isnan(list(model.word_embeddings.parameters())[0].detach().numpy()).any()

In [None]:
[np.isnan(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

In [None]:
[np.isinf(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

In [None]:
tokens = batch["tokens"]
labels = batch

mask = get_text_field_mask(tokens) == 1
embeddings = model.word_embeddings(tokens)
state = model.encoder(embeddings, mask)
class_logits = model.projection(state)

In [None]:
loss = model(**batch)["loss"]

In [None]:
loss

In [None]:
loss.backward()

In [None]:
[x.grad for x in list(model.encoder.parameters())]

# Train

In [None]:
from allennlp.training import trainer as _trainer
from allennlp.training.trainer import *
import math
logger = _trainer.logger

N_BATCHES_PER_UPDATE = config.batch_size // config.computational_batch_size

class CustomTrainer(Trainer):
    def _train_epoch(self, epoch: int) -> Dict[str, float]:
        """
        Trains one epoch and returns metrics. Copied from source
        """
        logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
        peak_cpu_usage = peak_memory_mb()
        logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
        gpu_usage = []
        for gpu, memory in gpu_memory_mb().items():
            gpu_usage.append((gpu, memory))
            logger.info(f"GPU {gpu} memory usage MB: {memory}")

        train_loss = 0.0
        # Set the model to "train" mode.
        self.model.train()

        # Get tqdm for the training batches
        train_generator = self.iterator(self.train_data,
                                        num_epochs=1,
                                        shuffle=self.shuffle)
        num_training_batches = self.iterator.get_num_batches(self.train_data)
        self._last_log = time.time()
        last_save_time = time.time()

        batches_this_epoch = 0
        if self._batch_num_total is None:
            self._batch_num_total = 0

        if self._histogram_interval is not None:
            histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())

        logger.info("Training")
        train_generator_tqdm = Tqdm.tqdm(train_generator,
                                         total=num_training_batches)
        cumulative_batch_size = 0
        for batch in train_generator_tqdm:
            batches_this_epoch += 1
            self._batch_num_total += 1
            batch_num_total = self._batch_num_total

            self._log_histograms_this_batch = self._histogram_interval is not None and (
                    batch_num_total % self._histogram_interval == 0)

            self.optimizer.zero_grad()
            
            ###########
            # Custom  #
            ###########
            loss = self.batch_loss(batch, for_training=True)
            if torch.isnan(loss):
                raise ValueError("nan loss encountered")
            train_loss += loss.item()
            # wait to update
            if (batches_this_epoch % N_BATCHES_PER_UPDATE) != 0: continue
            ###############
            # End Custom  #
            ###############
            
            loss.backward()
            batch_grad_norm = self.rescale_gradients()

            # This does nothing if batch_num_total is None or you are using an
            # LRScheduler which doesn't update per batch.
            if self._learning_rate_scheduler:
                self._learning_rate_scheduler.step_batch(batch_num_total)

            if self._log_histograms_this_batch:
                # get the magnitude of parameter updates for logging
                # We need a copy of current parameters to compute magnitude of updates,
                # and copy them to CPU so large models won't go OOM on the GPU.
                param_updates = {name: param.detach().cpu().clone()
                                 for name, param in self.model.named_parameters()}
                self.optimizer.step()
                for name, param in self.model.named_parameters():
                    param_updates[name].sub_(param.detach().cpu())
                    update_norm = torch.norm(param_updates[name].view(-1, ))
                    param_norm = torch.norm(param.view(-1, )).cpu()
                    self._tensorboard.add_train_scalar("gradient_update/" + name,
                                                       update_norm / (param_norm + 1e-7),
                                                       batch_num_total)
            else:
                self.optimizer.step()

            # Update the description with the latest metrics
            metrics = self._get_metrics(train_loss, batches_this_epoch)
            description = self._description_from_metrics(metrics)

            train_generator_tqdm.set_description(description, refresh=False)

            # Log parameter values to Tensorboard
            if batch_num_total % self._summary_interval == 0:
                if self._should_log_parameter_statistics:
                    self._parameter_and_gradient_statistics_to_tensorboard(batch_num_total, batch_grad_norm)
                if self._should_log_learning_rate:
                    self._learning_rates_to_tensorboard(batch_num_total)
                self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"], batch_num_total)
                self._metrics_to_tensorboard(batch_num_total,
                                             {"epoch_metrics/" + k: v for k, v in metrics.items()})

            if self._log_histograms_this_batch:
                self._histograms_to_tensorboard(batch_num_total, histogram_parameters)

            if self._log_batch_size_period:
                cur_batch = self._get_batch_size(batch)
                cumulative_batch_size += cur_batch
                if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
                    average = cumulative_batch_size/batches_this_epoch
                    logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
                    self._tensorboard.add_train_scalar("current_batch_size", cur_batch, batch_num_total)
                    self._tensorboard.add_train_scalar("mean_batch_size", average, batch_num_total)

            # Save model if needed.
            if self._model_save_interval is not None and (
                    time.time() - last_save_time > self._model_save_interval
            ):
                last_save_time = time.time()
                self._save_checkpoint(
                        '{0}.{1}'.format(epoch, time_to_str(int(last_save_time))), [], is_best=False
                )
        metrics = self._get_metrics(train_loss, batches_this_epoch, reset=True)
        metrics['cpu_memory_MB'] = peak_cpu_usage
        for (gpu_num, memory) in gpu_usage:
            metrics['gpu_'+str(gpu_num)+'_memory_MB'] = memory
        return metrics

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config.lr)

In [None]:
training_options = {
    # TODO: Add appropriate learning rate scheduler
    "should_log_parameter_statistics": True,
    "should_log_learning_rate": True,
    "num_epochs": config.epochs,
}

In [None]:
SER_DIR = DATA_ROOT / "ckpts" / RUN_ID

trainer = CustomTrainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=SER_DIR,
    cuda_device=0 if USE_GPU else -1,
    **training_options,
)

In [None]:
metrics = trainer.train()

In [None]:
metrics

# Record results and save weights

In [None]:
import sys
sys.path.append("../lib")

In [None]:
import record_experiments

Record summary

In [None]:
if not config.testing:
    experiment_log = dict(config)
    experiment_log.update(metrics)
    record_experiments.record(experiment_log)

Output tensorboard outputs and training logs to s3

(Remove weights since they take up too much space)

In [None]:
!rm {SER_DIR / "*.th"}

In [None]:
!ls {SER_DIR}

In [None]:
!aws s3 sync {SER_DIR} s3://nnfornlp/ckpts/{RUN_ID}