# Credits

Original Notebook [here](https://github.com/ShawhinT/YouTube-Blog/blob/main/LLMs/fine-tuning/ft-example.ipynb) 

# Import Dependencies

In [1]:
from datasets import load_dataset, DatasetDict, Dataset

from transformers import (
    AutoTokenizer,
    AutoConfig, 
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer)

from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import evaluate
import torch
import numpy as np

# Import Base Model

In [2]:
model_checkpoint = 'distilbert-base-uncased'

# Define label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative":0, "Positive":1}

# Generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.weight', 'pre_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
model

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

# Load dataset

In [4]:
# load dataset
dataset = load_dataset("shawhin/imdb-truncated")
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['label', 'text'],
        num_rows: 1000
    })
})

In [5]:
count = 0
for example in dataset["train"]:
    print(example)
    
    if count >= 1:
        break

    print()
    count += 1

{'label': 1, 'text': '. . . or type on a computer keyboard, they\'d probably give this eponymous film a rating of "10." After all, no elephants are shown being killed during the movie; it is not even implied that any are hurt. To the contrary, the master of ELEPHANT WALK, John Wiley (Peter Finch), complains that he cannot shoot any of the pachyderms--no matter how menacing--without a permit from the government (and his tone suggests such permits are not within the realm of probability). Furthermore, the elements conspire--in the form of an unusual drought and a human cholera epidemic--to leave the Wiley plantation house vulnerable to total destruction by the Elephant People (as the natives dub them) to close the story. If you happen to see the current release EARTH, you\'ll detect the Elephant People are faring less well today.'}

{'label': 1, 'text': "During 1933 this film had many cuts taken from it because it was very over the top for the story content and the fact that Lily Powers,

# Preprocess Data

In [6]:
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

# Add pad token if none exists
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

In [7]:
# Create tokenize function
def tokenize_function(examples):
    # Extract text
    text = examples["text"]

    # Tokenize and truncate text
    tokenizer.truncation_side = "left"
    tokenized_inputs = tokenizer(
        text,
        return_tensors="np",
        truncation=True,
        max_length=512
    )

    return tokenized_inputs

In [8]:
# Tokenize training and validation datasets
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 1000
    })
})

In [9]:
count = 0
for example in tokenized_dataset["train"]:
    print(example)
    
    if count >= 1:
        break

    print()
    count += 1

{'label': 1, 'text': '. . . or type on a computer keyboard, they\'d probably give this eponymous film a rating of "10." After all, no elephants are shown being killed during the movie; it is not even implied that any are hurt. To the contrary, the master of ELEPHANT WALK, John Wiley (Peter Finch), complains that he cannot shoot any of the pachyderms--no matter how menacing--without a permit from the government (and his tone suggests such permits are not within the realm of probability). Furthermore, the elements conspire--in the form of an unusual drought and a human cholera epidemic--to leave the Wiley plantation house vulnerable to total destruction by the Elephant People (as the natives dub them) to close the story. If you happen to see the current release EARTH, you\'ll detect the Elephant People are faring less well today.', 'input_ids': [101, 1012, 1012, 1012, 2030, 2828, 2006, 1037, 3274, 9019, 1010, 2027, 1005, 1040, 2763, 2507, 2023, 15248, 2143, 1037, 5790, 1997, 1000, 2184, 

In [10]:
# Create data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Evaluation metrics

In [11]:
# Import accuracy evaluation metric
accuracy = evaluate.load("accuracy")

# Define an evaluation function to pass into trainer later
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)

    return {"accuracy": accuracy.compute(predictions=predictions, 
                                          references=labels)}

# Apply Untrained Model to Text

In [12]:
# define list of examples
text_list = ["It was good.", "Not a fan, don't recommed.", "Better than the first one.", "This is not worth watching even once.", "This one is a pass."]

print("Untrained model predictions:")
print("----------------------------")
for text in text_list:
    # Tokenize text
    inputs = tokenizer.encode(text, return_tensors="pt")
    # Compute logits
    logits = model(inputs).logits
    # Convert logits to label
    predictions = torch.argmax(logits)

    print(text + " - " + id2label[predictions.tolist()])

Untrained model predictions:
----------------------------
It was good. - Negative
Not a fan, don't recommed. - Negative
Better than the first one. - Negative
This is not worth watching even once. - Negative
This one is a pass. - Negative


# Fine tuning with LoRa

In [13]:
peft_config = LoraConfig(
    task_type="SEQ_CLS", # Sequence classification
    r=4, # Intrinsic rank of trainable weight matrix
    lora_alpha=32, # This is like a learning rate
    lora_dropout=0.01, # Probablity of dropout
    target_modules = ['q_lin'] # We apply lora to query layer only
) 

In [14]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 628,994 || all params: 67,584,004 || trainable%: 0.9306847223789819


## Hyperparameters

In [15]:
lr = 1e-3 # Size of optimization step 
batch_size = 4 # Number of examples processed per optimziation step
num_epochs = 5 # Number of times model runs through training data

training_args = TrainingArguments(
    output_dir= model_checkpoint + "-lora-text-classification",
    learning_rate=lr,
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

## Train

In [16]:
# Create trainer object
trainer = Trainer(
    model=model, # Our peft model
    args=training_args, # Hyperparameters
    train_dataset=tokenized_dataset["train"], # Training data
    eval_dataset=tokenized_dataset["validation"], # Validation data
    tokenizer=tokenizer, # Define tokenizer
    data_collator=data_collator, # This will dynamically pad examples in each batch to be equal length
    compute_metrics=compute_metrics, # Evaluates model using compute_metrics() function from before
)

# Train model
trainer.train()


You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.355744,{'accuracy': 0.878}
2,0.434000,0.593929,{'accuracy': 0.862}
3,0.434000,0.568574,{'accuracy': 0.881}
4,0.182900,0.664931,{'accuracy': 0.881}
5,0.182900,0.688521,{'accuracy': 0.88}


Checkpoint destination directory distilbert-base-uncased-lora-text-classification/checkpoint-250 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=1250, training_loss=0.25873106002807617, metrics={'train_runtime': 8348.8409, 'train_samples_per_second': 0.599, 'train_steps_per_second': 0.15, 'total_flos': 556608875967840.0, 'train_loss': 0.25873106002807617, 'epoch': 5.0})

In [17]:
model.to("cpu") # Moving to cpu
print("Trained model predictions:")
print("--------------------------")
for text in text_list:
    inputs = tokenizer.encode(text, return_tensors="pt").to("cpu") # Moving to cpu

    logits = model(inputs).logits
    predictions = torch.max(logits,1).indices

    print(text + " - " + id2label[predictions.tolist()[0]])


Trained model predictions:
--------------------------
It was good. - Positive
Not a fan, don't recommed. - Negative
Better than the first one. - Positive
This is not worth watching even once. - Positive
This one is a pass. - Positive
