In [3]:
from torch.utils.data import Dataset
from torch import nn
from transformers import (AutoModelForSequenceClassification,
                          AutoTokenizer,
                          Trainer)
import torch
from datasets import Dataset as HFDataset
from datasets import concatenate_datasets

class ActiveLearningModel(nn.Module):
    """
        Adapter for our LLM to be able to communicate with the teacher
        dataset. It can contain any
    """
    def __init__(self, model_id, tok, teacher, full_model=None):
        super().__init__()
        if(full_model is None): ## if you want to input a model-id to load from HF
          self.model = AutoModelForSequenceClassification.from_pretrained(model_id)
        else : ## if you want to input a full model with all its params
          self.model = full_model
        self.tok = tok
        self.teacher = teacher ## teacher ds

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                labels=None) -> torch.Tensor:
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)

        ## send to teacher model
        self.teacher.update(output.logits)

        return output


class MyDataset(Dataset):
    """
        Dataset class that process the samples and puts them on the necessary device
    """

    def __init__(self, data, tokenizer, max_length, device) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.device = device
        self.initial_data = data ## used for introducing of randomness

    def __len__(self):
        # Return the total number of samples in your dataset
        return len(self.data)


    def __getitem__(self, idx):
        # Retrieve and preprocess a single sample at the given index

        sample = self.data[idx]

        # Use the tokenizer to tokenize the input text
        inputs = self.tokenizer(
            sample["tweet"],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        # You might want to include other information such as labels

        labels = sample.get("labels", [])

        # Return a dictionary containing the input_ids, attention_mask, and labels
        return {
            "input_ids": inputs["input_ids"].squeeze(0).to(self.device),
            "attention_mask": inputs["attention_mask"].squeeze(0).to(self.device),
            "labels": torch.tensor(labels).to(self.device),
        }

class TeacherDataset(MyDataset):
    """
        Teacher dataset can be updated using the given student model
        updates
    """
    def __init__(self, data, tokenizer, max_length, device, T):
        # self.data = self.load_data()
        super().__init__(data, tokenizer, max_length, device)
        self.T = T
        self.allHs = torch.tensor([]).to(self.device)



    def update(self, past_logits, automatic_new_iter=False):
        Ps = torch.softmax(past_logits, dim=0)
        nHs = (-Ps * torch.log(Ps)).sum(dim=-1).reshape((-1,))
        if(len(self.allHs) + len(nHs) <= len(self.data)): ## test data
          self.allHs = torch.cat([self.allHs, nHs])
          if(automatic_new_iter and (self.allHs.size()[0] == len(self.data))):
              self.new_iter()

    def new_iter(self, add_randomness=False):
        if(not add_randomness):
          selected_idx = self.allHs.argsort()[-int(len(self.allHs) / 2):] ## take the sample above the entropy median
          print("> TEACHER ROUND, from", len(self.data), "samples to", len(selected_idx))
          self.data = HFDataset.from_dict(self.data[selected_idx.tolist()])

        else :
          CHUNK_SIZE = int(1 * len(self.allHs) / 4)
          high_H_idxs = self.allHs.argsort()[-CHUNK_SIZE:] ## 1/4 of high entropy samples
          high_H_samples = HFDataset.from_dict(self.data[high_H_idxs.tolist()])
          random_init_samples = HFDataset.from_dict(self.data[torch.randperm(len(self.data))[:CHUNK_SIZE]])
          print("> TEACHER ROUND, from", len(self.data), "samples to", len(high_H_idxs) + len(random_init_samples))
          self.data = concatenate_datasets([random_init_samples, high_H_samples])

        self.allHs = torch.tensor([]).to(self.device)

        if(len(self.data) < self.T):
          print(f"Threshold was set at T = {self.T}, {len(self.data)} remaining datapoints, halting.")
          self.data = HFDataset.from_dict({}) ## empty dataset ===> halt


class RunArgs():
   
   def __init__(self,
                n_samples:int=200_000,
                device:str="cuda:0",
                T:int=10_000,
                test_ratio:float=0.1,
                max_length:int=130,
                batch_size:int=64,
                learning_rate:float=2e-5,
                weight_decay:float=0.01,
                SAVE_DIR:str="models/",
                BASE_MODEL:str="vinai/bertweet-base",
                warmup_prcnt:float=.3) -> None:
      self.n_samples = n_samples
      self.device = device
      self.T = T
      self.test_ratio = test_ratio
      self.max_length = max_length
      self.batch_size = batch_size
      self.learning_rate = learning_rate
      self.weight_decay = weight_decay
      self.SAVE_DIR = SAVE_DIR
      self.BASE_MODEL = BASE_MODEL
      self.warmup_prcnt = warmup_prcnt


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        labels = labels.long()
        # forward pass
        outputs = model(input_ids=inputs["input_ids"],
                        attention_mask=inputs["attention_mask"])
        logits = outputs.get("logits")
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(logits, labels.long())
        return (loss, outputs) if return_outputs else loss


In [4]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [5]:
## with randomness ::

from utils import create_datasets
from transformers import TrainingArguments, AdamW, get_linear_schedule_with_warmup
import evaluate
import numpy as np
import math


def run_with_args(args):
        BERT_MODEL = args.BASE_MODEL
        DIR = args.SAVE_DIR
        tok = AutoTokenizer.from_pretrained(BERT_MODEL)
        tok.pad_token = tok.eos_token
        ds = create_datasets(sub_sampling=args.n_samples) ## small dataset for testing
        ds = ds.train_test_split(test_size=args.test_ratio)
        ds_train, ds_test = ds["train"], ds["test"]

        teacher = TeacherDataset(ds_train,
                                tokenizer=tok,
                                max_length=args.max_length,
                                device=args.device,
                                T=args.T)

        test_data = MyDataset(ds_test, tokenizer=tok, max_length=args.max_length, device=args.device)

        model = ActiveLearningModel(BERT_MODEL, tok=tok, teacher=teacher)
        accuracy = evaluate.load("accuracy")

        def compute_metrics(eval_pred):
            predictions, labels = eval_pred
            predictions = np.argmax(predictions, axis=1)
            return accuracy.compute(predictions=predictions, references=labels)
        total_steps = 2 * math.ceil(len(ds_train) / args.batch_size)
        warmup_steps = int(args.warmup_prcnt * total_steps)
        optimizer = AdamW(model.parameters(), lr=args.learning_rate, no_deprecation_warning=True)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=total_steps, num_warmup_steps=warmup_steps)

        iter = 1
        while(len(teacher) != 0):
            ## train for 1 epoch = 1 round of the DataLoader
            training_args = TrainingArguments(
                output_dir=DIR,
                per_device_train_batch_size=args.batch_size,
                per_device_eval_batch_size=args.batch_size,
                num_train_epochs=1, ## only 1 epoch
                evaluation_strategy="epoch",
                save_strategy="epoch",
                remove_unused_columns=False,
            )

            trainer = CustomTrainer(
                model=model,
                args=training_args,
                train_dataset=teacher,
                eval_dataset=test_data,
                tokenizer=tok,
                optimizers=(optimizer, scheduler),
                compute_metrics=compute_metrics
            )

            trainer.train()
            print("FINISHED EPOCH 1 ==> updating")
            teacher.new_iter(add_randomness=True) ## teacher round
            model.model.save_pretrained(DIR+f"epoch_{iter}/")
            iter += 1

In [None]:
args = RunArgs(**{
    "n_samples": 150_000,
    "device": mps_device,
    "T": 10_000,
    "max_length": 130,
    "batch_size":64
}) ## default BERTweet Model
run_with_args(args)
