In [1]:
from transformers import AutoModelForQuestionAnswering, AutoTokenizer, AutoConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pandas as pd
from pathlib import Path
import os

In [2]:
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from apex.optimizers.fused_lamb import FusedLAMB as Lamb
from fairseq import criterions 
from functools import partial
from tokenizers import BertWordPieceTokenizer

import lineflow as lf
import lineflow.cross_validation as lfcv
import tqdm

In [3]:
file_dir = Path.cwd()/'data'
train_df = pd.read_csv(file_dir/'train.csv')
train_df['text'] = train_df['text'].apply(lambda x: str(x))
train_df['sentiment'] = train_df['sentiment'].apply(lambda x: str(x))

In [73]:
train_df.head()

Unnamed: 0,textID,text,selected_text,sentiment
0,cb774db0d1,"I`d have responded, if I were going","I`d have responded, if I were going",neutral
1,549e992a42,Sooo SAD I will miss you here in San Diego!!!,Sooo SAD,negative
2,088c60f138,my boss is bullying me...,bullying me,negative
3,9642c003ef,what interview! leave me alone,leave me alone,negative
4,358bd9e861,"Sons of ****, why couldn`t they put them on t...","Sons of ****,",negative


In [14]:
max_len = 128
bs = 64
epochs = 5
tokenizer = BertWordPieceTokenizer('input/electra-base-disc/vocab.txt', lowercase=True)

In [44]:
def preprocess(sentiment, tweet, selected, tokenizer, max_len):
    _input = tokenizer.encode(sentiment, tweet)
    _span = tokenizer.encode(selected, add_special_tokens=False)
    
    len_span = len(_span.ids)
    start_idx = None
    end_idx = None
    
    for ind in (i for i, e in enumerate(_input.ids) if e == _span.ids[0]):
        if _input.ids[ind: ind + len_span] == _span.ids:
            start_idx = ind
            end_idx = ind + len_span - 1
            break
    
    _input.start_target = start_idx
    _input.end_target = end_idx
    _input.tweet = tweet
    _input.sentiment = sentiment
    _input.selected = selected
    
    _input.pad(max_len)
    
    return _input

In [70]:
class TweetDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __getitem__(self, idx):
        sentiment, tweet, selected = (self.dataset[col][idx] for col in ['sentiment', 'text', 'selected_text'])
        output = preprocess(sentiment, tweet, selected, tokenizer, max_len)
        return output.ids, (output.start_target, output.end_target)

    def __len__(self):
        return len(self.dataset)