In [1]:
import transformers
import tokenizers
import torch
from torch import nn
import numpy as np
import pandas as pd
import typing, os, string, gc, time
from nltk.corpus import indian
import nltk
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AdamW


In [2]:
print(f"Transformers version: {transformers.__version__}")
print(f"Tokenizers version: {tokenizers.__version__}")

Transformers version: 3.0.2
Tokenizers version: 0.8.1.rc1


### Text classification

In [3]:
def process_dataframe(path):
    
    df = pd.read_csv(path, sep='\t', encoding='utf-8', header=None)
    df.columns = ['label_text', 'text']
    df.label_text = pd.Categorical(df.label_text)
    df['label'] = df.label_text.cat.codes
    #print(df.label_text.cat.categories)
    print(f"Number of examples: {len(df)}")
    return df

In [4]:
train_df = process_dataframe('bbc-hindi/hindi-train.csv')
valid_df = process_dataframe('bbc-hindi/hindi-test.csv')

Number of examples: 3468
Number of examples: 867


In [5]:
train_df.head()

Unnamed: 0,label_text,text,label
0,india,मेट्रो की इस लाइन के चलने से दक्षिणी दिल्ली से...,3
1,pakistan,नेटिजन यानि इंटरनेट पर सक्रिय नागरिक अब ट्विटर...,9
2,news,इसमें एक फ़्लाइट एटेनडेंट की मदद की गुहार है औ...,8
3,india,"प्रतीक खुलेपन का, आज़ाद ख्याली का और भीड़ से अ...",3
4,india,ख़ासकर पिछले 10 साल तक प्रधानमंत्री रहे मनमोहन...,3


In [6]:
valid_df.head()

Unnamed: 0,label_text,text,label
0,india,बुधवार को राज्य सभा में विपक्ष के सवालों के जव...,3
1,india,लखनऊ स्थित पत्रकार समीरात्मज मिश्र को बुलंदशहर...,3
2,india,लगभग 1300 हेक्टेयर ज़मीन का अधिग्रहण किया जा च...,3
3,international,हालांकि उनके अंगरक्षकों को बमों को जाम करने वा...,5
4,india,आयोग का कहना है कि इस तरह के परीक्षण से महिलाओ...,3


In [28]:
class BBCHindiDataset:
    
    def __init__(self, data, tokenizer, base_model_type, batch_size):
        
        self.tokenizer = tokenizer
        self.base_model_type = base_model_type
        self.batch_size = batch_size
        data = [data[i: i+batch_size] for i in range(0, len(data), batch_size)]
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __iter__(self):
        
        
        for batch in self.data:
            
            batch = batch.dropna()
            texts = list(batch.text)
            labels = list(batch.label)
            
            encoded_input = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors='pt')
            input_ids = encoded_input['input_ids']
            attention_mask = encoded_input['attention_mask']
            label = torch.tensor(labels, dtype=torch.long)
            if self.base_model_type == 'bert':
                token_type_ids = encoded_input['token_type_ids']
                yield {
                    'input_ids':input_ids,
                    'attention_mask':attention_mask,
                    'token_type_ids':token_type_ids,
                    'label':label
                }
            else:
                
                yield {
                    'input_ids':input_ids,
                    'attention_mask':attention_mask,
                    'label':label
                }

In [41]:
tokenizer = AutoTokenizer.from_pretrained('hi-lm-distilbert/')
train_dataset = BBCHindiDataset(train_df, tokenizer, 'bert', 32)
valid_dataset = BBCHindiDataset(valid_df, tokenizer, 'bert', 32)

In [42]:
a = next(iter(valid_dataset))

In [21]:
device = torch.device('cuda')
base_model = AutoModel.from_pretrained('hi-lm-distilbert/').to(device)
for param in base_model.parameters():
    param.requires_grad = False

In [29]:
class RoBERTaTextClassifier(nn.Module):
    
    def __init__(self, base_model):
        
        super().__init__()
        
        self.base_model = base_model
        self.fc1 = nn.Linear(768, 100)
        self.fc2 = nn.Linear(100, 14)
    
    def forward(self, input_ids, attention_mask):
        
        with torch.no_grad():
            sequence_output = self.base_model(input_ids, attention_mask)[0]
        
        # sequence_output = [batch_size, seq_len, 768]
        
        mean_output = sequence_output.mean(dim=1)
        # [bs, 768]
        
        out = self.fc2(self.fc1(mean_output))
        # out = [bs, 14]
        
        return out

In [30]:
model = RoBERTaTextClassifier(base_model).to(device)

In [31]:
optimizer = AdamW(model.parameters(), lr=3e-4)

In [25]:
def train(model, optimizer, train_dataset):
    
    print("Starting Training")
    train_loss = 0.
    train_acc = 0.
    model.train()
    
    for bi, batch in enumerate(train_dataset):

        if bi % 20 == 0:
            print(f"Starting batch: {bi}")

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        preds = model(input_ids, attention_mask)
        loss = F.cross_entropy(preds, labels)
        
        train_loss += loss.item()
        train_acc += (torch.argmax(preds,dim=1)==labels).float().mean().item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    return train_loss/len(train_dataset), train_acc/len(train_dataset)
    
    

In [26]:
def validate(model, valid_dataset):
    
    print("Starting validation")
    valid_loss = 0.
    valid_acc = 0.
    model.eval()
    
    for bi, batch in enumerate(valid_dataset):

        if bi % 20 == 0:
            print(f"Starting batch: {bi}")

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        with torch.no_grad():
            
            preds = model(input_ids, attention_mask)
            loss = F.cross_entropy(preds, labels)
        
            valid_loss += loss.item()
            valid_acc += (torch.argmax(preds,dim=1)==labels).float().mean().item()
        
        
    return valid_loss/len(valid_dataset), valid_acc/len(valid_dataset)
    
    

In [27]:
def epoch_time(start_time, end_time):
    '''
    Helper function to record epoch time.
    '''
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [32]:
train_losses = []
valid_losses = []
valid_accs = []
train_accs = []
epochs = 5
for epoch in range(epochs):
    print(f"Epoch {epoch+1}")
    
    start_time = time.time()
    
    train_loss, train_acc = train(model, optimizer, train_dataset)
    valid_loss, valid_acc = validate(model, valid_dataset)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    valid_accs.append(valid_acc)
    train_accs.append(train_acc)
    
    print(f"Epoch train loss : {train_loss}| Time: {epoch_mins}m {epoch_secs}s")
    print(f"Epoch valid loss: {valid_loss}")
    print(f"Epoch train accuracy: {train_acc}")
    print(f"Epoch valid accuracy: {valid_acc}")
    print("====================================================================================")

Epoch 1
Starting Training
Starting batch: 0
Starting batch: 20
Starting batch: 40
Starting batch: 60
Starting batch: 80
Starting batch: 100
Starting validation
Starting batch: 0
Starting batch: 20
Epoch train loss : 1.3605404337611766| Time: 5m 11s
Epoch valid loss: 1.014705736722265
Epoch train accuracy: 0.5916413878082135
Epoch valid accuracy: 0.6975566425493785
Epoch 2
Starting Training
Starting batch: 0
Starting batch: 20
Starting batch: 40
Starting batch: 60
Starting batch: 80
Starting batch: 100
Starting validation
Starting batch: 0
Starting batch: 20
Epoch train loss : 0.8979853246736964| Time: 6m 2s
Epoch valid loss: 0.8425668829253742
Epoch train accuracy: 0.7340436025497017
Epoch valid accuracy: 0.7519081213644573
Epoch 3
Starting Training
Starting batch: 0
Starting batch: 20
Starting batch: 40
Starting batch: 60
Starting batch: 80
Starting batch: 100
Starting validation
Starting batch: 0
Starting batch: 20
Epoch train loss : 0.7756573757447234| Time: 6m 16s
Epoch valid loss: