In [1]:
import torch
import numpy as np
from transformers import BertTokenizer
import pandas as pd

from lit_nlp import notebook
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
datapath='bbc-text-mini.csv'
df= pd.read_csv(datapath)

In [3]:
LABELS = {
    'business':0,
    'entertainment':1,
    'sport':2,
    'tech': 3,
    'politics':4
}

In [4]:
class Dataset(lit_dataset.Dataset):
    def __init__(self, df):
        self.labels = [LABELS[label] for label in df['category']]
        self.texts = [tokenizer(text, padding='max_length',max_length=512, truncation=True,return_tensors='pt') for text in df['text']]
    # end
    
    def classes(self):
        return self.labels
    # end
    
    def __len__(self):
        return len(self.labels)
    # end
    
    def get_batch_labels(self, idx):
        return np.array(self.labels[idx])
    # end
    
    def get_batch_texts(self, idx):
        return self.texts[idx]
    # end
    
    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y
    # end
    
    def spec(self) -> lit_types.Spec:
        return {
            'text': lit_types.TextSegment(),
            'label': lit_types.CategoryLabel(vocab=self.labels)
        }
    # end

In [5]:
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])

In [6]:
print(len(df_train),len(df_val), len(df_test))

79 10 10


In [7]:
# start to build the model
from torch import nn
from transformers import BertModel

class BertClassifier(nn.Module):
    def __init__(self, dropout=0.5):
        super (BertClassifier, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 5)
        self.relu = nn.ReLU()
    # end
    
    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)
        
        return final_layer
    # end
# end class



class BertTrainer(lit_model.Model):
    def __init__(self, model, train_dataset, val_dataset,learning_rate=1e-6):
        self.model = model
        self.labels = [LABELS[label] for label in df['category']]
        self.train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True)
        self.val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2)
        
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        self.criterion = nn.CrossEntropyLoss()
        
        self.optimizer = Adam(model.parameters(), lr=learning_rate)
    # end
    
    
    def input_spec(self) -> lit_types.Spec:
        return {
            "text": lit_types.TextSegment(),
            "label": lit_types.CategoryLabel(vocab=self.labels, required=False)
        }

    def output_spec(self) -> lit_types.Spec:
        return {
            "tokens": lit_types.Tokens(),
            "probas": lit_types.MulticlassPreds(parent="label", vocab=self.labels),
            "cls_emb": lit_types.Embeddings()
        }
    # end
    
    
    def train(self, epochs, tokenizer):
        device = self.device
        
        if self.use_cuda:
            model = self.model.cuda()
            criterion = self.criterion.cuda()
        # end

        for epoch_num in range(epochs):

            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(self.train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = model(input_id, mask)
                
                batch_loss = criterion(output, train_label)
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                model.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
            
#             total_acc_val = 0
#             total_loss_val = 0
            
#             for val_inputs, val_label in tqdm(self.train_dataloader):
#                 current_loss_val, current_acc_val = self.predict_minibatch(val_inputs, val_label, tokenizer)
#             # end
            
#             total_acc_val += current_loss_val
#             total_loss_val += current_loss_val
            
#             print(
#             f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_data): .3f} \
#             | Train Accuracy: {total_acc_train / len(train_data): .3f} \
#             | Val Loss: {total_loss_val / len(val_data): .3f} \
#             | Val Accuracy: {total_acc_val / len(val_data): .3f}')
        # end print

    def max_minibatch_size(self):
        return 32
    
    def predict_minibatch(self, val_input, val_label, tokenizer, model):
        device = self.device
        model = self.model
        criterion = self.criterion
        
        if self.use_cuda:
            model = self.model.cuda()
            criterion = self.criterion.cuda()
        # end
        
        with torch.no_grad():
            val_label = val_label.to(device)
            mask = val_input['attention_mask'].to(device)
            input_id = val_input['input_ids'].squeeze(1).to(device)

            output = model(input_id, mask)

            batch_loss = criterion(output, val_label)
            current_loss_val = batch_loss.item()
            acc = (output.argmax(dim=1) == val_label).sum().item()
             
            output_yield = {}
            output_yield['tokens'] = [tokenizer.decode(dimension23.reshape(-1)) for dimension23 in val_input['input_ids']]
            output_yield["probas"] = torch.nn.functional.softmax(output.logits, dim=-1),
            output_yield['cls_emb'] = output.hidden_states[-1][:, 0]
            yield output_yield
        # end
        
        return current_loss_val, acc

    # end
    
    

# end class

In [8]:
from torch.optim import Adam
from tqdm import tqdm

                  
EPOCHS = 1
model = BertClassifier()
train_dataset, val_dataset = Dataset(df_train), Dataset(df_val)
LR = 1e-6

trainer = BertTrainer(model, train_dataset, val_dataset)
trainer.train(EPOCHS, tokenizer)  

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 40/40 [00:03<00:00, 10.35it/s]


In [9]:
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2)

for val_inputs, val_label in tqdm(val_dataloader):
    trainer.predict_minibatch(val_inputs, val_label, tokenizer, model)
# end

100%|██████████| 5/5 [00:00<00:00, 4729.71it/s]


In [10]:
models = {"SR-2-PR": trainer}
datasets = {'SR-2-PR': val_dataset}

In [11]:
widget = notebook.LitWidget(models, datasets, height=1024)

INFO:absl:
 (    (           
 )\ ) )\ )  *   ) 
(()/((()/(` )  /( 
 /(_))/(_))( )(_))
(_)) (_)) (_(_()) 
| |  |_ _||_   _| 
| |__ | |   | |   
|____|___|  |_|   


INFO:absl:Starting LIT server...
INFO:absl:CachingModelWrapper 'SR-2-PR': no cache path specified, not loading.
INFO:absl:Warm-start of model 'SR-2-PR' on dataset '_union_empty'
INFO:absl:CachingModelWrapper 'SR-2-PR': misses (dataset=_union_empty): []
INFO:absl:CachingModelWrapper 'SR-2-PR': 0 misses out of 0 inputs
INFO:absl:Prepared 0 inputs for model
INFO:absl:Received 0 predictions from model
INFO:absl:Requested types: ['LitType']
INFO:absl:Will return keys: {'probas', 'tokens', 'cls_emb'}
INFO:absl:CachingModelWrapper 'SR-2-PR': no cache path specified, not saving.


In [12]:
widget.render()