<a href="https://colab.research.google.com/github/nirb28/llm/blob/main/distillation/data_prep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Language Model Distillation

<img src="https://arxiv.org/html/2402.13116v3/x2.png" width=600>

Model Distillation is the process of using a large foundation model/high parameter LLMs to create annotated data for a specific task. That data is then used to fine tune a lightweight language model on the same task, allowing the smaller parameter model to perform as well as the foundation model at a fraction of the cost, energy consumption, and time.

Per the paper [A Survey on Knowledge Distillation of Large Language Models](https://arxiv.org/pdf/2402.13116), "This process is akin to transferring the ‘knowledge’ of a highly skilled teacher to a student, wherein the student (e.g., open-source LLM) learns to mimic the performance characteristics of the teacher (e.g., proprietary LLM)."

Most tasks that LLMs are applied to don't utilize the entire capability and power of a full size foundation model, so *why not distill down your one specific application into its own model?*

**In this notebook we'll be:**
1. Using [Llama 3.1 405B](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B) to classify the sentiment of tweets, and
2. Use that dataset to train [Roberta-base](https://huggingface.co/FacebookAI/roberta-base) a 125 million parameter language model.

We end up with a model that performs with the same accuracy, at 0.03% of the size!

---
# Teacher Model Data Annotation

Dataset consists of tweets with labeled sentiments, [mteb/tweet_sentiment_extraction](https://huggingface.co/datasets/mteb/tweet_sentiment_extraction). We will be using the tweet texts along with prompting to generate the "knowledge" annotation. This data is what becomes the training data for distillation.

This would become simply fine-tuning/model training if the data is not generated by an LLM teacher! (i.e. human annotated data)

For the sake of demonstration, assuming this data is not annoted already.

### Importing Existing Dataset

In [None]:
# use hf datasets package to get the twitter sentiment data
from datasets import load_dataset
ds = load_dataset("mteb/tweet_sentiment_extraction")

In [None]:
ds['train'][0]

### Creating CSV

Converting a subset to a local csv for easy manipulation and sampling.

Train/split subset via HuggingFace dataset available [AdamLucek/twittersentiment-llama-3.1-405B-labels](https://huggingface.co/datasets/AdamLucek/twittersentiment-llama-3.1-405B-labels)

In [None]:
import csv
import os
from common.utils import get_project_root
def convert_to_csv(ds, start_index, end_index, output_folder):

    output_file = os.path.join(output_folder, f"twittersentiment_{start_index}_{end_index}.csv")

    with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = ['id', 'text', 'label', 'label_text']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for i in range(start_index, min(end_index, len(ds['train']))):
            row = ds['train'][i]
            writer.writerow(row)

    print(f"CSV file created: {output_file}")

# Usage
start_index = 5001
end_index = 6000
output_folder = str(get_project_root())+"/data/tweet_sentiment"
convert_to_csv(ds, start_index, end_index, output_folder)

### Cleaning the CSV

Simple script for dropping blank values that may break processing later on

In [None]:
import pandas as pd
from common.utils import get_project_root
# Load the CSV file
unclean_df = pd.read_csv(str(get_project_root())+'/data/tweet_sentiment/twittersentiment_5001_6000.csv')

# Remove rows where the 'text' label is blank
df_cleaned = unclean_df[unclean_df['text'].notna() & (unclean_df['text'].str.strip() != '')]

# Save the cleaned DataFrame to a new CSV file
df_cleaned.to_csv(str(get_project_root())+"/data/tweet_sentiment/twittersentiment_5001_6000_cleaned.csv", index=False)

---
# Setting Up Teacher LLM

[Llama 3.1 405B](https://ai.meta.com/blog/meta-llama-3-1/) explicitly states that *Our new model will enable the community to unlock new workflows, such as synthetic data generation and model distillation* in Meta's [release blog](https://ai.meta.com/blog/meta-llama-3-1/), so I wanted to use it as an example for distillation on this task, thus Llama 3.1 405B becomes the teacher model.

The cheapest inference API i could find is via [Fireworks.ai](https://fireworks.ai/) which, at the time of making this, offer 1M Token `Input/Output` at `$3/$3` respectively. We'll use their integration with LangChain to instantiate.

In [None]:
from langchain_openai import AzureChatOpenAI
from dotenv import load_dotenv

load_dotenv()
AZURE_OPENAI_API_KEY=os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT=os.getenv("AZURE_OPENAI_ENDPOINT")

llm = AzureChatOpenAI (
    api_version="2024-02-01",
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_key=AZURE_OPENAI_API_KEY,
    model="gpt-35-turbo",
    temperature=0.7
)


In [None]:
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os

load_dotenv()
GROQ_API_KEY=os.getenv("GROQ_API_KEY")
llm = ChatGroq(
    api_key=GROQ_API_KEY,
    model="llama3-70b-8192",
    temperature=0.7
)


### Prompting

To ensure we give our teacher model the best opportunity, we'll be employing two techniques in our classification prompt:
1. **Chain-of-Thought Reasoning:** Making the language model write a reasoning description to "think" through the problem before giving an answer
2. **Few-shot Prompting:** Providing robust examples about your expectations of both performance and format to better guide the LLM.

Examples taken from entries within the rows 7101-7200, which will not be used during training or for testing further on

In [None]:
tweet_sentiment_cot_prompt = """\
You are a highly qualified expert trained to annotate machine learning training data.
Your task is to briefly analyze the sentiment in the TEXT below from an social media manager perspective and then label it with only one the three labels:
positive, negative, neutral.
Base your label decision only on the TEXT and do not speculate e.g. based on prior knowledge about the context.
You first reason step by step about the correct label and then return your label.
You ALWAYS respond once in the following JSON format with brackets: {{"reason": "...", "label": "..."}}

Examples:
Text: Mode: Home Office
JSON: {{"reason": "The text is a factual statement about a work mode without expressing any emotion or opinion", "label": "neutral"}}
Text: oh oh oh are you offering to send ducks! I love love love confit duck
JSON: {{"reason": "The text expresses enthusiasm and love for confit duck, indicating a positive sentiment", "label": "positive"}}
Text: off to glue stuff onto poster
JSON: {{"reason": "The text is a simple statement of an action without any emotional context", "label": "neutral"}}
Text: Beautiful Day..takn it down twitters tell ALL mothers Happy Mothers Day
JSON: {{"reason": "The text describes a beautiful day and expresses positive wishes for Mother's Day", "label": "positive"}}
Text: Likewise. However, what was the comment about originally?
JSON: {{"reason": "The text is a neutral inquiry without expressing any particular sentiment", "label": "neutral"}}
Text: wished didnt spend money last night
JSON: {{"reason": "The text expresses regret about spending money, indicating a negative sentiment", "label": "negative"}}
Text: yo wake your **** up and go to work go get that paper u aint sick dont lie
JSON: {{"reason": "The text is aggressive and accusatory, suggesting a negative sentiment", "label": "negative"}}
Text: Such a beautiful morning
JSON: {{"reason": "The text expresses appreciation for the morning, indicating a positive sentiment", "label": "positive"}}
Text: Nooo...i forgot my calculator for physics oh well class is allmost over :3
JSON: {{"reason": "The text expresses initial disappointment about forgetting a calculator, indicating a negative sentiment", "label": "negative"}}
"""

and converting to an invokable chain via LangChain

In [None]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import JsonOutputParser

prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system", tweet_sentiment_cot_prompt
        ),
        (
            "human", "Your TEXT to analyse: {text}"
        )
    ]
)

chain = prompt | llm | JsonOutputParser()

In [None]:
# Example
#response = chain.invoke("Want to get a Blackberry but can`t afford it. Just watching the telly and relaxing. Hard sesion tomorrow.")
response = chain.invoke("But i dont mind the long line when theres a super cutie in front of me. Too bad he`s wearing a **** bracelet with a girls name on it")

response

### Annotation Script

Now using our model + prompt to generate the annotations and combine them with the CSV. Error handling is generally due to the filtering that Llama 3.1 405B has to not generate answers when presented with innapropriate content.

**NOTE:** This is roughly ~$10 of usage via API when ran over 5,000 examples. Run with caution! (when using fireworks API. I have pointed to Azure)

In [None]:
import json, csv

def process_csv(input_file, output_file):
    i = 0
    with open(input_file, 'r', newline='', encoding='utf-8') as infile, \
         open(output_file, 'w', newline='', encoding='utf-8') as outfile:

        reader = csv.DictReader(infile)
        fieldnames = reader.fieldnames + ['Llama_405B_reason', 'Llama_405B_label_text']

        writer = csv.DictWriter(outfile, fieldnames=fieldnames)
        writer.writeheader()

        for row in reader:
            try:
                # Invoke the chain with the text from the current row
                response = chain.invoke({"text": row['text']})
                result = json.loads(response) if isinstance(response, str) else response

                # Add new fields to the row
                row['Llama_405B_reason'] = result['reason']
                row['Llama_405B_label_text'] = result['label']

                # Write the updated row to the output file immediately
                writer.writerow(row)

                # Flush the write buffer to ensure data is written to disk
                outfile.flush()

                i+=1
                print(f"{i} - Processed and saved row with id: {row['id']}")

            except Exception as e:
                # Error handling
                print(f"Error processing row with id {row.get('id', 'unknown')}: {str(e)}")
                continue

    print(f"Processing completed. Output saved to: {output_file}")

# Usage
input_file = str(get_project_root())+"/data/tweet_sentiment/twittersentiment_5001_6000_cleaned.csv"
output_file = str(get_project_root())+"/data/tweet_sentiment/test.csv"

process_csv(input_file, output_file)