In [4]:
import warnings
warnings.filterwarnings('ignore')
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [48]:
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

In [18]:
BATCH_SIZE = 12
MAX_LEN = 120

PRE_TRAINED_MODEL_NAME = 'bert-base-cased'

In [7]:
DATA = Path('/home/sharif/Documents/Challenges/nlp-with-disaster-tweets/data')
train_df, test_df = pd.read_csv(DATA/'train.csv'), pd.read_csv(DATA/'test.csv')

In [75]:
class DS(Dataset):
    def __init__(self, texts, targets, tokenizer, max_len):
        self.texts = texts
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self): return len(self.texts)
    
    def __getitem__(self, item):
        text = str(self.texts[item])
        target = self.targets[item]
        
        encoding = self.tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          pad_to_max_length=True,
          return_attention_mask=True,
          return_tensors='pt',
          truncation=True
        )
        
        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'targets': torch.tensor(target, dtype=torch.long)
        }

In [76]:
train_df, valid_df = train_test_split(train_df, test_size=0.15)

In [77]:
len(train_ds), len(valid_df)

(6471, 596)

In [78]:
def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = DS(
        texts=df.text.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len,
        targets=df.target.to_numpy()
    )
    
    return DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=4
    )

In [79]:
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)

In [87]:
class BertClassifier(pl.LightningModule):
    def __init__(self, train_df, valid_df, n_c=2):
        super().__init__()
        self.train_df, self.valid_df = train_df, valid_df
        
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
        self.drop = nn.Dropout(p=0.5)
        self.out = nn.Linear(self.bert.config.hidden_size, n_c)
    
    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        output = self.drop(pooled_output)
        return self.out(output)
    
    def step(self, batch):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        targets = batch["targets"]
        
        outputs = self(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        
        _, preds = torch.max(outputs, dim=1)
        return F.cross_entropy(outputs, targets)
    
    def training_step(self, batch, batch_idx):
        return self.step(batch)
    
    def validation_step(self, batch, batch_idx):
        return self.step(batch)
        
    def configure_optimizers(self): return AdamW(self.parameters(), lr=2e-5, correct_bias=False)
            
    def train_dataloader(self): return create_data_loader(self.train_df, tokenizer, MAX_LEN, BATCH_SIZE)
    def val_dataloader(self): return create_data_loader(self.valid_df, tokenizer, MAX_LEN, BATCH_SIZE)

In [88]:
classifier = BertClassifier(train_df, valid_df)

In [89]:
trainer = pl.Trainer(gpus=1, max_epochs=10)
trainer.fit(classifier)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type      | Params
-----------------------------------
0 | bert | BertModel | 108 M 
1 | drop | Dropout   | 0     
2 | out  | Linear    | 1 K   


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

KeyboardInterrupt: 