In [None]:
%connect_info

In [None]:
%matplotlib inline

In [None]:
import sys
sys.argv = sys.argv[:1]

In [None]:
import os
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.data import BucketIterator

from sklearn.utils import class_weight
from sklearn.metrics import confusion_matrix, classification_report

from torchtext.data.iterator import BucketIterator

from ignite.engine import Engine, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite.contrib.handlers import ProgressBar

from models import CNNClassifier
from preprocess import load_tokenized_data, SentenceDataset
from utils import build_model_name, convert_flags_to_dict, define_cnn_flags

from transformers import AutoModel

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
bert_type = 'distilbert-base-multilingual-cased'

In [None]:
(x_train, y_train), (x_val, y_val), (x_dev, y_dev) = load_tokenized_data(
    datafile='{}/data/{}.tokenized.pkl'.format(os.getcwd(), bert_type), 
    language_codes=['DE', 'GA', 'HI', 'PT', 'ZH'],
    seed=SEED)

In [8]:
train_iterator = BucketIterator(
    dataset=SentenceDataset(data=(x_train, y_train)),
    batch_size=32,
    sort_key=lambda x: len(x.sentence),
    shuffle=False,
    device=torch.device("cpu"))

valid_iterator = BucketIterator(
    dataset=SentenceDataset(data=(x_val, y_val)),
    batch_size=32,
    sort_key=lambda x: len(x.sentence),
    shuffle=False,
    device=torch.device("cpu"))

test_iterator = BucketIterator(
    dataset=SentenceDataset(data=(x_dev, y_dev)),
    batch_size=32,
    sort_key=lambda x: len(x.sentence),
    shuffle=False,
    device=torch.device("cpu"))

In [9]:
transformer = AutoModel.from_pretrained(bert_type)

for param in transformer.parameters():
    param.requires_grad = False

transformer.to(torch.device("cpu"))
# tokenizer = AutoTokenizer.from_pretrained(bert_type)

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(119547, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0): TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(

In [10]:
config = {
    'nfilters': 128,
    'kernels': [1, 2, 3, 4, 5],
    'pool_stride': 3,
    'dropout': 0.2,
    'output_activation': 'sigmoid',
    'emb_dim': transformer.embeddings.word_embeddings.embedding_dim,
    
}

In [11]:
class CNNClassifier(nn.Module):
    def __init__(self, config):
        super(CNNClassifier, self).__init__()

        self.convolutions = nn.ModuleList([
            nn.Conv1d(
                in_channels=config["emb_dim"],
                out_channels=config["nfilters"],
                kernel_size=kernel_size,
                stride=1) for kernel_size in config["kernels"]])

        self.pool_stride = config["pool_stride"]

        self.dropout = nn.Dropout(config["dropout"])
        self.fully_connected = nn.Linear(
            (config["nfilters"] // config["pool_stride"]) * len(config["kernels"]), 2)

        self.output_activation = (torch.sigmoid  # pylint: disable=no-member
                                  if config["output_activation"] == 'sigmoid'
                                  else F.softmax)

    def forward(self, x):
        seq_len = x.shape[-1]
        #
        x = [F.relu(conv(x)).transpose(1, 2) for conv in self.convolutions]
        x = [nn.functional.pad(i, (0, 0, 0, seq_len - i.shape[1])) for i in x]
        x = [F.max_pool1d(c, self.pool_stride) for c in x]
        x = torch.cat(x, dim=2)  # pylint: disable=no-member
        x = self.fully_connected(x)
        x = self.dropout(x)
        return self.output_activation(x)

In [12]:
model = CNNClassifier(config)
model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))   # pylint: disable=no-member

CNNClassifier(
  (convolutions): ModuleList(
    (0): Conv1d(768, 128, kernel_size=(1,), stride=(1,))
    (1): Conv1d(768, 128, kernel_size=(2,), stride=(1,))
    (2): Conv1d(768, 128, kernel_size=(3,), stride=(1,))
    (3): Conv1d(768, 128, kernel_size=(4,), stride=(1,))
    (4): Conv1d(768, 128, kernel_size=(5,), stride=(1,))
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (fully_connected): Linear(in_features=210, out_features=2, bias=True)
)

In [13]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)
criterion = nn.BCELoss()

In [61]:
def process_function(engine, batch):
    x, m, y = batch.sentence, batch.mask, batch.labels
    x = transformer(x, attention_mask=m)[0].transpose(1, 2)
    model.train()
    optimizer.zero_grad()
    y_pred = model(x)
    loss = criterion(y_pred, torch.tensor(to_categorical(y)))
    loss.backward()
    optimizer.step()
    return loss.item()


def eval_function(engine, batch):
    x, m, y = batch.sentence, batch.mask, batch.labels
    x = transformer(x, attention_mask=m)[0].transpose(1, 2)
    model.eval()
    with torch.no_grad():
        y_pred = model(x)
        return y_pred, y

In [62]:
trainer = Engine(process_function)
train_evaluator = Engine(eval_function)
validation_evaluator = Engine(eval_function)

RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

In [63]:
def thresholded_output_transform(output):
    y_pred, y = output
    y_pred = torch.round(y_pred)
    return y_pred, y


Accuracy(output_transform=thresholded_output_transform).attach(train_evaluator, 'accuracy')
Loss(criterion).attach(train_evaluator, 'bce')

Accuracy(output_transform=thresholded_output_transform).attach(validation_evaluator, 'accuracy')
Loss(criterion).attach(validation_evaluator, 'bce')

pbar = ProgressBar(persist=True, bar_format="")
pbar.attach(trainer, ['loss'])

In [64]:

def score_function(engine):
    val_loss = engine.state.metrics['bce']
    return -val_loss

handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer)
validation_evaluator.add_event_handler(Events.COMPLETED, handler)


<ignite.engine.engine.RemovableEventHandle at 0x116829cd0>

In [65]:
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
    train_evaluator.run(train_iterator)
    metrics = train_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    pbar.log_message(
        "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce))

def log_validation_results(engine):
    validation_evaluator.run(valid_iterator)
    metrics = validation_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_bce = metrics['bce']
    pbar.log_message(
        "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
        .format(engine.state.epoch, avg_accuracy, avg_bce))
    pbar.n = pbar.last_print_n = 0

In [66]:
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

checkpointer = ModelCheckpoint('/tmp/models', 'textcnn', n_saved=2, create_dir=True, save_as_state_dict=True)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'textcnn': model})


<ignite.engine.engine.RemovableEventHandle at 0x17e723c90>

In [67]:
trainer.run(train_iterator, max_epochs=20)

HBox(children=(FloatProgress(value=0.0, max=706.0), HTML(value='')))

Current run is terminating due to exception: .
Engine run is terminating due to exception: .


KeyboardInterrupt: 