In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Preprocessing testing
Download pre-trained preprocessor from huggingface

In [3]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", cache_dir="./cache")

In [4]:
sentence = "The International Court of Justice has its seat in The Hague"

In [5]:
sentence = "R.H. Saunders ( St. Lawrence River ) ( 968 MW )"

In [6]:
tokens = tokenizer(sentence, return_offsets_mapping=True)
subword_ids = tokens["input_ids"]
offsets = tokens["offset_mapping"]
print(f"pre: {len(sentence.split())}, toks: {len(subword_ids)}")
print(offsets)
print(tokens.keys())

pre: 11, toks: 17
[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (5, 13), (14, 15), (16, 18), (18, 19), (20, 28), (29, 34), (35, 36), (37, 38), (39, 42), (43, 45), (46, 47), (0, 0)]
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping'])


In [7]:
subwords = tokenizer.convert_ids_to_tokens(subword_ids)

print(subwords)

['[CLS]', 'R', '.', 'H', '.', 'Saunders', '(', 'St', '.', 'Lawrence', 'River', ')', '(', '968', 'MW', ')', '[SEP]']


In [8]:
[1] * 0

[]

# Training

In [9]:
from transformers import TrainingArguments
import numpy as np
import os
import argparse


In [10]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Hyperparameters

# Dataset
---

Sentences are split by empty lines, not `.`

In [11]:
import json
import gzip
from typing import List
from torch.nn.utils.rnn import pad_sequence

### Helper Functions

In [12]:
def split_label(label, num):
    if num == 0:
        return None
    if label[0] == "B":
        return [label] + ["I" + label[1:]] * (num - 1)
    else:
        return [label] * num

def create_tokenized_labels(labels, original_ranges, token_ranges) -> List[int]:
    new_labels = []
    tok_id = 0
    label_id = 0
    cur_tok = token_ranges[tok_id]
    tok_ranges = token_ranges[1:-1] # Remove start and end tokens
   
    # print(f"ranges {tok_ranges}")
    # debug = False
    # num_added = 0
    # if len(token_ranges) >= 512:
    #     print(f"Printing debugs:")
    #     debug = False
    
    for start, end in original_ranges:
        current_label = labels[label_id]
        label_id += 1
        counter = 0
        while tok_id < len(tok_ranges) and cur_tok[1] <= end:  # Word spans multiple tokens
            cur_tok = tok_ranges[tok_id]
            counter += 1
            tok_id += 1
            if cur_tok[1] == end:
                break

        new_token_labels = split_label(current_label, counter)
        if new_token_labels:
            new_labels.extend(new_token_labels)

        # if debug:
        #     num_added += counter
        #     print(f" * * added {counter} to the now {num_added} long label-list ({len(new_labels)})\n")
    
    return new_labels

def recombine_to_original_labels(tok_labels, original_ranges, token_ranges):
    org_labels = []
    tok_id = 0
    label_id = 0

    cur_tok = token_ranges[tok_id]
    tok_ranges = token_ranges[1:-1]
    
    
    for i, (start, end) in enumerate(original_ranges):
        current_label = tok_labels[label_id]
        # print(f"Label: {current_label} Start: {start} End: {end}")
        
        inner_labels = []
        while cur_tok[1] <= end:
            # print(f" * cur_tok: {cur_tok } End: {end}")
            inner_labels.append(current_label)
            
            if tok_id >= len(tok_ranges) - 1:
                break
            
            tok_id += 1 
            label_id += 1
            cur_tok = tok_ranges[tok_id]
            
        # print(f" * inner: {inner_labels}")
        org_labels.append(inner_labels[0])
    
    return org_labels

In [13]:
labels = ["B-ORG", "O", "B-ORG", "O"]
original_ranges = [       (0, 3),         (3, 5),         (6, 9),         (10, 14)]
tok_ranges =      [(0,0), (0, 1), (2, 3), (3, 4), (4, 5), (6, 7), (8, 9), (10, 11), (12, 14), (0,0)]

tok_labels = create_tokenized_labels(labels, original_ranges, tok_ranges)
print(tok_labels)

reconverted_labels = recombine_to_original_labels(tok_labels, original_ranges, tok_ranges)
print(reconverted_labels)


['B-ORG', 'I-ORG', 'O', 'O', 'B-ORG', 'I-ORG', 'O', 'O']
['B-ORG', 'O', 'B-ORG', 'O']


## Dataset implementation

In [14]:
from tqdm import tqdm

In [15]:
class XTREMEDataset(torch.utils.data.Dataset):
    def __init__(self, path, train=True, language='en'):
        train_or_dev = 'train' if train else 'dev'
        file_name = f"{train_or_dev}-{language}.tsv.gz"
        path = os.path.join(path, file_name)

        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", cache_dir="./cache")
        
        
        self.sentences = []
        self.sent_labels = []
        self.ranges = []
        
        self.tokens = []
        self.token_labels = []
        self.token_ranges = []

        # Convert between numerical representation
        self.label_to_num = {"[PAD}": 0, "[CLS]": 1, "[SEP]": 2}
        self.num_to_label = {}
        
        self.import_data(path)
        self.tokenize_data()
        
    def tokenize_data(self, max_len=512):
        for i, (sent, labels, ranges) in enumerate(zip(self.sentences, self.sent_labels, self.ranges)):
            sent_string = " ".join(sent)
            tokens = self.tokenizer(sent_string, return_offsets_mapping=True, max_length=512, truncation=True)
            tok_ranges = tokens["offset_mapping"]
            # print(f"len labs {len(labels)}, ranges: {len(ranges)}, toks: {len(tok_ranges)}")
            tok_labels = create_tokenized_labels(labels, ranges, tok_ranges)

            if len(tok_labels) != len(tok_ranges) - 2:
                error_sent = " ".join(tokenizer.convert_ids_to_tokens(tokens['input_ids']))
                print(f"ERROR IN LEN IN INDEX {i} ({len(tok_labels)}) : ({len(tok_ranges) - 2})")
                # print(f" * {error_sent}\n")

            # if len(tok_ranges) > max_len or len(tok_labels) > max_len:
            #     del self.sentences[i]
            #     del self.sent_labels[i]
            #     del self.ranges[i]
            #     continue
            
            self.tokens.append(tokens)
            self.token_ranges.append(tok_ranges)
            self.token_labels.append(tok_labels)
    
            # original_check = recombine_to_original_labels(tok_labels, ranges, tok_ranges)
            # print(f"Recombined: {original_check}")

        # print(f"Original: {self.sent_labels[0]}")
        # print(f"num toks: {len(tok_ranges)}\n\nLabels:{self.token_labels}\n\nnum labels: {len(self.token_labels[0])}")  

        
    def import_data(self, path):
        counter = 0
        
        with gzip.open(path, 'r') as file:
            cur_sent = []
            cur_labels = []
            cur_range = []
            prev_idx = 0
            
            for line in file:
                # New sentence if file contains an empty line
                if not line.split():
                    self.sentences.append(cur_sent)
                    self.sent_labels.append(cur_labels)
                    self.ranges.append(cur_range)
                    cur_sent = []
                    cur_labels = []
                    cur_range = []
                    prev_idx = 0
                    
                    # Temporary print and early stopping
                    # counter += 1
                    # if counter >= 5:
                    #     # print(f"sents: {[' '.join(sent) for sent in self.sentences]}\n\n")
                    #     # print(f"labels: {[[self.num_to_label[n] for n in sent] for sent in self.sent_labels]}")
                    #     # print(f"ranges: {self.ranges}")
                    #     break
                    continue
                    
                    
                word, label = line.decode().split()
                
                # Create unique numbered labels
                if label not in self.label_to_num.keys():
                    new_num = len(self.label_to_num)
                    self.label_to_num[label] = new_num
                    self.num_to_label[new_num] = label
                    
                # num = self.label_to_num[label]
                num = label
                
                
                # Build the word ranges for the sentence
                word_len = len(word)
                cur_range.append((prev_idx, prev_idx + word_len))
                prev_idx += word_len + 1 # add 1 for the space from later concatenation
                
                # Add the word and label to the current sentence
                cur_sent.append(word)
                cur_labels.append(num)
                
    
                
    def tokens_to_input_format(tokens):
        return NotImplemented

    def __getitem__(self, idx):
        data = {**self.tokens[idx]}
        data["sentence"] = self.sentences[idx]
        data["token_labels"] = self.token_labels[idx]
        data["original_labels"] = self.sent_labels[idx]
        data["label_to_num"] = self.label_to_num
        data["num_to_label"] = self.num_to_label

        # print(f'\n{"="*30}')
        # print(f"original ranges ({len(self.ranges[idx])}): {self.ranges[idx]}") 
        # for key in data:
        #     print(f" * {key} ({len(data[key])}): {data[key]}")
        
        return data

    def __len__(self):
        return len(self.sentences)


            
class CollateFunctor:
    def __init__(self, tokenizer, max_len, device):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.device = device
        
    def batch_to_device(self, batch):
        new_batch = {}
        for key, value in batch.items():
            if torch.is_tensor(value):
                new_batch[key] = value.to(self.device)
            else:
                new_batch[key] = value
        return new_batch

    def __call__(self, batch):
        input_ids = []
        token_type_ids = []
        attention_mask = []
        labels = []
        max_batch_len = max(len(sample["input_ids"]) for sample in batch)
        max_len = max(self.max_len, max_batch_len)
        
        # Iterate over each sample in the batch
        for sample in batch:
            # Pad or truncate input_ids, token_type_ids, attention_mask
            toks_labels_numerical = [sample["label_to_num"][label] for label in sample["token_labels"]]
            toks_labels_numerical = [1, *toks_labels_numerical, 2] # Add start and end token

            cur_ids = sample['input_ids']
            
            input_ids.append(torch.tensor(cur_ids))
            token_type_ids.append(torch.tensor(sample["token_type_ids"]))
            attention_mask.append(torch.tensor(sample["attention_mask"]))
            labels.append(torch.tensor(toks_labels_numerical + (len( input_ids)-len(cur_ids))*[0]))  
            
            # # Add padding to labels
            # pad_size = (len(input_ids[i])-len(toks_labels_numerical))
            # labels.append(torch.tensor(toks_labels_numerical + pad_size*[-100]))
            
        # Pad sequences to ensure uniform length
        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
        token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=0,)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=0)
        labels = torch.stack([term.squeeze(0) for term in labels])
        inputs = {
            "input_ids": input_ids,
            "token_type_ids": token_type_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

        inputs['labels'] = labels.clone().detach()

        # print(inputs['labels'])
         
        return self.batch_to_device(inputs)

# Training

In [16]:
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig

In [17]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, device, val_loader=None):
    model.train()
    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        optimizer.zero_grad()

        # print(batch.keys())
        # for key, value in batch.items():
        #     print(f"{key} : {value.shape}")
        
        # forward pass 
        loss = model(**batch).loss

        # backward pass
        loss.backward()

        # update weights
        optimizer.step()
        lr_scheduler.step()

        progress_bar.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

@torch.no_grad()
def evaluate(model, val_loader, device):
    model.eval()
    total_correct, total_samples = 0, 0
    for batch in tqdm(val_loader):
        outputs = model(**batch)
        attention_mask = batch['attention_mask']

        # Mask the [CLS] and [SEP] token as well
        for mask in attention_mask:
            unmasked = torch.where(mask == 1)[0]
            mask[unmasked[0]] = 0
            mask[unmasked[-1]] = 0
        
        active_logits = outputs.logits.view(-1, outputs.logits.shape[-1])[attention_mask.view(-1) == 1]  
        active_labels = batch['labels'].view(-1)[attention_mask.view(-1) == 1]

        total_correct += (active_logits.argmax(dim=1) == active_labels).sum().item()  
        total_samples += active_labels.shape[0]

    accuracy = total_correct / total_samples
    print(f"total correct: {total_correct}, total samples: {total_samples}: ({accuracy})")
    return accuracy

# Hyperparameters

In [18]:
def parse_arguments():
    parser = argparse.ArgumentParser(description='Train a model on the SNLI dataset')
    parser.add_argument('--model', type=str, default='bert-base-multilingual-cased', help='The model to use')
    parser.add_argument('--batch_size', type=int, default=8, help='The batch size')
    parser.add_argument('--epochs', type=int, default=3, help='The number of epochs to train')
    parser.add_argument('--lr', type=float, default=2e-5, help='The learning rate')
    parser.add_argument('--freeze', type=bool, default=True, help='If to freeze the earlier BERT weights')
    parser.add_argument('--seed', type=int, default=42, help='The random seed')
    parser.add_argument('--warmup_steps', type=int, default=50, help='The number of warmup steps')
    parser.add_argument('--gradient_clipping', type=float, default=10.0, help='The gradient clipping value')
    return parser.parse_args([])

args = parse_arguments()

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased", cache_dir="./cache")
model = AutoModelForTokenClassification.from_pretrained("bert-base-multilingual-cased", 
                                                        cache_dir="./cache", 
                                                        num_labels=10).to(device)

# Freeze all layers of the pre-trained BERT model  
if args.freeze:
    for name, param in model.bert.named_parameters():  
        param.requires_grad = False 
    

print(f"{'='*40}" +
      f"Training model: {args.model} for {args.epochs} epochs\n * learning rate is {args.lr}" +
      f" * batch size is {args.batch_size}\n * BERT weights are frozen\n * device is {device}")

collate = CollateFunctor(tokenizer, 512, device)

dataset = XTREMEDataset("./data")
data_loader = torch.utils.data.DataLoader(dataset, 
                                          batch_size=args.batch_size, 
                                          shuffle=True,
                                          drop_last=True,
                                          collate_fn=collate)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=args.lr
)
lr_scheduler = transformers.get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(data_loader) * args.epochs
)

for epoch in range(args.epochs):
    train_epoch(model, data_loader, optimizer, lr_scheduler, device)
    accuracy = evaluate(model, data_loader, device)
    print(f"Epoch {epoch + 1}: validation accuracy = {accuracy:.2%}\n")

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


 * learning rate is 2e-05 * batch size is 8
 * BERT weights are frozen
 * device is cuda


Training: 100%|██████████| 2499/2499 [00:25<00:00, 97.14it/s, loss=1.16, lr=1.34e-5]  
100%|██████████| 2499/2499 [00:23<00:00, 106.21it/s]


total correct: 89367, total samples: 218859: (0.4083313914438063)
Epoch 1: validation accuracy = 40.83%



Training: 100%|██████████| 2499/2499 [00:25<00:00, 99.65it/s, loss=0.707, lr=6.71e-6] 
100%|██████████| 2499/2499 [00:23<00:00, 107.11it/s]


total correct: 115092, total samples: 218907: (0.525757513464622)
Epoch 2: validation accuracy = 52.58%



Training: 100%|██████████| 2499/2499 [00:24<00:00, 100.17it/s, loss=0.719, lr=0]      
100%|██████████| 2499/2499 [00:23<00:00, 107.14it/s]

total correct: 122153, total samples: 218888: (0.558061657103176)
Epoch 3: validation accuracy = 55.81%






In [20]:
test_accuracy = evaluate(model, data_loader, device)
print(f"Test accuracy = {test_accuracy:.2%}")

100%|██████████| 2499/2499 [00:23<00:00, 106.56it/s]

total correct: 122124, total samples: 218844: (0.5580413445193837)
Test accuracy = 55.80%



