## <center>How to adapt a pre-trained model to do sentiment classification <br> using SpaRTA</center>

### Setup 

Specify your pre-trained model and output directory (for saving the adapter)

In [1]:
import os

In [2]:
# pre-trained base model 
model_id = 'google/gemma-2b' 

# dir path for saving SpaRTA adapter
home_dir = os.environ['HOME']
save_dir = os.path.join(home_dir, 'sparta_examples/output/classification_model/')

In [None]:
print(save_dir)

### Dataset

We load the SST-2 dataset from the GLUE benchmark. This dataset consists of movie reviews labeled as positive or negative. 

In [3]:
from datasets import load_dataset

In [4]:
# load task (classification) dataset
dataset = load_dataset('glue', 'sst2')
dataset = dataset.remove_columns('idx')
dataset = dataset.rename_column('label', 'labels')
del dataset['test']

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'labels'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'labels'],
        num_rows: 872
    })
})

We fix some spacing issues in the data by removing "broken" spaces inserted around punctuation or within contractions.

In [6]:
dataset['train'][1,8,11,19]['sentence']

['contains no wit , only labored gags ',
 "a depressed fifteen-year-old 's suicidal poetry ",
 "for those moviegoers who complain that ` they do n't make movies like they used to anymore ",
 "swimming is above all about a young woman 's face , and by casting an actress whose face projects that woman 's doubts and yearnings , it succeeds . "]

In [7]:
import re 

bad_spaces = re.compile(r" ([.,)!:;%]|'(?:s|t|re|ve|m|ll|d) |n't )|(\() |(s) (' )|( \$) (\d)")
    
def repair_spaces(example):
    example['sentence'] = bad_spaces.sub(r"\1\2\3\4\5\6", example['sentence']).strip()
    return example

In [8]:
dataset = dataset.map(repair_spaces)

In [9]:
dataset['train'][1,8,11,19]['sentence']

['contains no wit, only labored gags',
 "a depressed fifteen-year-old's suicidal poetry",
 "for those moviegoers who complain that ` they don't make movies like they used to anymore",
 "swimming is above all about a young woman's face, and by casting an actress whose face projects that woman's doubts and yearnings, it succeeds."]

Let's have a look at the labels: This is a binary classification task.  

In [10]:
dataset['train'][10]

{'sentence': 'goes to absurd lengths', 'labels': 0}

In [11]:
dataset['train'].features['labels']

ClassLabel(names=['negative', 'positive'])

In [12]:
id2label = {
    class_id: class_label
    for class_id, class_label in enumerate(
        dataset['train'].features['labels'].names
    )
}
id2label

{0: 'negative', 1: 'positive'}

In [13]:
num_classes = dataset['train'].features['labels'].num_classes 
num_classes

2

### Load pre-trained (base) model

We use as pre-trained model for our adapter the Gemma 2B (a pre-trained decoder-only transformer) model with a sequence classification head. 

In [14]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [15]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_id, 
    num_labels=num_classes,
    id2label=id2label,
    )

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = 'left' # makes sure last token is an actual text token, not a padding token
tokenizer.truncation_side = 'left'

Loading weights:   0%|          | 0/164 [00:00<?, ?it/s]

GemmaForSequenceClassification LOAD REPORT from: google/gemma-2b
Key          | Status  | 
-------------+---------+-
score.weight | MISSING | 

Notes:
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


In [16]:
model.num_labels

2

In [17]:
model.config.id2label

{0: 'negative', 1: 'positive'}

Since the classification head is newly initialized at random, we *save* the head initial state (weights) to have a record of all the parameters of the adapter's pre-trained base model before fine-tuning begins.

In [18]:
print('saving weight initialization for classification head:')
head_init = model.score.weight.data
head_init

saving weight initialization for classification head


tensor([[-0.0267, -0.0111,  0.0249,  ..., -0.0620, -0.0033,  0.0023],
        [-0.0193,  0.0204,  0.0153,  ...,  0.0135,  0.0015,  0.0271]],
       dtype=torch.bfloat16)

In [19]:
import torch

In [20]:
torch.save(head_init, 
           os.path.join(save_dir,'head_init.pt'))         

In [21]:
import safetensors

In [22]:
safetensors.torch.save_file(
    {'head': head_init.contiguous()}, 
    os.path.join(save_dir,'head_init.safetensors'))

In [None]:
# checking head weights were saved correctly

In [23]:
safetensors.torch.load_file(
    os.path.join(save_dir, 'head_init.safetensors'), 
    device='cpu')#['head']

{'head': tensor([[-0.0267, -0.0111,  0.0249,  ..., -0.0620, -0.0033,  0.0023],
         [-0.0193,  0.0204,  0.0153,  ...,  0.0135,  0.0015,  0.0271]],
        dtype=torch.bfloat16)}

In [24]:
torch.load(
    os.path.join(save_dir, 'head_init.pt'), 
    map_location='cpu', 
    weights_only=True)

tensor([[-0.0267, -0.0111,  0.0249,  ..., -0.0620, -0.0033,  0.0023],
        [-0.0193,  0.0204,  0.0153,  ...,  0.0135,  0.0015,  0.0271]],
       dtype=torch.bfloat16)

### Create SpaRTA adapter

Instead of full fine-tuning, we wrap the model with SpaRTA to significantly reduce the number of trainable parameters. We choose a sparsity = 99%, so only about 1% of the model parameters will be updated during training. This will also help prevent overfitting on our small training dataset, as we will see.

In [25]:
from peft_sparta import SpaRTA

In [26]:
model = SpaRTA(model, 0.99).to('cuda')

In [27]:
model

SparseModel(sparsity=0.99, frozen_modules=['embed_tokens', 'self_attn.q', 'self_attn.k', 'mlp', 'norm'], dropout_rate=0.0,
  GemmaForSequenceClassification(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNor

In [28]:
model.num_trainable_parameters()

Num trainable parameters: 25,067,043 (1.00021%)


### Training SpaRTA

We use the Hugging Face trainer to fine-tune our SpaRTA adater.

In [29]:
from transformers import Trainer, TrainingArguments

pre-tokenize the dataset before training for efficiency

In [30]:
def tokenize(examples):
    return tokenizer(examples['sentence'], truncation=True, padding='max_length', max_length=64)

In [31]:
tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns('sentence')

define accuracy as a metric to evaluate the performance of our classification model (on the evaluation dataset)

In [32]:
def classification_accuracy(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return {'accuracy': (preds == labels).mean()}

In [33]:
sft_config = TrainingArguments(
    output_dir=save_dir,
    num_train_epochs=1,
    learning_rate=2e-5,
    per_device_train_batch_size=40,
    gradient_accumulation_steps=1,
    weight_decay=0.02,
    logging_strategy="steps",
    logging_steps=100,
    logging_first_step=True,
    batch_eval_metrics = False, 
    eval_strategy="steps",
    eval_on_start=True,
    eval_steps=100,
    label_names=["labels"],
    per_device_eval_batch_size=40,
    gradient_checkpointing=False,
    save_strategy="no",
    load_best_model_at_end=False,
    report_to="none",
    push_to_hub=False,
    remove_unused_columns=False,
    )

In [None]:
# https://huggingface.co/docs/transformers/v5.1.0/en/main_classes/trainer#transformers.TrainingArguments
# sft_config

In [34]:
trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    args=sft_config,
    compute_metrics=classification_accuracy,
    )

In [35]:
trainer.evaluate()

{'eval_loss': 1.5605647563934326,
 'eval_model_preparation_time': 0.0032,
 'eval_accuracy': 0.5091743119266054,
 'eval_runtime': 2.111,
 'eval_samples_per_second': 413.072,
 'eval_steps_per_second': 10.422}

In [36]:
trainer.train()

Step,Training Loss,Validation Loss,Model Preparation Time,Accuracy
0,No log,1.560565,0.0032,0.509174
100,0.537524,0.184857,0.0032,0.942661
200,0.208059,0.158157,0.0032,0.955275
300,0.181113,0.142214,0.0032,0.955275
400,0.163779,0.141742,0.0032,0.959862
500,0.156539,0.138036,0.0032,0.955275
600,0.144659,0.130729,0.0032,0.955275
700,0.138549,0.132415,0.0032,0.961009
800,0.133481,0.139374,0.0032,0.955275
900,0.139922,0.138483,0.0032,0.956422


TrainOutput(global_step=1684, training_loss=0.16799183752644373, metrics={'train_runtime': 335.2536, 'train_samples_per_second': 200.89, 'train_steps_per_second': 5.023, 'total_flos': 0.0, 'train_loss': 0.16799183752644373, 'epoch': 1.0})

In [37]:
trainer.evaluate()

{'eval_loss': 0.1213647648692131,
 'eval_model_preparation_time': 0.0032,
 'eval_accuracy': 0.9701834862385321,
 'eval_runtime': 1.2715,
 'eval_samples_per_second': 685.79,
 'eval_steps_per_second': 17.302,
 'epoch': 1.0}

Note, how the *accuracy* of the model (on the evaluation split) goes from around 50.9% to 97.0%.

### Inference (classify) 

Let's test the SpaRTA adapted model on a sentence 

In [38]:
input_text = "It was a great movie" 
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=128).to("cuda")
with torch.no_grad():
    logits = model(**inputs).logits
pred_class_id = logits.argmax().item()
print('class id prediction:', pred_class_id)
dataset['train'].features['labels'].int2str(pred_class_id)

class id prediction: 1


'positive'

### Save the SpaRTA adapter to disk for later use

In [39]:
os.listdir(save_dir)

['head_init.safetensors', 'head_init.pt']

In [None]:
model.save(save_dir, merged=False)

In [42]:
os.listdir(save_dir)

['head_init.safetensors',
 'config.json',
 'sparse_deltas.safetensors',
 'head_init.pt']

### Reload the SpaRTA adapter 

Use can use the `SpaRTAforSequenceClassification` class to load the saved adapter and do inference (classification) using the `classify` and `decide_class` methods.

In [43]:
del model

In [44]:
from peft_sparta import SpaRTAforSequenceClassification

In [45]:
model = SpaRTAforSequenceClassification(
   adapter = save_dir,
   device = 'cuda')

Loading weights:   0%|          | 0/164 [00:00<?, ?it/s]

In [46]:
print(model)

(SpaRTA)ModelForSeqClassification(
	adapter = '/u/jriosal/sparta_examples/output/classification_model/'
	model = 'google/gemma-2b'
	id2label = {0: 'negative', 1: 'positive'}
)


Let's test the adapted model on a couple of sentences

In [47]:
sentences = ["The movie was great", 
             "I hate that movie"]

In [48]:
model.classify(sentences)

tensor([[0.0032, 0.9968],
        [0.9986, 0.0014]], device='cuda:0')

In [52]:
model.decide_class(sentences)

['positive', 'negative']

In this case, the adapted model predicts the correct sentiment for both of them. 

### Evaluation 

We can also evaluate the performance of the adapter in a labeled dataset with the `evaluate` method

In [53]:
eval_sentences = dataset['validation']['sentence']
eval_labels = dataset['validation']['labels']
model.evaluate(eval_sentences, eval_labels, batch_size=128)

{'loss': 0.1213341415475268,
 'accuracy': 0.9690366972477065,
 'confusion matrix': tensor([[416,  12],
         [ 15, 429]], dtype=torch.int32),
 'balanced accuracy': 0.9690893888473511,
 'MCC': 0.9380825757980347,
 'F1-score': 0.9694915413856506}