#### Author: Sazan Mahbub (UID: 118214443)

### Reference: 
#### Privided sample codes: 
1. https://github.com/jwkirchenbauer/CMSC828A-Spring2023
1. https://github.com/jwkirchenbauer/CMSC828A-Spring2023/blob/main/hw1/hw1_starter_code.ipynb
2. https://github.com/jwkirchenbauer/CMSC828A-Spring2023/blob/57b00bd009b7452da1dde17bd7134bf1997d1aac/hw1/task_sampler.py

#### Hugginface official documentations and examples: 
1. https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/token_classification.ipynb#scrollTo=tvcpN89gwN8_
2. https://huggingface.co/docs/transformers/v4.26.1/en/model_doc/bert#transformers.BertForTokenClassification.forward
3. https://huggingface.co/docs/transformers/training#training-loop


In [None]:
# ! pip install datasets transformers seqeval evaluate

In [None]:
import os
os.environ['HF_HOME'] = "/vulcanscratch/smahbub/.cache/huggingface"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_DISABLED"] = "true"

In [None]:
import numpy as np
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding, AutoTokenizer

from datasets import load_dataset, DatasetDict


class TaskSampler():
    """ 
    Class for sampling batches from a dictionary of dataloaders according to a weighted sampling scheme.
    Dynamic task weights can be externally computed and set using the set_task_weights method,
    or, this class can be extended with methods and state state to implement a more complex sampling scheme.
    You probably/shouldn't need to use this with multiple GPUs, but if you do, you'll may need
    to extend/debug it yourself since the current implementation is not distributed-aware.
    
    Args:
        dataloader_dict (dict[str, DataLoader]): Dictionary of dataloaders to sample from.
        task_weights (list[float], optional): List of weights for each task. If None, uniform weights are used. Defaults to None.
        max_iters (int, optional): Maximum number of iterations. If None, infinite. Defaults to None.
    """
    def __init__(self, 
                *,
                dataloader_dict: dict[str, DataLoader],
                task_weights=None,
                max_iters=None):
        
        assert dataloader_dict is not None, "Dataloader dictionary must be provided."

        self.dataloader_dict = dataloader_dict
        self.task_names = list(dataloader_dict.keys())
        self.dataloader_iterators = self._initialize_iterators()
        self.task_weights = task_weights if task_weights is not None else self._get_uniform_weights()
        self.max_iters = max_iters if max_iters is not None else float("inf")
    
    # Initialization methods
    def _get_uniform_weights(self):
        return [1/len(self.task_names) for _ in self.task_names]
    
    def _initialize_iterators(self):
        return {name:iter(dataloader) for name, dataloader in self.dataloader_dict.items()}
    
    # Weight getter and setter methods (NOTE can use these to dynamically set weights)
    def set_task_weights(self, task_weights):
        assert sum(self.task_weights) == 1, "Task weights must sum to 1."
        self.task_weights = task_weights
    
    def get_task_weights(self):
        return self.task_weights

    # Sampling logic
    def _sample_task(self):
        return np.random.choice(self.task_names, p=self.task_weights)
    
    def _sample_batch(self, task):
        try:
            return self.dataloader_iterators[task].__next__()
        except StopIteration:
            print(f"Restarting iterator for {task}")
            self.dataloader_iterators[task] = iter(self.dataloader_dict[task])
            return self.dataloader_iterators[task].__next__()
        except KeyError as e:
            print(e)
            raise KeyError("Task not in dataset dictionary.")
    
    # Iterable interface
    def __iter__(self):
        self.current_iter = 0
        return self
    
    def __next__(self):
        if self.current_iter >= self.max_iters:
            raise StopIteration
        else:
            self.current_iter += 1
        task = self._sample_task()
        batch = self._sample_batch(task)
        return task, batch


In [None]:
import transformers

print(transformers.__version__)

In [None]:
# task = "ner" # Should be one of "ner", "pos" or "chunk"
# model_checkpoint = "distilbert-base-uncased" 
model_checkpoint = "bert-base-cased" 

In [None]:
from datasets import load_dataset, load_metric

In [None]:
# datasets = load_dataset("conll2003")
datasets = {}
datasets['train'] = load_dataset("Babelscape/wikineural", split="train_en")#.shuffle(seed=42).select(range(100))
datasets['val'] = load_dataset("Babelscape/wikineural", split="val_en")#.shuffle(seed=42).select(range(100)) ## subsampling for debug only
datasets['test'] = load_dataset("Babelscape/wikineural", split="test_en")#.shuffle(seed=42).select(range(100)) ## subsampling for debug only

In [None]:
datasets

In [None]:
datasets['train'][0]

In [None]:
datasets["train"].features[f"ner_tags"]

tag_set = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
tag_set = {v: k for k, v in tag_set.items()}

for key in datasets:
    datasets[key] = datasets[key].map(lambda example: {'ner_tags_named': [tag_set[tag] for tag in example['ner_tags']]})
    print(key, ':', datasets[key])

In [None]:
# label_list = datasets["train_en"].features[f"ner_tags"].feature.names
label_list = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
label_list

In [None]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for ii in range(num_examples):
        pick = ii #random.randint(0, len(dataset)-1)
        # while pick in picks:
        #     pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(datasets["train"])

In [None]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
tokenizer("Hello, this is one sentence!", return_tensors='pt')

In [None]:
tokenizer(["Hello", ",", "this", "is", "one", "sentence", "split", "into", "words", "."], is_split_into_words=True, return_tensors='pt')

In [None]:
example = datasets["train"][4]
print(example["tokens"])

In [None]:
tokenized_input = tokenizer(example["tokens"], is_split_into_words=True, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"][0])
print(example["tokens"])
print(tokens)

In [None]:
# print(task)
len(example[f"ner_tags_named"]), tokenized_input["input_ids"].shape#, example[f"ner_tags_named"]

In [None]:
print(tokenized_input.word_ids())

In [None]:
word_ids = tokenized_input.word_ids()
aligned_labels = [-100 if i is None else example[f"ner_tags"][i] for i in word_ids]
print(len(aligned_labels), tokenized_input["input_ids"].shape)
# aligned_labels

In [None]:
label_all_tokens = True

In [None]:
import torch

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True, padding='max_length', max_length=512, return_tensors='pt')

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = torch.Tensor(labels).long()

    ## dummy data. to avoid collate issue.
    tokenized_inputs["tokens"] = tokenized_inputs['input_ids']
    tokenized_inputs["lang"] = tokenized_inputs['input_ids']
    tokenized_inputs["ner_tags"] = tokenized_inputs['input_ids']
    tokenized_inputs["ner_tags_named"] = tokenized_inputs['input_ids']
    return tokenized_inputs

In [None]:
# ### show some examples
# temp_data = tokenize_and_align_labels(datasets['train'][:5])
# for key in temp_data:
#     print(key, ':\n', temp_data[key].shape)

# for i in range(len(temp_data['input_ids'])):
#     print(i, ':', tokenizer.convert_ids_to_tokens(temp_data['input_ids'][i]))

In [None]:
ner_dataset_raw = datasets
ner_dataset = {}
for key in ner_dataset_raw:
    ner_dataset[key] = ner_dataset_raw[key].map(tokenize_and_align_labels, batched=True)
    # del ner_dataset[key].features['tokens']
    # del ner_dataset[key].features['ner_tags']
    # del ner_dataset[key].features['ner_tags_named']
    ner_dataset[key].set_format("torch")

In [None]:
# for key in ner_dataset['train'][:1]:
#     print(key, ':', ner_dataset['train'][:1][key])

# ner_dataset['train'][:1]

In [None]:
# show_random_elements(ner_dataset['train'])#datasets['train'].map(tokenize_and_align_labels, batched=True))

In [None]:
#### Dataset finalize

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset, load_metric
import torch
import numpy as np 
import evaluate 
# Load the dataset
dataset_nli = load_dataset("multi_nli")
# Filter the training set to include only the genre you want to train on
nli_dataset = {}
nli_dataset['train'] = dataset_nli["train"]
nli_dataset['val'] = dataset_nli["validation_matched"]#.shuffle(seed=42).select(range(100)) ## subsampling for debug only
# nli_dataset['train'] = dataset_nli["train"].filter(lambda example: example["genre"] == "travel").shuffle(seed=42).select(range(1000))
# nli_dataset['val'] = dataset_nli["train"].filter(lambda example: example["genre"] == "telephone").shuffle(seed=42).select(range(100))

metric_nli = evaluate.load("accuracy")
metric_ner = evaluate.load("seqeval")
# metric_ner = load_metric("seqeval") #old code. future warning

def tokenize_dataset(dataset):
    return dataset.map(lambda example: tokenizer(example["premise"], example["hypothesis"], truncation=True, return_tensors='pt', 
                                                 padding='max_length', max_length=512), batched=True)

for key in nli_dataset:
    nli_dataset[key] = tokenize_dataset(nli_dataset[key])
    nli_dataset[key].set_format("torch")

def compute_metrics_nli(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric_nli.compute(predictions=predictions, references=labels)

def compute_metrics_ner(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=-1)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric_ner.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
# show_random_elements(nli_dataset['train'])

In [None]:
# Load the BERT models for NLI and NER

from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification

model_nli = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=3)
model_ner = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))

data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
# model_ner.bert.embeddings, model_ner.bert.encoder

In [None]:
# labels = [label_list[i] for i in example[f"ner_tags"]]
# metric_ner.compute(predictions=[labels], references=[labels])

model_ner.bert.embeddings, model_ner.bert.encoder

In [None]:
# predictions, labels, _ = trainer_ner.predict(ner_dataset["val"])
# predictions = np.argmax(predictions, axis=2)

# # Remove ignored index (special tokens)
# true_predictions = [
#     [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
#     for prediction, label in zip(predictions, labels)
# ]
# true_labels = [
#     [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
#     for prediction, label in zip(predictions, labels)
# ]

# results = metric_ner.compute(predictions=true_predictions, references=true_labels)
# results

In [None]:
# # train_dataset_nli[0].keys()
# example = ner_dataset['train'][0]
# type(example['input_ids'])
# # example["premise"] + ' ' + example["hypothesis"]

In [None]:
len(label_list)

In [None]:
from torch import nn
class custom_nli_ner_model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model_nli = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=3)
        self.model_ner = AutoModelForTokenClassification.from_pretrained(model_checkpoint, num_labels=9)
        # self.model_nli = model_nli
        # self.model_ner = model_ner
#         self.model_ner.distilbert = self.model_nli.distilbert 
        self.model_nli.bert.embeddings = self.model_ner.bert.embeddings
        self.model_nli.bert.encoder = self.model_ner.bert.encoder
        # self.device = self.model_ner.device
        print('self.model_ner.device:', self.model_ner.device)
#         assert (self.model_ner.distilbert == self.model_nli.distilbert), 'The models should have the same backbone/encoder'
        assert (self.model_ner.bert.embeddings == self.model_nli.bert.embeddings), 'The models should have the same embeddings' 
        assert (self.model_ner.bert.encoder == self.model_nli.bert.encoder), 'The models should have the same encoder'

    def forward(self,
        task, **kwargs,
        # input_ids=None,
        # attention_mask=None,
        # token_type_ids=None,
        # position_ids=None,
        # head_mask=None,
        # inputs_embeds=None,
        # labels=None,
        # task_ids=None,
        # **kwargs,
    ):
        # print('kwargs:', kwargs)
        if task == 'nli':
            return self.model_nli(
                        # input_ids=input_ids,
                        # attention_mask=attention_mask,
                        # head_mask=head_mask,
                        # inputs_embeds=inputs_embeds
                        # labels=kwargs['label'],
                        **kwargs
                    )
        elif task == 'ner':
            return self.model_ner(
                        # input_ids=input_ids,
                        # attention_mask=attention_mask,
                        # head_mask=head_mask,
                        # inputs_embeds=inputs_embeds
                        # labels=kwargs['labels'],
                        **kwargs
                    )
        else:
            raise

In [None]:
# model_nli(**ner_dataset["train"][0])

multitask_model = custom_nli_ner_model()
multitask_model = multitask_model.cuda()

In [None]:
task = 'nli'
if task=='ner':
    sample = ner_dataset["train"][:3] 
    label_key = 'labels'
elif task=='nli':
    sample = nli_dataset["train"][:3] 
print(sample.keys())
# print(sample)

label_keys = {'ner':'labels', 'nli':'label'}
# def totensor(x):
#     return torch.Tensor(x)
# out = multitask_model(task=task, **sample)
# sample['labels']
print()
out = multitask_model(task=task,
                 input_ids=sample['input_ids'].cuda(), 
                 attention_mask=sample['attention_mask'].cuda(), 
                 labels=sample[label_keys[task]].cuda()
                 )

for key in out:
    print(key, ":", out[key], out[key].shape)

print(out.keys())

In [None]:
from torch.utils.data import DataLoader

batch_size = 8
worker_num = 0

nli_dataloaders = {} 
nli_dataloaders['train'] = DataLoader(nli_dataset["train"], shuffle=True, batch_size=batch_size, num_workers=worker_num)
nli_dataloaders['val'] = DataLoader(nli_dataset["val"], batch_size=batch_size*4, num_workers=worker_num)

ner_dataloaders = {} 
ner_dataloaders['train'] = DataLoader(ner_dataset["train"], shuffle=True, batch_size=batch_size, num_workers=worker_num)
ner_dataloaders['val'] = DataLoader(ner_dataset["val"], batch_size=batch_size*4, num_workers=worker_num)

In [None]:


for i, data in enumerate(nli_dataloaders['train']):
    print(i, data['input_ids'].shape)
    if i == 5:
        break
print()
for i, data in enumerate(ner_dataloaders['train']):
    print(i, data['input_ids'].shape)
    if i == 5:
        break

In [None]:
nli_dataset['val']['label']

In [None]:
## checking the parameters of the forward function
import inspect

print(inspect.getfullargspec(model_nli.forward))
print(inspect.getfullargspec(model_ner.forward))

In [None]:
multitask_model.load_state_dict(torch.load('nli_ner_fixed_weights_.pth'))
multitask_model.cuda()

### NER trainer

training_args_ner = TrainingArguments(
    f"ner-finetuned",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size*4,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
)

trainer_ner = Trainer(
    model=multitask_model.model_ner,
    args=training_args_ner,
    train_dataset=ner_dataset["train"],
    eval_dataset=ner_dataset["val"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_ner
)

### NLI trainer

training_args_nli = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size*4,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=500,
)

trainer_nli = Trainer(
    model=multitask_model.model_nli,
    args=training_args_nli,
    train_dataset=nli_dataset['train'], #train_dataset,
    eval_dataset=nli_dataset['val'],
    compute_metrics=compute_metrics_nli,
)


#### training single task models
# trainer_ner.train()
# trainer_nli.train()

# eval_results_nli = trainer_nli.evaluate(nli_dataset['val'])
# eval_results_ner = trainer_ner.evaluate(ner_dataset['val'])

# eval_results_nli, eval_results_ner, 

In [None]:
# trainer_ner.predict(ner_dataset['val'])

In [None]:
### training multitask simultaneous

from collections import Counter
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

### task sampler dataloader
dataloader_dict = {'nli': nli_dataloaders['train'], 'ner': ner_dataloaders['train']}
task_sampler = TaskSampler(dataloader_dict=dataloader_dict, max_iters=10_000, task_weights=[.5, .5])


### training misc.
optimizer = AdamW(multitask_model.parameters(), lr=2e-5)
num_epochs = 5
num_training_steps = num_epochs * task_sampler.max_iters
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
progress_bar = tqdm(range(num_training_steps))

freq_list = [] 

loss_prev = {}
loss_prev['nli'] = -1 # L(t-1)
loss_prev['ner'] = -1 # L(t-1)
loss_current = {}
loss_current['nli'] = -1 # L(t)
loss_current['ner'] = -1 # L(t)
# NumTasks = 2 ## T in equation
_sigma_ = 1 ### ?

### training loop 
for epoch in range(num_epochs):
    
    multitask_model.train()
    
    for batch_index, batch in enumerate(task_sampler):
#         break
        task, sample = batch
        freq_list += [task]
        # print(batch_index, task, sample)
        
        outputs = multitask_model(task=task,
                            input_ids=sample['input_ids'].cuda(), 
                            attention_mask=sample['attention_mask'].cuda(), 
                            labels=sample[label_keys[task]].cuda()
                            )
        loss = outputs.loss
        loss_current[task] = loss.item()
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)
        
#         ## weight update
#         if loss_prev['nli'] >= 0 and loss_prev['ner'] >= 0:
#             exp_r_nli = np.exp((loss_current['nli']/loss_prev['nli']) / _sigma_) 
#             exp_r_ner = np.exp((loss_current['ner']/loss_prev['ner']) / _sigma_) 
#             w_nli = exp_r_nli/(exp_r_nli+exp_r_ner)
#             task_sampler.set_task_weights(task_weights=[w_nli, 1-w_nli])
#             print('set_task_weights updated:', task_sampler.task_weights) 
        
        for key in loss_current:
            loss_prev[key] = loss_current[key]
        
        # print(outputs)

    
    multitask_model.eval()
    
    torch.save(multitask_model.state_dict(), 'nli_ner_fixed_weights.pth')
    
    eval_results_nli = trainer_nli.evaluate(nli_dataset['val']); print('epoch:', epoch, '> eval_results_nli:', eval_results_nli)
    eval_results_ner = trainer_ner.evaluate(ner_dataset['val']); print('epoch:', epoch, '> eval_results_ner:', eval_results_ner)
    


print('Sample frequency stat:')
cntr = Counter(freq_list)
cntr['nli'], cntr['ner']

In [None]:
# torch.save(multitask_model.state_dict(), 'nli_ner.pth')

multitask_model = custom_nli_ner_model()
multitask_model.load_state_dict(torch.load('nli_ner_fixed_weights.pth'))
multitask_model = multitask_model.cuda()
# multitask_model.eval() 

In [None]:
trainer_ner = Trainer(
    model=multitask_model.model_ner,
    args=training_args_ner,
    train_dataset=ner_dataset["train"],
    eval_dataset=ner_dataset["val"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_ner
)
trainer_nli = Trainer(
    model=multitask_model.model_nli,
    args=training_args_nli,
    train_dataset=nli_dataset['train'], #train_dataset,
    eval_dataset=nli_dataset['val'],
    compute_metrics=compute_metrics_nli,
)
eval_results_nli = trainer_nli.evaluate(nli_dataset['val']); print('eval_results_nli:', eval_results_nli)
eval_results_ner = trainer_ner.evaluate(ner_dataset['val']); print('eval_results_ner:', eval_results_ner)


In [None]:
trainer_nli.save_model('./fixed_weight')

In [None]:
# !zip -r fixed_weight.zip fixed_weight