In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
from datasets import Dataset, DatasetDict, load_dataset
from transformers import Trainer, TrainingArguments, DefaultDataCollator, PretrainedConfig, PreTrainedModel, DataCollatorWithPadding
from transformers.modeling_outputs import SequenceClassifierOutput
from peft import LoraConfig, LoftQConfig, TaskType, get_peft_model
import evaluate
import pandas as pd
import glob
import os

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

In [None]:
class E5DataLoader:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        dataset = load_dataset("csv", data_files="../../data/temp_data.csv", split='train')
        dataset = dataset.class_encode_column('label')
        dataset = dataset.train_test_split(test_size=0.2, stratify_by_column='label')
        self.train_dataset, self.eval_dataset = dataset['train'], dataset['test']
        self.train_dataset.set_transform(self._transform)
        self.eval_dataset.set_transform(self._transform)
        
    
    def _transform(self, examples):
        docs = [f'passage: {doc}' for doc in examples['description']]
        queries = [f'query: {query}' for query in examples['comment']]

        batch_dict = self.tokenizer(queries, 
                                    text_pair=docs,
                                    max_length=512,
                                    truncation=True,
                                    )
        
        batch_dict['label'] = examples['label']
        
        return batch_dict

In [None]:
class E5NNTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super(E5NNTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

In [None]:
class E5NNConfig(PretrainedConfig):
    model_type = 'E5_NN'
    
    def __init__(self, num_labels=2, **kwargs):
        super().__init__(**kwargs)
        self.num_labels = num_labels


class E5NN(PreTrainedModel):
    config_class = E5NNConfig
    
    def __init__(self, config):
        super(E5NN, self).__init__(config)
        self.num_labels = config.num_labels
        self.e5 = AutoModel.from_pretrained('intfloat/multilingual-e5-large', load_in_8bit=True)
        self.linear = nn.Linear(1024, config.num_labels, dtype=torch.qint8)
        self.cross_entropy = nn.CrossEntropyLoss()
        
    def forward(self, input_ids, attention_mask, labels, **kwargs):
        e5_outputs = self.e5(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        features = e5_outputs.pooler_output
        logits = self.linear(features)
        prob = torch.softmax(logits, dim=-1).to(device)
        loss = self.cross_entropy(prob, labels)
        return SequenceClassifierOutput(loss=loss, logits=logits)


In [None]:
class E5NNCollator(DataCollatorWithPadding):

    def __call__(self, examples):
        batch_dict = {
            'input_ids': [example['input_ids'] for example in examples],
            'attention_mask': [example['attention_mask'] for example in examples],
            'labels': [int(example['label']) for example in examples]
        }
        

        collated_batch_dict = self.tokenizer.pad(
            batch_dict,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors
        )

        return collated_batch_dict

In [None]:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
data_loader = E5DataLoader(tokenizer)
train_data = data_loader.train_dataset
eval_data = data_loader.eval_dataset

loftq_config = LoftQConfig(loftq_bits=8)
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS if device == 'cuda' else None,
                         init_lora_weights="loftq" if device == 'cuda' else dict(),
                         loftq_config=loftq_config,
                         target_modules=[
                           'query',
                           'key'
                         ],
                         inference_mode=False, 
                         r=8, 
                         lora_alpha=32, 
                         lora_dropout=0.1
                         )


config = E5NNConfig()
model = E5NN(config)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

data_collator = E5NNCollator(
    tokenizer=tokenizer,
    max_length=512
)

def compute_metrics(eval_pred):
  logits = torch.tensor(eval_pred.predictions, device=device)
  labels = torch.tensor(eval_pred.label_ids, device=device, dtype=torch.int32)
  probs = torch.softmax(logits, dim=-1).to(device)
  predictions = probs.argmax(dim=-1).to(device)
  metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

  return metrics.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(
    output_dir='saved_models/e5nn',
    evaluation_strategy='steps',
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_strategy='steps',
    save_steps=0.2,
    logging_steps=0.2,
    load_best_model_at_end=True,
    remove_unused_columns=False,
)

trainer = E5NNTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()