In [None]:
!pip install lightning
!pip install --force-reinstall tensorflow-io

In [None]:
import os
from datasets import load_dataset
import datasets
from transformers import AutoTokenizer

TOKENIZER = AutoTokenizer.from_pretrained("xlm-roberta-base", cache_dir=os.getenv("CACHE_DIR"))

# Align label with subtokens generated through tokenizer
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            # Start of a new word!
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            new_labels.append(label)
        elif word_id is None:
            # Special token
            new_labels.append(-100)
        else:
            # Same word as previous token
            label = labels[word_id]
            # If the label is B-XXX we change it to I-XXX
            if label % 2 == 1:
                label += 1
            new_labels.append(label)

    return new_labels

def tokenize_and_align_labels(examples):
    tokenized_inputs = TOKENIZER(
        examples["tokens"], truncation=True, is_split_into_words=True
    )
    all_labels = examples["ner_tags"]
    new_labels = []
    for i, labels in enumerate(all_labels):
        word_ids = tokenized_inputs.word_ids(i)
        new_labels.append(align_labels_with_tokens(labels, word_ids))

    tokenized_inputs["labels"] = new_labels
    return tokenized_inputs

'''
taken from the SPLICER repository: Label indices larger than 6 are remapped to 0.
'''
def remap_to_wikiann_labels(examples: dict) -> dict:
    examples["ner_tags"] = [
        [tag if tag < 7 else 0 for tag in instance] for instance in examples["ner_tags"]
    ]
    return examples

class MonolingualNERDataSet:
    def __init__(self,
                 name="conll2003",
                 split="train",
                 languages=None) -> None:
        super().__init__()
        if languages:
            self.dataset = datasets.concatenate_datasets(list(load_dataset(name, lang, cache_dir=os.getenv("CACHE_DIR"), split=split) for lang in languages))
        else:
            self.dataset = load_dataset(name, cache_dir=os.getenv("CACHE_DIR"), split=split)
        self.dataset_name = name
        self.tokenized_datasets = self.dataset.map(remap_to_wikiann_labels,batched=True)
        self.tokenized_datasets = self.tokenized_datasets.map(tokenize_and_align_labels, batched=True).map(batched=True, remove_columns=self.dataset.column_names)
        self.tokenized_datasets.set_format(type="torch", columns=["input_ids","attention_mask", "labels"])

class MultilingualNERDataSet:
    def __init__(self,
                 name="masakhaner",
                 languages=["amh", "hau", "ibo", "kin", "lug", "luo", "pcm", "swa", "wol", "yor"],
                 split="test") -> None:
        super().__init__()
        self.dataset_names = languages
        self.datasets = list(load_dataset(name, lang, split=split, cache_dir=os.getenv("CACHE_DIR")) for lang in languages)
        self.tokenized_datasets = [dataset.map(remap_to_wikiann_labels,batched=True).map(tokenize_and_align_labels, batched=True).map(batched=True, remove_columns=self.datasets[0].column_names) for dataset in self.datasets]
        for dataset in self.tokenized_datasets:
            dataset.set_format(type="torch", columns=["input_ids","attention_mask", "labels"])



In [None]:
import lightning.pytorch as pl
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification


_ENGLISH = [
    "en"
]

_MASAKHANER_LANGS = ["amh", "hau", "ibo", "kin", "lug", "luo", "pcm", "swa", "wol", "yor"]

class SlicerLightningDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=16, num_workers=2, tokenizer_checkpoint="xlm-roberta-base"):
        super().__init__()
        self.model_checkpoint = tokenizer_checkpoint
        self.batch_size = batch_size
        self.test_datasets_names  = None
        self.data_collator = None
        self.tokenizer = None
        self.num_workers = num_workers
        self.save_hyperparameters()

    def setup(self, stage="train_conll"):
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint)
        self.data_collator = DataCollatorForTokenClassification(tokenizer=self.tokenizer, return_tensors="pt")

        if stage == "train_conll":
            # Load data for training
            dataset = MonolingualNERDataSet(name="conll2003", split="train")
            self.train_dataset = dataset.tokenized_datasets

            # Load data for validation
            datasets = MonolingualNERDataSet(name="conll2003", split="validation")
            self.validation_dataset = datasets.tokenized_datasets        

        if stage == "train_wikiann":
            # Load data for training
            dataset = MonolingualNERDataSet(name="wikiann", split="train", languages=_ENGLISH)
            self.train_dataset = dataset.tokenized_datasets

            # Load data for validation
            datasets = MonolingualNERDataSet(name="wikiann", split="validation", languages=_ENGLISH)
            self.validation_dataset = datasets.tokenized_datasets

        masakhaner = MultilingualNERDataSet("masakhaner", _MASAKHANER_LANGS, split="test")
        self.test_datasets_names = masakhaner.dataset_names
        self.test_datasets = masakhaner.tokenized_datasets

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            collate_fn=self.data_collator, 
            # num_workers=self.num_workers
            )

    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            collate_fn=self.data_collator,
            # num_workers=self.num_workers
            )

    def test_dataloader(self):
        dataloaders = {}
        for index in range(len(self.test_datasets_names)):
            dataloader = DataLoader(
                self.test_datasets[index], 
                batch_size=self.batch_size,
                shuffle=False,
                collate_fn=self.data_collator)
            name = self.test_datasets_names[index]
            dataloaders[name] = dataloader
        return dataloaders
    # HINT: Evaluating on multiple dataloaders (https://lightning.ai/docs/pytorch/LTS/guides/data.html)




In [None]:
import torch
from torch import nn
from transformers import AutoModel


class SlicerRobertaNER(nn.Module):
    def __init__(self,
                 n_labels = 9,
                 slice_size = 4,
                 hidden_size = 768,
                 num_slices = 192):
        # Save HP to checkpoints
        super().__init__()
        self.n_labels = n_labels
        self.hidden_size = hidden_size
        self.slice_size = slice_size
        self.num_slices = num_slices

        # Init model
        self.base = AutoModel.from_pretrained("xlm-roberta-base")
        linear = nn.Linear(self.hidden_size, self.n_labels)
        
        classifier_weights = linear.weight.data.T.reshape((self.num_slices, self.slice_size, self.n_labels))
        classifier_bias = linear.bias
        
        self.classifier_weights = torch.nn.Parameter(classifier_weights, requires_grad=True)
        self.classifier_bias = torch.nn.Parameter(classifier_bias, requires_grad=True)

        #self.dropout = nn.Dropout(p=0.2)




    def forward(self, x):

        classified_outputs = self.base(**x)
        logits = classified_outputs[0]

        ####
        batch_size, sequence_length, hidden = logits.shape
        ####

        ####
        num_slices, slice_size, n_labels = self.classifier_weights.shape
        ####

        #we slice along the hidden dimension in num_slices slides
        sliced_outputs = logits.reshape((batch_size * sequence_length, self.num_slices, self.slice_size))
        #sliced_classifier= weight_matrix.reshape((self.num_slices, self.slice_size, self.n_labels)).to(sliced_outputs.device)
        self.classifier_weights = self.classifier_weights.to(sliced_outputs.device)

        #and combine them:
        #target shape
        #classified_outputs = torch.zeros((batch_size*sequence_length, self.num_slices, self.n_labels)).to(sliced_outputs.device)
        #combination
        #for i in range(batch_size*sequence_length):
        #    for j in range(self.num_slices):
        #        for kk in range(self.slice_size):
        #            for l in range(self.n_labels):
        #                classified_outputs[i, j, l] += sliced_outputs[i, j, kk] * sliced_classifier[j, kk, l]
        
        classified_outputs = torch.einsum("ndk, dkl->ndl", sliced_outputs, self.classifier_weights).reshape((-1, self.n_labels)) + self.classifier_bias
        
    
        # reshape to be compatible with later operations

        return classified_outputs



class RobertaNERAdvanced(nn.Module):
    def __init__(self,
                 n_labels=9,
                 hidden_size=384):
        # Save HP to checkpoints
        super().__init__()
        self.n_labels = n_labels
        self.hidden_size = hidden_size
        # Init model

        self.base = AutoModel.from_pretrained("xlm-roberta-base")
        self.dense = nn.Linear(768, hidden_size)
        # self.dropout = nn.Dropout(p=0.2)
        self.classification_head = nn.Linear(hidden_size, n_labels)

    def forward(self, x):
        output = self.base(**x)
        logits = output[0]
        logits = self.dense(logits)
        # logits = self.dropout(logits)
        logits = self.classification_head(logits)
        return logits


class RobertaNER(nn.Module):
    def __init__(self,
                 n_labels=9):
        # Save HP to checkpoints
        super().__init__()
        self.n_labels = n_labels
        # Init model

        self.base = AutoModel.from_pretrained("xlm-roberta-base")
        self.classification_head = nn.Linear(768, n_labels)

    def forward(self, x):
        output = self.base(**x)
        logits = output[0]
        logits = self.classification_head(logits)
        return logits

In [None]:
import lightning.pytorch as pl
import torch
import torch.nn as nn
import torchmetrics.functional
from torch import argmax
from torch.nn import CrossEntropyLoss
from torch.nn.functional import cross_entropy
from transformers import AutoModel
from torchmetrics.functional.classification import f1_score, accuracy, precision, recall
from torch.nn.functional import cross_entropy


label_dict = {0: 'O', 1:  'B-PER', 2: 'I-PER',3: 'B-ORG',4: 'I-ORG',5: 'B-LOC', 6: 'I-LOC',7: 'B-MISC',8: 'I-MISC'}
def lookup_table(label):
    return label_dict[label]

class SlicerLightningModule(pl.LightningModule):
    def __init__(self,
                test_datasets_names,
                 n_labels = 7,
                 hidden_size = 768, 
                 learning_rate = 2e-5, 
                 weight_decay = 0.05,
                 is_slicer = True,
                 slice_size = 4
                 ):
        # Save HP to checkpoints
        super().__init__()
        self.n_labels = n_labels
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate
        self.weigt_decay = weight_decay
        self.is_slicer = is_slicer
        self.slice_size = slice_size
        self.num_slices = int(self.hidden_size/self.slice_size)
        self.test_datasets_names  = test_datasets_names
        self.save_hyperparameters()

        # Init model
        if (self.is_slicer):
            self.model = SlicerRobertaNER(n_labels=self.n_labels, slice_size=slice_size, num_slices=self.num_slices)
        else:
            self.model = RobertaNER(n_labels=self.n_labels)


    def __default_step(self, batch, batch_idx):
        x, labels = {'input_ids': batch['input_ids'], 'attention_mask': batch['attention_mask']}, batch['labels']

        logits = self.model(x).reshape((-1, self.n_labels))  # combine batch and sequence (and num_slices) into one dim
        labels = labels.to(logits.device).reshape((-1,))  # combine batch and sequence into one dim

        # length of logits is different depending on SLICER/STANDARD
        if (self.is_slicer):
            labels = labels.repeat_interleave(self.num_slices)

        loss = cross_entropy(logits, labels, ignore_index=-100)

        return logits, labels, loss




    def training_step(self, batch, batch_idx):
        _,_, loss = self.__default_step(batch, batch_idx)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        logits, labels, loss = self.__default_step(batch, batch_idx)
        micro_f1, f1, precision, recall, accuracy = self.computeMetrics(logits, labels)
        
        self.log("val_f1", micro_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.logMetrics(loss, f1, recall, precision, accuracy, "val")


    def test_step(self, batch, batch_idx, dataloader_idx):

        logits, labels, loss = self.__default_step(batch, batch_idx)
        micro_f1, f1, precision, recall, accuracy = self.computeMetrics(logits, labels)

        dataset_name = self.test_datasets_names[dataloader_idx]
        self.log(f"{dataset_name}_test_f1", micro_f1, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.logMetrics(loss, f1, recall, precision, accuracy, f"{dataset_name}_test")


    def computeMetrics(self, logits, labels):
        # 7-er vector giving scores for each class individually
        micro_f1 = f1_score(preds=logits, target=labels,
                                 task='multiclass', num_classes=self.n_labels, ignore_index=-100, average='micro', multidim_average='global')
        f1 = f1_score(preds=logits, target=labels,
                           task='multiclass', num_classes=self.n_labels, ignore_index=-100, average='none', multidim_average='global')
        p = precision(preds=logits, target=labels,
                                   task='multiclass', num_classes=self.n_labels, ignore_index=-100, average='none', multidim_average='global')
        r = recall(preds=logits, target=labels,
                             task='multiclass', num_classes=self.n_labels, ignore_index=-100, average='none',multidim_average='global')
        a = accuracy(preds=logits, target=labels,
                                 task='multiclass', num_classes=self.n_labels, ignore_index=-100, average='none',multidim_average='global')

        return micro_f1, f1, p, r, a

    def logMetrics(self, loss, f1, recall, precision, accuracy, stage : str):
        #9er vector giving scores for each class individually
        dict = {f"{stage}_loss": loss}
        for i in range(self.n_labels):
            dict[f"{lookup_table(i)}_{stage}_recall"] = recall[i]
            dict[f"{lookup_table(i)}_{stage}_precision"] = precision[i]
            dict[f"{lookup_table(i)}_{stage}_f1"] = f1[i]
            dict[f"{lookup_table(i)}_{stage}_accuracy"] = accuracy[i]
        self.log_dict(dict, on_step=False, on_epoch=True, prog_bar=True, logger=True)



    # def predict_step(self, batch, batch_idx):

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), 
                                lr=self.learning_rate, 
                                weight_decay=self.weigt_decay)



In [None]:

from argparse import ArgumentParser, Namespace

import wandb
from lightning import seed_everything
import torch
from lightning.pytorch.loggers import WandbLogger
import lightning.pytorch as pl

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'


def create_experiment_name(args):
    base_name = 'Slicer' if args.slicer else 'Vanilla'
    dataset = str(args.dataset)
    batch_size = str(args.batch_size)
    num_slices = str(args.slice_size)
    return 'NER_' + base_name \
        + "-Dataset" + dataset \
        + "-BatchSize" + batch_size \
        + "-Slices" + num_slices \

def main(hparams):
    # Random seeding for reproducablity
    seed_everything(seed=hparams.random_seed, workers=True)

    # float precision
    torch.set_float32_matmul_precision('high')

    # Init logging
    experiment_name = create_experiment_name(hparams)

    wandb_logger = WandbLogger(
        project="MLNLP_Slicer", 
        name=experiment_name,
        checkpoint_name= experiment_name,
        offline=hparams.offline,
        log_model=True
        )

    datamodule = SlicerLightningDataModule(        
        batch_size=hparams.batch_size, 
        num_workers = hparams.num_workers)
    
    # important to call here, otherwise, the name of the test datasets will not be set causing erors in the module
    datamodule.setup(stage="train_"+hparams.dataset)

    # Init the building blocks
    module = SlicerLightningModule(
        test_datasets_names=datamodule.test_datasets_names,
        n_labels=7,
        learning_rate=2e-5,
        weight_decay=0.05,
        is_slicer=hparams.slicer,
        slice_size= hparams.slice_size
    )

    # Trainer (https://lightning.ai/docs/pytorch/stable/common/trainer.html)
    trainer = pl.Trainer(accelerator=hparams.accelerator,
                         max_epochs=10 if hparams.dataset == "conll" else 5,
                         accumulate_grad_batches=8,
                         logger=wandb_logger,
                         enable_progress_bar=True,
                         fast_dev_run=hparams.fast_dev_run,
                         deterministic=hparams.deterministic,
                         )

    # Fit the model (and evaluate on validation data as defined)
    trainer.fit(module, datamodule=datamodule)
    # trainer.save_checkpoint("./../checkpoints/" + experiment_name + "/example.ckpt")

    # Test model
    if not hparams.fast_dev_run:
        trainer.test(datamodule=datamodule, ckpt_path='last')

    del module
    del datamodule
    del trainer

    torch.cuda.empty_cache()
    wandb_logger.experiment.finish()




In [None]:

if __name__ == "__main__":
    wandb.login(key='a49faeab1679e732df6f7e70378fea80f88c48a2')
    
    search_space = {
        "dataset": ["wikiann"], 
        "slice_size": [1, 2, 8]
    }
                    
    for dataset in search_space["dataset"]:        
            for slice in search_space["slice_size"]:      
                hparams = Namespace()
                hparams.accelerator = "gpu"
                hparams.batch_size = 4
                hparams.dataset = dataset
                hparams.deterministic = True
                hparams.fast_dev_run = False
                hparams.slice_size = slice
                hparams.offline = False
                hparams.random_seed = 42
                hparams.num_workers = 4   
                hparams.slicer = True
                    
                    
                main(hparams)