In [374]:
from transformers import AutoModelForQuestionAnswering, AutoModel, AutoConfig, get_linear_schedule_with_warmup
from transformers.optimization import AdamW
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from pathlib import Path
import os
from itertools import compress
import utils

In [409]:
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from apex.optimizers.fused_lamb import FusedLAMB as Lamb
from fairseq import criterions 
from functools import partial
from tokenizers import BertWordPieceTokenizer
from sklearn.model_selection import train_test_split

import lineflow as lf
import lineflow.cross_validation as lfcv
from tqdm import tqdm

In [None]:
from torch.optim.lr_scheduler import wraps

In [412]:
file_dir = Path.cwd()/'data'
df = pd.read_csv(file_dir/'train.csv')
df['text'] = df['text'].apply(lambda x: str(x))
df['sentiment'] = df['sentiment'].apply(lambda x: str(x))
df['selected_text'] = df['selected_text'].apply(lambda x: str(x))

In [4]:
train_df.head()

Unnamed: 0,textID,text,selected_text,sentiment
0,cb774db0d1,"I`d have responded, if I were going","I`d have responded, if I were going",neutral
1,549e992a42,Sooo SAD I will miss you here in San Diego!!!,Sooo SAD,negative
2,088c60f138,my boss is bullying me...,bullying me,negative
3,9642c003ef,what interview! leave me alone,leave me alone,negative
4,358bd9e861,"Sons of ****, why couldn`t they put them on t...","Sons of ****,",negative


In [455]:
max_len = 128
bs = 64
tokenizer = BertWordPieceTokenizer('input/electra-base-disc/vocab.txt', lowercase=True)

In [356]:
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x): 
        x = x * (torch.tanh(F.softplus(x)))
        return x

In [325]:
def preprocess(sentiment, tweet, selected, tokenizer, max_len):
    _input = tokenizer.encode(sentiment, tweet)
    _span = tokenizer.encode(selected, add_special_tokens=False)
    
    len_span = len(_span.ids)
    start_idx = None
    end_idx = None
    
    for ind in (i for i, e in enumerate(_input.ids) if e == _span.ids[0]):
        if _input.ids[ind: ind + len_span] == _span.ids:
            start_idx = ind
            end_idx = ind + len_span - 1
            break
    
    # Handles cases where Wordpiece tokenizing input & span separately produces different outputs
    if not start_idx:
        idx0 = tweet.find(selected)
        idx1 = idx0 + len(selected)
        
        char_targets = [0] * len(tweet)
        if idx0 != None and idx1 != None:
            for ct in range(idx0, idx1):
                char_targets[ct] = 1
                
        tweet_offsets = list(compress(_input.offsets, _input.type_ids))[0:-1]
        
        target_idx = []
        for j, (offset1, offset2) in enumerate(tweet_offsets):
            if sum(char_targets[offset1: offset2]) > 0:
                target_idx.append(j)
                
        start_idx, end_idx = target_idx[0] +3 , target_idx[-1] + 3
        
    _input.start_target = start_idx
    _input.end_target = end_idx
    _input.tweet = tweet
    _input.sentiment = sentiment
    _input.selected = selected
    
    _input.pad(max_len)
    
    return _input

In [326]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, idx):
        sentiment, tweet, selected = (self.dataset[col][idx] for col in ['sentiment', 'text', 'selected_text'])
        _input = preprocess(sentiment, tweet, selected, tokenizer, max_len)
        
        return {
            'ids': torch.LongTensor(_input.ids),
            'mask': torch.LongTensor(_input.attention_mask),
            'token_type_ids': torch.LongTensor(_input.type_ids),
            'targets_start': torch.LongTensor([_input.start_target]),
            'targets_end': torch.LongTensor([_input.end_target]),
            'orig_tweet': _input.tweet,
            'orig_selected': _input.selected,
            'sentiment': _input.sentiment,
            'offsets': torch.LongTensor(_input.offsets)
        }
        

    def __len__(self):
        return len(self.dataset)

In [19]:
model = AutoModel.from_pretrained('google/electra-base-discriminator', config = config)
# m = AutoModel.from_pretrained('google/electra-large-discriminator')

INFO:transformers.modeling_utils:loading weights file https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/pytorch_model.bin from cache at /home/jack/.cache/torch/transformers/3c8e97e5021532563898ceb491dbfbc068ab4cb9eaa31f555990b9993e3228b4.b7514d01ce5acfe02313470cce3175018852a5e8cbcb8784268ab87dc21daf4c
INFO:transformers.modeling_utils:Weights from pretrained model not used in ElectraModel: ['electra.embeddings_project.weight', 'electra.embeddings_project.bias']


In [13]:
config = AutoConfig.from_pretrained('google/electra-base-discriminator')
config.output_hidden_states = True

INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/google/electra-base-discriminator/config.json from cache at /home/jack/.cache/torch/transformers/9236d197566a7f1be2b2151f5afcc5a8e17f31e1e23c52f3cdf2340019986e78.88ba6e8e7d5a7936e86d6f2551fe19c236dc57c24da163907cd0544e9933f6ee
INFO:transformers.configuration_utils:Model config ElectraConfig {
  "_num_labels": 2,
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "embedding_size": 768,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    

In [80]:
class SpanModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = model
        self.drop_out = nn.Dropout(0.1)
        self.qa_outputs = nn.Linear(768 * 2, 2) # update hidden size

    def forward(self, input_ids, attention_mask, token_type_ids):
        
        _, hidden_states = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = torch.cat((hidden_states[-1], hidden_states[-2]), dim=-1)
        out = self.drop_out(out)
        logits = self.qa_outputs(out)
        
        start_logits, end_logits = logits.split(1, dim=-1)
        
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        # Probabilities for each 'token' over the input sequence length
        # Need an argmax to compare with the scalar targets?
        return start_logits, end_logits

In [81]:
def calc_loss(start_logits, end_logits, start_positions, end_positions, loss_fn = None):
    if not loss_fn: loss_fn = nn.CrossEntropyLoss()
        
    start_loss = loss_fct(start_logits, start_positions)
    end_loss = loss_fct(end_logits, end_positions)
    total_loss = (start_loss + end_loss)
    return total_loss

In [390]:
def train(dataloader, model, optimizer, device, scheduler=None):
    model.train()
    losses = utils.AverageMeter()
    jaccards = utils.AverageMeter()

    prog = tqdm(dataloader, total=len(dataloader))
    
    for bi, d in enumerate(prog):

        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        sentiment = d["sentiment"]
        orig_selected = d["orig_selected"]
        orig_tweet = d["orig_tweet"]
        targets_start = d["targets_start"]
        targets_end = d["targets_end"]
        offsets = d["offsets"]

        ids = ids.to(device)
        token_type_ids = token_type_ids.to(device)
        mask = mask.to(device)
        targets_start = targets_start.to(device)
        targets_end = targets_end.to(device)

        model.zero_grad()
        outputs_start, outputs_end = model(ids, mask, token_type_ids)
        
        loss = calc_loss(outputs_start, outputs_end, targets_start, targets_end)
        loss.backward()
        optimizer.step()
        if scheduler: scheduler.step()

        outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
        outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
        jaccard_scores = []
        for px, tweet in enumerate(orig_tweet):
            selected_tweet = orig_selected[px]
            tweet_sentiment = sentiment[px]
            jaccard_score, _ = calculate_jaccard_score(
                original_tweet=tweet,
                target_string=selected_tweet,
                sentiment_val=tweet_sentiment,
                idx_start=np.argmax(outputs_start[px, :]),
                idx_end=np.argmax(outputs_end[px, :]),
                offsets=offsets[px]
            )
            jaccard_scores.append(jaccard_score)

        jaccards.update(np.mean(jaccard_scores), ids.size(0))
        losses.update(loss.item(), ids.size(0))
        prog.set_postfix(loss=losses.avg, jaccard=jaccards.avg)

In [391]:
def calculate_jaccard_score(
    original_tweet, 
    target_string, 
    sentiment_val, 
    idx_start, 
    idx_end, 
    offsets,
    verbose=False):
    
    if idx_end < idx_start:
        idx_end = idx_start
    
    filtered_output  = ""
    for ix in range(idx_start, idx_end + 1):
        filtered_output += original_tweet[offsets[ix][0]: offsets[ix][1]]
        if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
            filtered_output += " "

    if sentiment_val == "neutral" or len(original_tweet.split()) < 2:
        filtered_output = original_tweet

    jac = utils.jaccard(target_string.strip(), filtered_output.strip())
    return jac, filtered_output


def eval_fn(data_loader, model, device):
    model.eval()
    losses = utils.AverageMeter()
    jaccards = utils.AverageMeter()
    
    with torch.no_grad():
        tk0 = tqdm(data_loader, total=len(data_loader))
        for bi, d in enumerate(tk0):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            sentiment = d["sentiment"]
            orig_selected = d["orig_selected"]
            orig_tweet = d["orig_tweet"]
            targets_start = d["targets_start"]
            targets_end = d["targets_end"]
            offsets = d["offsets"].numpy()

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            targets_start = targets_start.to(device, dtype=torch.long)
            targets_end = targets_end.to(device, dtype=torch.long)

            outputs_start, outputs_end = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )
            loss = loss_fn(outputs_start, outputs_end, targets_start, targets_end)
            outputs_start = torch.softmax(outputs_start, dim=1).cpu().detach().numpy()
            outputs_end = torch.softmax(outputs_end, dim=1).cpu().detach().numpy()
            jaccard_scores = []
            for px, tweet in enumerate(orig_tweet):
                selected_tweet = orig_selected[px]
                tweet_sentiment = sentiment[px]
                jaccard_score, _ = calculate_jaccard_score(
                    original_tweet=tweet,
                    target_string=selected_tweet,
                    sentiment_val=tweet_sentiment,
                    idx_start=np.argmax(outputs_start[px, :]),
                    idx_end=np.argmax(outputs_end[px, :]),
                    offsets=offsets[px]
                )
                jaccard_scores.append(jaccard_score)

            jaccards.update(np.mean(jaccard_scores), ids.size(0))
            losses.update(loss.item(), ids.size(0))
            tk0.set_postfix(loss=losses.avg, jaccard=jaccards.avg)
    
    print(f"Jaccard = {jaccards.avg}")
    return jaccards.avg

In [453]:
def run(epochs):
    
    train_df, val_df = train_test_split(df, test_size = 0.2, random_state = 42)
    
    train_df.reset_index(drop=True, inplace = True)
    val_df.reset_index(drop=True, inplace = True)
    
    train_ds = TweetDataset(train_df)
    val_ds = TweetDataset(val_df)

    train_dl = DataLoader(train_ds, batch_size = bs, num_workers = 4)
    valid_data_loader = DataLoader(val_ds, batch_size = 16, num_workers = 2)

    device = torch.device("cuda")

#     model_config = transformers.BertConfig.from_pretrained(config.BERT_PATH)
#     model_config.output_hidden_states = True

#     model = TweetModel(conf=model_config)

    model.to(device)


    num_train_steps = int(len(train_df) / bs * epochs)
    param_optimizer = list(model.named_parameters())

    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
    ]

                                   
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=0, 
        num_training_steps=num_train_steps
    )

    # Apply early stopping with patience of 2
    # This means to stop training new epochs when 2 rounds have passed without any improvement
    es = utils.EarlyStopping(patience=2, mode="max")

    for epoch in range(epochs):
        train(train_dl, model, optimizer, device, scheduler=scheduler)
        jaccard = eval_fn(valid_data_loader, model, device)
        print(f"Jaccard Score = {jaccard}")
        es(jaccard, model, model_path=f"model_{fold}.bin")
        if es.early_stop:
            print("Early stopping")
            break