In [1]:
import evaluate
import torch
import torch.nn as nn
from datasets import load_dataset
from peft import LoftQConfig, LoraConfig, TaskType, get_peft_model
from transformers import (AutoModel, AutoTokenizer, DataCollatorWithPadding,
                          PretrainedConfig, PreTrainedModel, Trainer,
                          TrainingArguments)
from transformers.modeling_outputs import SequenceClassifierOutput

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [2]:
class E5DataLoader:
    def __init__(self, tokenizer, data_file):
        self.tokenizer = tokenizer
        dataset = load_dataset("csv", data_files=data_file, split='train')
        dataset = dataset.class_encode_column('label')
        dataset = dataset.train_test_split(test_size=0.2, stratify_by_column='label')
        self.train_dataset, self.eval_dataset = dataset['train'], dataset['test']
        self.train_dataset.set_transform(self._transform)
        self.eval_dataset.set_transform(self._transform)


    def _transform(self, examples):
        docs = [f'passage: {doc}' for doc in examples['description']]
        queries = [f'query: {query}' for query in examples['comment']]

        assert len(docs) == len(queries)
        assert len(queries) == len(examples['label'])

        query_batch_dict = self.tokenizer(queries,
                                    max_length=512,
                                    truncation=True,
                                    )

        doc_batch_dict = self.tokenizer(docs,
                                    max_length=512,
                                    truncation=True,
                                    )

        merged_batch_dict = {f'q_{k}': v for k, v in query_batch_dict.items()}
        for k, v in doc_batch_dict.items():
            k = f'd_{k}'
            merged_batch_dict[k] = v

        merged_batch_dict['label'] = examples['label']

        return merged_batch_dict

In [3]:
class E5NNTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super(E5NNTrainer, self).__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs.loss

        return (loss, outputs) if return_outputs else loss

In [4]:
class E5NNConfig(PretrainedConfig):
    model_type = 'E5NN'

    def __init__(self, num_labels=2, **kwargs):
        super().__init__(**kwargs)
        self.num_labels = num_labels


class E5NN(PreTrainedModel):
    config_class = E5NNConfig

    def __init__(self, config):
        super(E5NN, self).__init__(config)
        self.num_labels = config.num_labels
        self.e5 = AutoModel.from_pretrained('intfloat/multilingual-e5-large')
        self.linear = nn.Linear(1024*2, config.num_labels)

        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, **kwargs):
        e5_outputs = self.e5(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )

        features = e5_outputs.pooler_output
        features = torch.cat((features[:len(input_ids)//2], features[len(input_ids)//2:]), dim=-1)
        logits = self.linear(features)
        prob = torch.softmax(logits, dim=-1).to(device)
        loss = self.cross_entropy(prob, labels)
        return SequenceClassifierOutput(loss=loss, logits=logits)


In [5]:
class E5NNCollator(DataCollatorWithPadding):

    def __call__(self, examples):
        q_prefix, d_prefix = 'q_', 'd_'

        queries = [{k[len(q_prefix):]: v for k, v in example.items() if q_prefix in k}
                        for example in examples]

        docs = [{k[len(d_prefix):]: v for k, v in example.items() if d_prefix in k}
                        for example in examples]

        batch_collated = self.tokenizer.pad(
            queries + docs,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors
        )

        batch_collated['labels'] = torch.tensor([example['label'] for example in examples])

        return batch_collated

In [6]:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
data_file = '../../data/temp_data.csv'
data_loader = E5DataLoader(tokenizer,data_file=data_file)
train_data = data_loader.train_dataset
eval_data = data_loader.eval_dataset

loftq_config = LoftQConfig(loftq_bits=8)
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS if device == 'cuda' else None,
                         init_lora_weights="loftq" if device == 'cuda' else dict(),
                         loftq_config=loftq_config,
                         target_modules=[
                           'query',
                           'key'
                         ],
                         inference_mode=False,
                         r=8,
                         lora_alpha=32,
                         lora_dropout=0.1
                         )


config = E5NNConfig()
model = E5NN(config)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
print(model)

data_collator = E5NNCollator(
    tokenizer=tokenizer,
    max_length=512
)

def compute_metrics(eval_pred):
  logits = torch.tensor(eval_pred.predictions, device=device)
  labels = torch.tensor(eval_pred.label_ids, device=device, dtype=torch.int32)
  probs = torch.softmax(logits, dim=-1).to(device)
  predictions = probs.argmax(dim=-1).to(device)
  metrics = evaluate.combine(["accuracy", "f1", "precision", "recall"])

  return metrics.compute(predictions=predictions, references=labels)


training_args = TrainingArguments(
    output_dir='saved_models/e5nn',
    evaluation_strategy='steps',
    learning_rate=2e-5,
    weight_decay=0.01,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    save_strategy='steps',
    save_steps=0.2,
    logging_steps=0.2,
    load_best_model_at_end=True,
    remove_unused_columns=False,
    label_names=['labels']
)

trainer = E5NNTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainable params: 786,432 || all params: 560,680,962 || trainable%: 0.14026372452432226
PeftModel(
  (base_model): LoraModel(
    (model): E5NN(
      (e5): XLMRobertaModel(
        (embeddings): XLMRobertaEmbeddings(
          (word_embeddings): Embedding(250002, 1024, padding_idx=1)
          (position_embeddings): Embedding(514, 1024, padding_idx=1)
          (token_type_embeddings): Embedding(1, 1024)
          (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): XLMRobertaEncoder(
          (layer): ModuleList(
            (0-23): 24 x XLMRobertaLayer(
              (attention): XLMRobertaAttention(
                (self): XLMRobertaSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (default): Dropout(p=0.1, inplace=False)
    

In [7]:
trainer.train()

  0%|          | 0/37 [00:00<?, ?it/s]

You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 0.599, 'learning_rate': 1.5675675675675676e-05, 'epoch': 0.22}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.5616773366928101, 'eval_accuracy': 0.9459459459459459, 'eval_f1': 0.9722222222222222, 'eval_precision': 0.9459459459459459, 'eval_recall': 1.0, 'eval_runtime': 12.8871, 'eval_samples_per_second': 2.871, 'eval_steps_per_second': 0.776, 'epoch': 0.22}




{'loss': 0.5966, 'learning_rate': 1.1351351351351352e-05, 'epoch': 0.43}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.5606305599212646, 'eval_accuracy': 0.9459459459459459, 'eval_f1': 0.9722222222222222, 'eval_precision': 0.9459459459459459, 'eval_recall': 1.0, 'eval_runtime': 15.6901, 'eval_samples_per_second': 2.358, 'eval_steps_per_second': 0.637, 'epoch': 0.43}




{'loss': 0.5904, 'learning_rate': 7.027027027027028e-06, 'epoch': 0.65}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.5593210458755493, 'eval_accuracy': 0.9459459459459459, 'eval_f1': 0.9722222222222222, 'eval_precision': 0.9459459459459459, 'eval_recall': 1.0, 'eval_runtime': 12.0806, 'eval_samples_per_second': 3.063, 'eval_steps_per_second': 0.828, 'epoch': 0.65}




{'loss': 0.5906, 'learning_rate': 2.702702702702703e-06, 'epoch': 0.86}


  0%|          | 0/10 [00:00<?, ?it/s]

{'eval_loss': 0.5590757727622986, 'eval_accuracy': 0.9459459459459459, 'eval_f1': 0.9722222222222222, 'eval_precision': 0.9459459459459459, 'eval_recall': 1.0, 'eval_runtime': 13.4099, 'eval_samples_per_second': 2.759, 'eval_steps_per_second': 0.746, 'epoch': 0.86}




{'train_runtime': 154.2584, 'train_samples_per_second': 0.946, 'train_steps_per_second': 0.24, 'train_loss': 0.5962875340435956, 'epoch': 1.0}


TrainOutput(global_step=37, training_loss=0.5962875340435956, metrics={'train_runtime': 154.2584, 'train_samples_per_second': 0.946, 'train_steps_per_second': 0.24, 'train_loss': 0.5962875340435956, 'epoch': 1.0})