In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, SequentialSampler, RandomSampler

from transformers import DistilBertTokenizerFast, DistilBertModel
from transformers import AdamW, get_linear_schedule_with_warmup
import pytorch_lightning as pl

In [28]:
import numpy as np
import json

In [21]:
# Hyper params
batch_size = 16
epochs = 3
learning_rate = 2e-5

# Prepare dataset

In [22]:
def prepare_dataset(PATH,convert2bio=False):
    sents = []
    chunks = open(PATH,'r').read().split('\n\n')
    for chunk in chunks:
        lines = chunk.split('\n')
        sent = []
        current_tag = None
        previous_tag = None
        for line in lines:
            if line != '':
                token = line.split('\t')
                previous_tag = current_tag 
                current_tag = token[1]
                if convert2bio:
                    if previous_tag == current_tag and current_tag != 'O':
                        sent.append((token[0],'I-'+token[1]))
                    elif previous_tag != current_tag and current_tag != 'O':
                        sent.append((token[0],'B-'+token[1]))
                    else:
                        sent.append((token[0],token[1]))
                else:
                    sent.append((token[0],token[1]))
        sents.append(sent)
    return sents

In [23]:
def convert_schema(samples):
    sentences = []
    tags = []
    for sample in samples:
        sent = []
        tag = []
        for s in sample:
            sent.append(s[0])
            tag.append(tag_index[s[1]])

        sentences.append(sent)
        tags.append(tag)

    return sentences, tags

In [26]:
train_set = prepare_dataset('data/ontonotes5/ner_train.txt', convert2bio=True)
valid_set = prepare_dataset('data/ontonotes5/ner_valid.txt', convert2bio=True)
test_set = prepare_dataset('data/ontonotes5/ner_test.txt', convert2bio=True)

samples = train_set + test_set + valid_set
schema = ['_'] + sorted({tag for sentence in samples for _, tag in sentence})
tag_index = {tag: i for i, tag in enumerate(schema)}
index_tag = {i: tag for i, tag in enumerate(schema)}

In [7]:
train_sentences, train_tag = convert_schema(train_set)
valid_sentences, valid_tag = convert_schema(valid_set)
test_sentences, test_tag = convert_schema(test_set)

# Prepare torch dataset

In [8]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-multilingual-cased')

In [9]:
class OntonoteDataset(Dataset):
    def __init__(self, sentences, labels, tokenizer, max_len=128):
        super(OntonoteDataset, self).__init__()
        self.sentences = sentences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, index):
        sentence = self.sentences[index]
        label = self.labels[index]
        
        tokenized_input = self.tokenizer(
            sentence, is_split_into_words=True, return_offsets_mapping=True,
            padding='max_length', truncation=True, return_tensors='pt',
            max_length=self.max_len
        )
        
        labels = []
        previous_word_idx = None
        label_all_tokens = True
        
        word_ids = tokenized_input.word_ids()
        
        for word_idx in word_ids:
            if word_idx is None:
                labels.append(-100)
            elif word_idx != previous_word_idx:
                labels.append(label[word_idx])
            else:
                labels.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx
            
        return {
            'input_ids': tokenized_input['input_ids'].squeeze(0),
            'attention_mask': tokenized_input['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels)
        }

In [10]:
# Build pytorch lightning
class OntonoteDataModule(pl.LightningDataModule):
    def __init__(
        self, train_sentence, valid_sentence, test_sentence,
        train_tag, valid_tag, test_tag,
        tokenizer, batch_size=32, max_length=128
    ):
        super(OntonoteDataModule, self).__init__()
        
        self.train_sentence = train_sentence
        self.train_tag = train_tag
        self.valid_sentence = valid_sentence
        self.valid_tag = valid_tag
        self.test_sentence = test_sentence
        self.test_tag = test_tag
        
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_length = max_length
        
    def setup(self, stage=None):
        self.train_dataset = OntonoteDataset(
            self.train_sentence, self.train_tag, self.tokenizer
        )
        self.valid_dataset = OntonoteDataset(
           self.valid_sentence, self.valid_tag, self.tokenizer
        )
        self.test_dataset = OntonoteDataset(
           self.test_sentence, self.test_tag, self.tokenizer
        )
        
    def train_dataloader(self):
        train_sampler = RandomSampler(self.train_dataset)
        return DataLoader(
            self.train_dataset, batch_size=batch_size, 
            sampler=train_sampler
        )
    
    
    def val_dataloader(self):
        valid_sampler = SequentialSampler(self.valid_dataset)
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            sampler=valid_sampler
        )
    
    
    def test_dataloader(self):
        test_sampler = SequentialSampler(self.test_dataset)
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            sampler=test_sampler
        )

In [11]:
data_module = OntonoteDataModule(
    train_sentences, valid_sentences, test_sentences,
    train_tag, valid_tag, test_tag,
    tokenizer, batch_size
)

data_module.setup()

# Build Model

In [12]:
class NerClassifier(pl.LightningModule):
    def __init__(self, n_classes=38):
        super(NerClassifier, self).__init__()
        self.bert_model = DistilBertModel.from_pretrained('distilbert-base-multilingual-cased')
        self.classifier = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(768, 256),
            nn.GELU(),
            nn.Linear(256, n_classes)
        )
        
        self.n_classes = n_classes
        self.criterion = nn.CrossEntropyLoss()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert_model(input_ids, attention_mask)
        return self.classifier(outputs[0])
    
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        
        logits = self(input_ids, attention_mask)
        
        if attention_mask is not None:
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.n_classes)
            active_labels = torch.where(
                active_loss, labels.view(-1), 
                torch.tensor(self.criterion.ignore_index).type_as(labels)
            )
            loss = self.criterion(active_logits, active_labels)
        else:
            loss = self.criterion(logits.view(-1, self.n_classes), labels.view(-1))
            
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        
        logits = self(input_ids, attention_mask)
        
        if attention_mask is not None:
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.n_classes)
            active_labels = torch.where(
                active_loss, labels.view(-1), 
                torch.tensor(self.criterion.ignore_index).type_as(labels)
            )
            loss = self.criterion(active_logits, active_labels)
        else:
            loss = self.criterion(logits.view(-1, self.n_classes), labels.view(-1))
            
        return loss
    
    def configure_optimizers(self):
        optimizer = AdamW(model.parameters(), lr=learning_rate)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=0, num_training_steps=epochs * len(data_module.train_dataloader())
        )
        
        return dict(
            optimizer=optimizer,
            lr_scheduler=dict(
                scheduler=scheduler,
                interval='step'
            )
        )


In [13]:
model = NerClassifier()

Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
trainer = pl.Trainer(
    max_epochs=epochs,
    gpus=1, progress_bar_refresh_rate=30
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [15]:
trainer.fit(model, datamodule=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type             | Params
------------------------------------------------
0 | bert_model | DistilBertModel  | 134 M 
1 | classifier | Sequential       | 206 K 
2 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
134 M     Trainable params
0         Non-trainable params
134 M     Total params
539.763   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [18]:
torch.save(model.state_dict(), 'weights/ontotnote_model.pth')