In [1]:
from torch.utils.data import DataLoader
from transformers import Trainer
import torch
import torch.nn.functional as F
import torch, time, itertools, json
from datasets import load_dataset
from transformers import (AutoTokenizer,
                          AutoModelForSequenceClassification,
                          TrainingArguments, Trainer)
from peft import PrefixTuningConfig, TaskType, get_peft_model
from evaluate import load as load_metric
from torch.distributions import Normal
from sklearn.model_selection import train_test_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
metric = load_metric("accuracy")  

In [2]:
from transformers import TrainerCallback, get_scheduler

class Phase2LRScheduler(TrainerCallback):
    def __init__(self, phase1_epochs: int, phase2_lr: float, warmup_ratio: float):
        super().__init__()
        self.phase1_epochs = phase1_epochs
        self.phase2_lr     = phase2_lr
        self.warmup_ratio  = warmup_ratio
        self.trainer       = None

    def on_epoch_begin(self, args, state, control, **kwargs):
        if int(state.epoch) == self.phase1_epochs and self.trainer:
            for g in self.trainer.optimizer.param_groups:
                g["lr"] = self.phase2_lr

            steps_per_epoch = len(self.trainer.get_train_dataloader())
            remaining_epochs = args.num_train_epochs - self.phase1_epochs
            total_steps_phase2 = remaining_epochs * steps_per_epoch
            num_warmup = int(self.warmup_ratio * total_steps_phase2)

            self.trainer.lr_scheduler = get_scheduler(
                name="cosine",
                optimizer=self.trainer.optimizer,
                num_warmup_steps=num_warmup,
                num_training_steps=total_steps_phase2
            )

            print(f"\n>>> Phase 2 START: reset LR to {self.phase2_lr} and restart cosine scheduler "
                  f"({total_steps_phase2} steps, {num_warmup} warmup)\n")
        return control

In [7]:
from transformers import TrainingArguments
from datasets import load_dataset
from transformers import DataCollatorWithPadding
from transformers import AutoConfig
from peft import PrefixTuningConfig, get_peft_model
def prepare_for_training(
        dataset_name, 
        model_name,
        text_column_name,
        label_column_name,
        num_virtual_tokens=20, 
        max_length=128):
    if dataset_name == "sst2":
        raw = load_dataset("glue", dataset_name)
    else:
        raw = load_dataset(dataset_name, trust_remote_code=True)

    tok = AutoTokenizer.from_pretrained(model_name)
    if model_name != "roberta-base":
        tok.pad_token = tok.eos_token
    def prep(x):
        t = tok(x[text_column_name], padding="max_length", truncation=True, max_length=max_length)
        t["labels"]=x[label_column_name]; 
        return t
    data_collator = DataCollatorWithPadding(tokenizer=tok, return_tensors="pt")
    neseccary_cols = ["input_ids","attention_mask","labels"]
    ds = raw["train"].shuffle(seed=42).map(
        prep, 
        batched=True, 
        remove_columns=set(raw["train"].features.keys()).difference(set(neseccary_cols))
    )
    ds.set_format("torch",columns=neseccary_cols)

    split = ds.train_test_split(0.1, shuffle=True)
    train_main = split["train"]
    val_ds  = split["test"]
    split2 = val_ds.train_test_split(test_size=100, shuffle=True)
    val_ds = split2["train"]
    rl_subset = split2["test"]
    test_ds = raw.get('validation') or raw.get("test")
    test_ds = test_ds.map(
        prep, 
        batched=True, 
        remove_columns=set(raw["train"].features.keys()).difference(set(neseccary_cols))
    )
    test_ds.set_format("torch",columns=neseccary_cols)

    num_labels = len(raw["train"].features[label_column_name].names)
    
    if model_name != "roberta-base":
        config = AutoConfig.from_pretrained(
            "tiiuae/falcon-rw-1b",
            hidden_size=768, 
            num_hidden_layers=6, 
            num_attention_heads=12,
            num_key_value_heads=2
        )

        base = AutoModelForSequenceClassification.from_config(
            config,
            torch_dtype=torch.bfloat16,
        )
        tok.pad_token = tok.eos_token
        base.config.pad_token_id = tok.pad_token_id
    else:
        base = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels,
            torch_dtype=torch.bfloat16,
        )
    if model_name == "roberta-base":
        for p in base.roberta.parameters(): 
            p.requires_grad=False
    else:
        for p in base.parameters(): 
            p.requires_grad=False


    cfg = PrefixTuningConfig(
        task_type="SEQ_CLS",
        num_virtual_tokens=num_virtual_tokens,
        prefix_projection=False
    )
    model = get_peft_model(base,cfg)
    return train_main, val_ds, rl_subset, test_ds, model, data_collator, tok

In [8]:

class TwoPhaseTrainer(Trainer):
    def __init__(
        self,
        *args,
        phase1_epochs: int = 3,
        alpha: float = 1.0,
        beta: float  = 0.3, 
        gamma: float = 0.05,  
        rl_subset_size: int = 32,
        rl_dataset=None,
        k_negatives: int = 1,
        sigma: float = 0.00002,
        **kwargs
    ):
        self.phase1_epochs  = phase1_epochs
        self.alpha          = alpha
        self.beta           = beta
        self.gamma          = gamma
        self.rl_subset_size = rl_subset_size
        self.rl_dataset     = rl_dataset
        self.k_negatives    = k_negatives
        self.sigma          = sigma
        super().__init__(*args, **kwargs)

        subset = list(range(min(len(self.rl_dataset), self.rl_subset_size)))
        self._rl_loader = DataLoader(
            self.rl_dataset.select(subset),
            batch_size=self.args.per_device_train_batch_size,
            collate_fn=self.data_collator
        )
        self.prefix_params = []
        for name, param in self.model.named_parameters():
            if "prompt_encoder" in name:
                self.prefix_params.append((name, param))

    def mutate_prefix_for_contrastive(self):
        negs = []
        for _ in range(self.k_negatives):
            mutated = {}
            for name, param in self.prefix_params:
                mask = (torch.rand_like(param) > 0.1).float()
                mutated[name] = param * mask
            negs.append(mutated)
        return negs

    def compute_contrastive_loss(self, logits_pos, logits_negs, temp=1.0):
        pos_norm = F.normalize(logits_pos, dim=-1)
        neg_norms = [F.normalize(n, dim=-1) for n in logits_negs]
        sim_pos = (pos_norm * pos_norm).sum(-1) / temp
        sim_negs = torch.stack([(pos_norm*neg).sum(-1) for neg in neg_norms], dim=1) / temp
        loss = -torch.log(sim_pos.exp() / (sim_pos.exp() + sim_negs.exp().sum(1))).mean()
        return loss

    @torch.no_grad()
    def compute_reward(self, model):
        was_train = model.training
        model.eval()
        correct = total = 0
        for batch in self._rl_loader:
            batch = {k: v.to(model.device) for k,v in batch.items()}
            out = model(**batch).logits.argmax(-1)
            correct += (out == batch["labels"]).sum().item()
            total += out.size(0)
        if was_train: model.train()
        return correct/total

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs   = model(**inputs)
        loss_mle  = outputs.loss
        if (self.state.epoch < self.phase1_epochs) or not model.training:
            return (loss_mle, outputs) if return_outputs else loss_mle

        logits_pos = outputs.logits
        original = {n: p.data.clone() for n,p in self.prefix_params}

        neg_prefixes = self.mutate_prefix_for_contrastive()
        logits_negs = []
        for neg in neg_prefixes:
            with torch.no_grad():
                for name, param in self.prefix_params:
                    param.data.copy_(neg[name].data)
            logits_negs.append(model(**inputs).logits.detach())

        with torch.no_grad():
            for name, param in self.prefix_params:
                param.data.copy_(original[name])

        loss_contrast = self.compute_contrastive_loss(logits_pos, logits_negs)


        if self.state.global_step % 100 == 0:
            log_probs = []

            for name, param in self.prefix_params:
                eps   = torch.randn_like(param) * self.sigma
                noisy = param + eps
                dist  = Normal(loc=param, scale=self.sigma)
                log_probs.append(dist.log_prob(noisy).sum())
                with torch.no_grad():
                    param.data.copy_(noisy.data)

            reward = self.compute_reward(model)


            with torch.no_grad():
                for name, param in self.prefix_params:
                    param.data.copy_(original[name])

            total_log_prob = torch.stack(log_probs).mean()
            loss_rl = - reward * total_log_prob
        else:
            loss_rl = 0.0

        loss = self.alpha*loss_mle + self.beta*loss_contrast + self.gamma*loss_rl
        return (loss, outputs) if return_outputs else loss

In [35]:
def prepare_trainer(model, 
                    train_ds, 
                    val_ds, 
                    rl_subset, 
                    data_collator, 
                    tok, 
                    alpha=1.0, 
                    beta=0.1, 
                    gamma=0.0002,
                    rl_subset_size=32,
                    k_negatives=7,
                    phase1_epochs=3,
                    learning_rate=1e-3,
                    warmup_ratio=0.1,
                    per_device_train_batch_size=16,
                    per_device_eval_batch_size=32,
                    num_train_epochs=5,
                    lr_scheduler_type="cosine",
                    sigma=0.1,
                    ):

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = logits.argmax(-1)
        return metric.compute(predictions=preds, references=labels)

    training_args = TrainingArguments(
        output_dir="two_phase",
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=per_device_eval_batch_size,
        learning_rate=learning_rate,
        lr_scheduler_type=lr_scheduler_type,
        warmup_ratio=warmup_ratio,
        eval_strategy="steps",
        eval_steps = 1000,
        metric_for_best_model="accuracy",
        fp16=True,
        report_to="none",
    )
    lr_callback = Phase2LRScheduler(
        phase1_epochs=phase1_epochs, 
        phase2_lr=learning_rate, 
        warmup_ratio=warmup_ratio
    )

    trainer = TwoPhaseTrainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        rl_dataset=rl_subset,
        phase1_epochs=phase1_epochs,
        callbacks=[lr_callback],
        alpha=alpha, beta=beta, gamma=gamma,
        rl_subset_size=rl_subset_size,
        k_negatives=k_negatives,
        compute_metrics=compute_metrics,
        data_collator=data_collator,
        sigma=sigma
    )

    lr_callback.trainer = trainer
    return trainer



In [20]:
train_ds, val_ds, rl_subset, test_ds, model, data_collator, tok \
    = prepare_for_training("trec", "tiiuae/falcon-rw-1b", "text", "coarse_label", num_virtual_tokens=30)

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [21]:
trainer_trec = prepare_trainer(model, train_ds, val_ds, rl_subset, data_collator, tok, phase1_epochs=15, num_train_epochs=20, per_device_train_batch_size=2, per_device_eval_batch_size=2)

No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [22]:
trainer_trec.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.7524,1.655549,0.320628
2,1.5896,1.30203,0.58296
3,1.0915,0.78859,0.751121
4,0.777,0.64371,0.820628
5,0.6915,0.545119,0.852018
6,0.549,0.479054,0.867713
7,0.4869,0.445813,0.881166
8,0.4417,0.415156,0.88565
9,0.4265,0.408092,0.899103
10,0.3656,0.367892,0.912556



>>> Phase 2 START: reset LR to 0.001 and restart cosine scheduler (12265 steps, 1226 warmup)



TrainOutput(global_step=49060, training_loss=0.4011252948715888, metrics={'train_runtime': 2387.9889, 'train_samples_per_second': 41.089, 'train_steps_per_second': 20.544, 'total_flos': 6454345983528960.0, 'train_loss': 0.4011252948715888, 'epoch': 20.0})

In [23]:
trainer_trec.evaluate(test_ds)

{'eval_loss': 0.2305237352848053,
 'eval_accuracy': 0.961,
 'eval_runtime': 2.7524,
 'eval_samples_per_second': 181.658,
 'eval_steps_per_second': 90.829,
 'epoch': 20.0}

In [32]:
train_ds, val_ds, rl_subset, test_ds, model, data_collator, tok \
    = prepare_for_training("ag_news", "tiiuae/falcon-rw-1b", "text", "label", num_virtual_tokens=30)

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [37]:
trainer_ag_news = prepare_trainer(model, train_ds, val_ds, rl_subset, data_collator, tok, k_negatives=4, num_train_epochs=2, phase1_epochs=1)


No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [38]:
trainer_ag_news.train()


Step,Training Loss,Validation Loss,Accuracy
1000,1.3851,1.374372,0.347731
2000,0.5939,0.392951,0.880924
3000,0.3703,0.336402,0.895966
4000,0.3555,0.317509,0.9
5000,0.3397,0.30345,0.905462
6000,0.339,0.296863,0.904706
7000,0.055,0.288166,0.918908
8000,-0.3535,0.279136,0.92084
9000,-0.394,0.283998,0.920756
10000,-0.3825,0.277851,0.922605



>>> Phase 2 START: reset LR to 0.001 and restart cosine scheduler (6750 steps, 675 warmup)



TrainOutput(global_step=13500, training_loss=0.11256539535522461, metrics={'train_runtime': 1032.3108, 'train_samples_per_second': 209.239, 'train_steps_per_second': 13.077, 'total_flos': 1.4208252125184e+16, 'train_loss': 0.11256539535522461, 'epoch': 2.0})

In [None]:
trainer_ag_news.evaluate(test_ds)

{'eval_loss': 0.279926061630249,
 'eval_accuracy': 0.935921052631579,
 'eval_runtime': 6.0118,
 'eval_samples_per_second': 1264.182,
 'eval_steps_per_second': 39.589,
 'epoch': 2.0}

In [39]:
train, val, rl, test, model, data_collator, tok = \
    prepare_for_training(
        "sst2",
        "tiiuae/falcon-rw-1b", 
        "sentence", 
        "label", 
        num_virtual_tokens=25,
        )



You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [40]:
trainer_sst2 = prepare_trainer(
    model, 
    train, 
    val, 
    rl, 
    data_collator, 
    tok,
    per_device_eval_batch_size=16,
    per_device_train_batch_size=16,
    num_train_epochs=5,
    phase1_epochs=3,
    )
trainer_sst2.train()
trainer_sst2.evaluate(test)

No label_names provided for model class `PeftModelForSequenceClassification`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss,Validation Loss,Accuracy
1000,0.6827,0.679614,0.556745
2000,0.6833,0.663135,0.564883
3000,0.4913,0.395994,0.833911
4000,0.4077,0.33954,0.861794
5000,0.3757,0.310263,0.874604
6000,0.3549,0.297387,0.880784
7000,0.337,0.282504,0.887717
8000,0.323,0.270703,0.889676
9000,0.3142,0.265847,0.890882
10000,0.3128,0.262738,0.894499



>>> Phase 2 START: reset LR to 0.001 and restart cosine scheduler (7578 steps, 757 warmup)



{'eval_loss': 0.22311674058437347,
 'eval_accuracy': 0.9428440366972477,
 'eval_runtime': 1.1099,
 'eval_samples_per_second': 785.688,
 'eval_steps_per_second': 49.556,
 'epoch': 5.0}