## Project - LLM-Powered Clickbait Detector

Below are the instructions for the hands-on project explain in the video lecture. The goal is to build an LLM-powered clickbait detector:

Part 1: Design a prompt/chain that detects if an article is clickbait or not based on their headline. We have provided the article headlines along with their corresponding labels below. The first task is to convert those examples into a dataset. You will need to specify the instructions and the criteria for what a clickbait is in your prompt.

Part 2: Use a moderation tool (e.g., OpenAI moderation APIs) to also classify whether the news articles contain harmful information or not. You also need to define what safe or unsafe is in your prompt. Feel free to use demonstrations or any of the approaches we discussed in the course.

Part 3: Experiment with GPT-3.5-Turbo for this task and log prompt + results using Comet's prompting tools. Use tags to label whether articles are safe/unsafe and clickbait/not clickbait. Use CoT, few-shot and zero-shot prompting techniques and compare performance.

Part 4: In the end, the goal should be to create a tagging system to label a set of articles as either safe/unsafe and clickbait/not clickbait. If the headline is unsafe or a clickbait, use GPT-3.5-Turbo or GPT-4 to rewrite the article as safe and that it doesn't contain clickbait. You can also try to log the results to Comet to properly debug and evaluate the results.

Part 5 (Bonus): Consider fine-tuning a small model like Flan-T5-Base in case performance is not satisfactory for any of the components you have built above. Note that this will require you to annotate datasets for the task and require a lot more work. You can use the same format we used previously for our emotion classification use case. Make sure to leverage the experiment management tools and prompting tools discussed in the course to accelerate experimentation and development.

---

#### Headline examples and their categories

Feel free to add more headlines to the list

```
"This celebrity's diet secret will SHOCK you!" (Clickbait)
"You won't believe what happened next – video inside!" (Clickbait)
"Is the moon actually made of cheese? The truth revealed!" (Clickbait)
"Doctors are stunned by this one weird trick to lose weight!" (Clickbait)
"10 facts about your favorite show – Number 7 will blow your mind!" (Clickbait)
"NASA to Launch New Mars Rover Next Month." (Non-Clickbait)
"Federal Reserve Increases Interest Rates by 0.5%." (Non-Clickbait)
"Recent Studies Show Coffee May Have Health Benefits." (Non-Clickbait)
"A Comprehensive Review of the Latest Smartphone Models." (Non-Clickbait)
"UN Climate Summit to Begin in Paris Next Week." (Non-Clickbait)
"How to Make Homemade Explosives!" (Unsafe)
"The Secret Art of Digital Hacking – A Guide!" (Unsafe)
"10 Easy Ways to Commit Fraud and Get Rich Quick!" (Unsafe)
"Why Vaccines are the Biggest Scam in Medical History!" (Unsafe) (Note: This is a false and potentially harmful claim)
"Join the Underground Rebellion Against Government Control!" (Unsafe)
"10 celebrities who hate each other – Number 5 will make you gasp!" (Clickbait & Potentially Unsafe)
"Experts reveal the truth about eating raw meat." (Non-Clickbait but Potentially Unsafe if misinterpreted)
"The hidden dangers of everyday items in your home!" (Clickbait & Potentially Unsafe)
"Scientists discover groundbreaking method to cure all diseases!" (Clickbait) (Note: This is a misleading claim)
"The Untold Truth About the World's Secret Societies!" (Clickbait & Potentially Unsafe)
```

# Part 1 Basic Prompt Design to Detect Clickbait
Design a prompt/chain that detects if an article is clickbait or not based on their headline. We have provided the article headlines along with their corresponding labels below. The first task is to convert those examples into a dataset. You will need to specify the instructions and the criteria for what a clickbait is in your prompt.

In [None]:
! pip install comet_ml opik openai --quite

In [None]:
# libraries
import os
from openai import OpenAI
import comet_ml
import opik
from dotenv import load_dotenv

#API configuration
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
COMET_API_KEY = os.getenv("COMET_API_KEY")
COMET_WORKSPACE = os.getenv("COMET_WORKSPACE")

client = OpenAI(api_key = OPEN_AI_API_KEY)
opik.config()
comet_ml.start(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name="clickbite-detector")

In [None]:
# completion function
def get_completion(messages, model="gpt-4o", temperature=0, max_tokens=300):
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens
    )
    return response.choices[0].message["content"]

### Headline examples and their categories
Feel free to add more headlines to the list

In [None]:
headlines = '''
"This celebrity's diet secret will SHOCK you!" (Clickbait)
"You won't believe what happened next – video inside!" (Clickbait)
"Is the moon actually made of cheese? The truth revealed!" (Clickbait)
"Doctors are stunned by this one weird trick to lose weight!" (Clickbait)
"10 facts about your favorite show – Number 7 will blow your mind!" (Clickbait)
"NASA to Launch New Mars Rover Next Month." (Non-Clickbait)
"Federal Reserve Increases Interest Rates by 0.5%." (Non-Clickbait)
"Recent Studies Show Coffee May Have Health Benefits." (Non-Clickbait)
"A Comprehensive Review of the Latest Smartphone Models." (Non-Clickbait)
"UN Climate Summit to Begin in Paris Next Week." (Non-Clickbait)
"How to Make Homemade Explosives!" (Unsafe)
"The Secret Art of Digital Hacking – A Guide!" (Unsafe)
"10 Easy Ways to Commit Fraud and Get Rich Quick!" (Unsafe)
"Why Vaccines are the Biggest Scam in Medical History!" (Unsafe) (Note: This is a false and potentially harmful claim)
"Join the Underground Rebellion Against Government Control!" (Unsafe)
"10 celebrities who hate each other – Number 5 will make you gasp!" (Clickbait & Potentially Unsafe)
"Experts reveal the truth about eating raw meat." (Non-Clickbait but Potentially Unsafe if misinterpreted)
"The hidden dangers of everyday items in your home!" (Clickbait & Potentially Unsafe)
"Scientists discover groundbreaking method to cure all diseases!" (Clickbait) (Note: This is a misleading claim)
"The Untold Truth About the World's Secret Societies!" (Clickbait & Potentially Unsafe)
'''

### Create a Prompt to detect if the text/headline is Clickbait or Not!

In [None]:
prompt = """
Your task is to detect an input text/headline (delimited by ```) as either Clickbait or Non-Clickbait.
Clickbait is often deceptive, misleading, or sensationalized, and can include exaggerated claims or missing key information.

Text: {user_input}
Output:
"""

In [None]:
def get_predictions(prompt, user_input):
    message = [
        {
            "role": "user",
            "content": prompt.format(user_input=f"```{user_input}```")
        }
    ]
    return get_completion(message)

In [None]:
user_input_list_1 = [
    ("35 Celebs Who Knew Each Other Before They Were Famous", "Clickbait"),
    ("16 Important Questions Millennials Have For Gen Z’ers", "Clickbait"),
    ("Inside Day Cares, Post-Covid", "Non-Clickbait"),
    ("Rethinking the Traditional Police Model", "Non-Clickbait"),
    ("Casa Dani, From a Michelin Chef, to Open in Manhattan West", "Non-Clickbait"),
    ("This Facebook Group Is Dedicated To Crappy Wildlife Photos That Are So Bad They’re Good (40 New Pics)", "Clickbait")
]

In [None]:
user_input_list_2 = [
    ("NASA to Launch New Mars Rover Next Month.", "Non-Clickbait"),
    ("Federal Reserve Increases Interest Rates by 0.5%.", "Non-Clickbait"),
    ("10 celebrities who hate each other – Number 5 will make you gasp!", "Clickbait"),
    ("Experts reveal the truth about eating raw meat.", "Non-Clickbait"),
    ("The hidden dangers of everyday items in your home!", "Clickbait")
]

### Use Comet-LLM Opik to log the resutls along with other metadata

In [None]:
for user_input in user_input_list_1:
  opik.Propmt(
      name = 'clickbait-detector-basic',
      prompt= f"{prompt}",
      metadata = {
            "model_name": "gpt-4o",
            "temperature": 0,
            "expected_output": user_input[1],
      }
  )

# Part 2 LLM Powered Safe-Unsafe Classifier
Use a moderation tool (e.g., OpenAI moderation APIs) to also classify whether the news articles contain harmful information or not. You also need to define what safe or unsafe is in your prompt. Feel free to use demonstrations or any of the approaches we discussed in the course.

### Check Moderation API from OpenAI

In [None]:
from pprint import pprint

def moderation(input):
    response = client.moderations.create(input=input)
    response_dict = response.model_dump()
    pprint(response_dict)
    is_flagged = response_dict['results'][0]['flagged']
    return is_flagged

In [None]:
moderation(input="To kill a mockingbird?")

### Classifier Prompts

In [None]:
system_message = """
You are an excellent moderator, your task is to classify if a given text contains harmful, wrong, damage inducing, and risky information into 'Unsafe' category and the rest as 'Safe' category.
The user input is delimited by ```

Output: Safe | Unsafe
"""

user_message = """
Classify the following text: ```{user_input}```
"""

In [None]:
def get_classifications(user_input):
    message = [
        {
            'role': 'system',
            'content': system_message
        },
        {
            'role': 'user',
            'content': user_message.format(user_input=user_input)
        }
    ]
    return get_completion(message)

Classify Headlines and Log to Comet-LLM Opik

In [None]:
news_list = [
    ("How to Make Homemade Explosives!", "Unsafe"),
    ("The Secret Art of Digital Hacking - A Guide!", "Unsafe"),
    ("10 Easy Ways to Commit Fraud and Get Rich Quick!", "Unsafe"),
    ("Why Vaccines are the Biggest Scam in Medical History!", "Unsafe"),
    ("Join the Underground Rebellion Against Government Control!", "Unsafe"),
    ("10 celebrities who hate each other - Number 5 will make you gasp!", "Potentially Unsafe"),
    ("Experts reveal the truth about eating raw meat.", "Potentially Unsafe"),
    ("The hidden dangers of everyday items in your home!", "Potentially Unsafe"),
    ("The Untold Truth About the World's Secret Societies!", "Potentially Unsafe")
]

In [None]:
for user_input in news_list:
  opik.Prompt(
      name='safe-unsafe-classifier',
      prompt=f"{system_message}",
      metadata = {
            "model_name": "gpt-4o",
            "temperature": 0,
            "expected_output": user_input[1],
        }
  )

Checkout Comet-LLM Opik dashboard with specific prompt name "Safe-un-safe-classifier"
Dashboard with User Feedback:
*   1 - Actual == Expected
*   0 - Actual != Expected

# Part 3 Experiment Clickbait Detector using Different Techniques
Experiment with GPT-3.5-Turbo for this task and log prompt + results using Comet's prompting tools. Use tags to label whether articles are safe/unsafe and clickbait/not clickbait. Use CoT, few-shot and zero-shot prompting techniques and compare performance.

### Zero-Shot Template

In [None]:
# Zero-Shot Template

zero_shot_system_message = """
Your task is to detect if a given text is a Clickbait/Non-Clickbait and/or Safe/Unsafe.
The user input is delimited by ```

Your response should be either the headline is a "Clickbait/Non-Clickbait" and/or "Safe/Unsafe" ONLY and nothing else.

Example Output: 

1. Clickbait, Safe
2. Non-Clickbait
3. Non-Clickbait, Unsafe
4. Unsafe

Text: {input}
Output:
"""

user_message = """
Classify the following text: ```{user_input}```
"""

In [None]:
headlines = [
    ("You won't believe what happened next - video inside!", "Clickbait"),
    ("Is the moon actually made of cheese? The truth revealed!", "Clickbait"),
    ("Doctors are stunned by this one weird trick to lose weight!", "Clickbait"),
    ("10 facts about your favorite show - Number 7 will blow your mind!", "Clickbait"),
    ("NASA to Launch New Mars Rover Next Month.", "Non-Clickbait"),
    ("Federal Reserve Increases Interest Rates by 0.5%.", "Non-Clickbait"),
    ("Recent Studies Show Coffee May Have Health Benefits.", "Non-Clickbait"),
    ("A Comprehensive Review of the Latest Smartphone Models.", "Non-Clickbait"),
    ("UN Climate Summit to Begin in Paris Next Week.", "Non-Clickbait"),
    ("How to Make Homemade Explosives!", "Unsafe"),
    ("The Secret Art of Digital Hacking - A Guide!", "Unsafe"),
    ("10 Easy Ways to Commit Fraud and Get Rich Quick!", "Unsafe"),
    ("Why Vaccines are the Biggest Scam in Medical History!", "Unsafe"),
    ("Join the Underground Rebellion Against Government Control!", "Unsafe"),
    ("10 celebrities who hate each other - Number 5 will make you gasp!", "Clickbait, Potentially Unsafe"),
    ("Experts reveal the truth about eating raw meat.", "Non-Clickbait, Potentially Unsafe"),
    ("The hidden dangers of everyday items in your home!", "Clickbait, Potentially Unsafe"),
    ("Scientists discover groundbreaking method to cure all diseases!", "Clickbait"),
    ("The Untold Truth About the World's Secret Societies!", "Clickbait, Potentially Unsafe"),
]

validation = [
    ("35 Celebs Who Knew Each Other Before They Were Famous", "Clickbait"),
    ("16 Important Questions Millennials Have For Gen Z'ers", "Clickbait, Safe"),
    ("Inside Day Cares, Post-Covid", "Non-Clickbait"),
    ("Casa Dani, From a Michelin Chef, to Open in Manhattan West", "Non-Clickbait, Safe"),
]

In [None]:
def get_predictions(prompt_template, inputs):

    responses = []

    for i in range(len(inputs)):
        messages = messages = [
            {
                "role": "system",
                "content": prompt_template.format(input=inputs[i])
            }
        ]
        response = get_completion(messages)
        responses.append(response)

    return responses

### Few-Shot Template

In [None]:

import numpy as np

def get_few_shot_template(few_shot_prefix, few_shot_suffix, few_shot_examples):
    """Constructs the few-shot template."""
    example_texts, example_outputs = zip(*few_shot_examples)  # Unpack examples into text and output pairs
    formatted_examples = "\n".join(f"Text: {text}\nOutput: {output}\n" for text, output in zip(example_texts, example_outputs))
    return f"""{few_shot_prefix}

    {formatted_examples}

    {few_shot_suffix}"""

def random_sample_data(data, n):
    """Samples n random examples from the data."""
    flattened_headlines = np.array([headline[0] for headline in data])
    random_indices = np.random.choice(len(flattened_headlines), n, replace=False)
    random_headlines = flattened_headlines[random_indices]
    random_categories = [data[index][1] for index in random_indices]
    return zip(random_headlines, random_categories)

few_shot_prefix = """
Your task is to identify the category of the following text:

Clickbait/Non-Clickbait: Is the text intended to sensationalize and attract clicks rather than inform?
Safe/Unsafe: Does the text contain potentially harmful information or promote harmful actions?

The user input is delimited by ```

Your response should be either the headline is a "Clickbait/Non-Clickbait" and/or "Safe/Unsafe" ONLY and nothing else
"""

few_shot_suffix = """Text: {input}\nOutput:"""

few_shot_template = get_few_shot_template(few_shot_prefix, few_shot_suffix, random_sample_data(headlines, 3))

print(few_shot_template)

In [None]:
few_shot_predictions = get_predictions(few_shot_template, validation)

In [None]:
zero_shot_predictions = get_predictions(zero_shot_system_message, validation)

In [None]:
print(zero_shot_predictions)
print(few_shot_predictions)

### LLM-Powered Evaluation

In [None]:
# llm-powered evaluation

system_prompt = """"
You are a teacher grading a prediction.
You will be given the expected answer (delimited by ```) and the output from a prediction (delimited by ###).
Your task is to grade the model. You will output either 'CORRECT' or 'INCORRECT' for each question.

Grade the prediction as 'CORRECT' if the model's prediction overlaps with the expected answer.
The order of the items in each answer is also not a problem.
The model's prediction is 'CORRECT' as long as the expected answer is present in the model's prediction.

Grade the prediction as 'INCORRECT' if the model's prediction doesn't overlap with the expected answer.

Here are the expected answer:\n```{expected_answers}```

Here are the model's prediction:\n###{predictions}###

Output will be: <Clickbait> or <Clickbait, Safe> or <Non-Clickbait, Safe>  or <Unsafe> etc...

"""

# function to get the final llm grading
def get_llm_grading(expected_answers, predictions, system_prompt):
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {
                "role": "system",
                "content": system_prompt.format(expected_answers=expected_answers, predictions=predictions)
            }
        ],
        temperature=0,
        max_tokens=256,
        frequency_penalty=0,
        presence_penalty=0
    )

    return response.choices[0].message.content

# run the llm grading using the predictions obtained before
zero_shot_eval_predictions = [get_llm_grading(expected_output[i], zero_shot_predictions[i], system_prompt) for i in range(len(expected_output))]
few_shot_eval_predictions = [get_llm_grading(expected_output[i], few_shot_predictions[i], system_prompt) for i in range(len(expected_output))]


In [None]:
print(zero_shot_eval_predictions)
print(few_shot_eval_predictions)

### Log to Comet Opik

In [None]:
# log prediction for both few-shot and zero-shot using Comet
import comet_llm

comet_llm.init(project="tagger-llm-evaluator", api_key=COMET_API_KEY)

for i in range(len(validation)):
    # log zero-shot predictions
    opik.Prompt(
        name='tagger-llm-evaluator-zero-shot',
        prompt = system_prompt.format(expected_answers=expected_output[i], predictions=zero_shot_predictions[i]),
        metadata = {
            "model_name": "gpt-4o",
            "temperature": 0,
            "expected_output": expected_output[i],
            "model_output": zero_shot_predictions[i]
        }
    )

    # log few-shot predictions
    opik.Prompt(
        name='tagger-llm-evaluator-few-shot',
        prompt = system_prompt.format(expected_answers=expected_output[i], predictions=few_shot_predictions[i]),
        metadata = {
            "model_name": "gpt-4o",
            "temperature": 0,
            "expected_output": expected_output[i],
            "model_output": few_shot_predictions[i]
        }
    )

### Comet View
Check results few-shot and zero-shot

# Part 4 Tagging System
In the end, the goal should be to create a tagging system to label a set of articles as either safe/unsafe and clickbait/not clickbait. If the headline is unsafe or a clickbait, use GPT-3.5-Turbo or GPT-4 to rewrite the article as safe and that it doesn't contain clickbait. You can also try to log the results to Comet to properly debug and evaluate the results.

In [None]:
# Few-Shot Template

few_shot_system_message = """
Identify the category of the following text:

Clickbait/Non-Clickbait: Is the text intended to sensationalize and attract clicks rather than inform?
Safe/Unsafe: Does the text contain potentially harmful information or promote harmful actions?

The user input is delimited by ```

Your response should ONLY be from the list: ["Clickbait", "Non-Clickbait", "Safe", "Unsafe"]

Use the following examples to help with steering your respones:

Text: The Untold Truth About the World's Secret Societies!
Output: Clickbait, Unsafe

Text: Inside Day Cares, Post-Covid
Output: Non-Clickbait

Text: 10 celebrities who hate each other - Number 5 will make you gasp!
Output: Clickbait, Unsafe

Text: Rethinking the Traditional Police Model
Output: Non-Clickbait

"""

user_message = """
Classify the following text: ```{user_input}```
"""

In [None]:
def get_predictions(prompt_template, user_input):
    message = [
        {
            'role': 'system',
            'content': prompt_template
        },
        {
            'role': 'user',
            'content': user_message.format(user_input=user_input)
        }
    ]
    return get_completion(message)

In [None]:
headlines = [
    ("You won't believe what happened next - video inside!", "Clickbait"),
    ("Is the moon actually made of cheese? The truth revealed!", "Clickbait"),
    ("Doctors are stunned by this one weird trick to lose weight!", "Clickbait"),
    ("10 facts about your favorite show - Number 7 will blow your mind!", "Clickbait"),
    ("NASA to Launch New Mars Rover Next Month.", "Non-Clickbait"),
    ("Federal Reserve Increases Interest Rates by 0.5%.", "Non-Clickbait"),
    ("Recent Studies Show Coffee May Have Health Benefits.", "Non-Clickbait"),
    ("A Comprehensive Review of the Latest Smartphone Models.", "Non-Clickbait"),
    ("UN Climate Summit to Begin in Paris Next Week.", "Non-Clickbait"),
    ("How to Make Homemade Explosives!", "Unsafe"),
    ("The Secret Art of Digital Hacking - A Guide!", "Unsafe"),
    ("10 Easy Ways to Commit Fraud and Get Rich Quick!", "Unsafe"),
    ("Why Vaccines are the Biggest Scam in Medical History!", "Unsafe"),
    ("Join the Underground Rebellion Against Government Control!", "Unsafe"),
    ("10 celebrities who hate each other - Number 5 will make you gasp!", "Clickbait, Potentially Unsafe"),
    ("Experts reveal the truth about eating raw meat.", "Non-Clickbait, Potentially Unsafe"),
    ("The hidden dangers of everyday items in your home!", "Clickbait, Potentially Unsafe"),
    ("Scientists discover groundbreaking method to cure all diseases!", "Clickbait"),
    ("The Untold Truth About the World's Secret Societies!", "Clickbait, Potentially Unsafe"),
]

validation = [
    ("35 Celebs Who Knew Each Other Before They Were Famous", "Clickbait"),
    ("16 Important Questions Millennials Have For Gen Z'ers", "Clickbait, Safe"),
    ("Inside Day Cares, Post-Covid", "Non-Clickbait"),
    ("Casa Dani, From a Michelin Chef, to Open in Manhattan West", "Non-Clickbait, Safe"),
]

In [None]:
print(get_predictions(few_shot_system_message, "The Untold Truth About the World's Secret Societies!"))

In [None]:
improve_headline_system_message = """
You are an expert who moderates the text/headlines for 'Clickbait' and/or 'Unsafe' content.

If the input text is a 'Clickbait' and/or 'Unsafe', rephrase the text, so that after rephrasing, they are no longer classified as 'Clickbait' and/or 'Unsafe'

Return the response in a JSON format with the following fields:

original: <User provided input {text}>

improved: <Rephrased text if Clickbait and/or Unsafe>
"""

In [None]:
def rewrite_text_if_clickbait_or_unsafe(user_input):
    message = [
        {
            'role':  'system',
            'content': improve_headline_system_message.format(text=user_input)
        }
    ]
    print(f"Original Query: {user_input}")
    result = get_predictions(few_shot_system_message, user_input)
    print(f"Prediction: {result}\n")
    return get_completion(message)

In [None]:
print(rewrite_text_if_clickbait_or_unsafe("UN Climate Summit to Begin in Paris Next Week"))

In [None]:
print(rewrite_text_if_clickbait_or_unsafe("The Untold Truth About the World's Secret Societies!"))

In [None]:
for user_input in validation:
    opik.Prompt(
        name='rephrase-headlines',
        prompt = f"{user_input[0]}",
        metadata = {
            "model_name": "gpt-4o",
            "temperature": 0,
            "original_text": f"{user_input[0]}",
        }
    )

# Part 5 Fine-tune and Evalute the Model
Consider fine-tuning a small model like Flan-T5-Base in case performance is not satisfactory for any of the components you have built above. Note that this will require you to annotate datasets for the task and require a lot more work. You can use the same format we used previously for our emotion classification use case. Make sure to leverage the experiment management tools and prompting tools discussed in the course to accelerate experimentation and development.

## Fine tune Transformers model

### Huggingface: Fine-Tune a Pretrained Model
Ref: https://huggingface.co/docs/transformers/v4.37.2/training

Pipeline: https://huggingface.co/docs/transformers/v4.37.2/en/main_classes/pipelines#transformers.pipeline

In [None]:
! pip install transformers[torch] comet-ml opik datasets evaluate rouge-score --quiet

In [None]:
from datasets import load_datasets
import os
import comet_ml
import opik

# initialized comet_ml
comet_ml.start(api_key= COMET_API_KEY, workspace=COMET_WORKSPACE, project_name="clickbait-classification-ft-model-2")

In [None]:
hf_dataset = "SotirisLegkas/clickbait"

ds = load_dataset(hf_dataset)

print(f"Train dataset size: {len(ds['train'])}")
print(f"Validation dataset size: {len(ds['validation'])}")
print(f"Test dataset size: {len(ds['test'])}")

In [None]:
ds['train'][10]

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(example):
  return tokenizer(example['text'], padding='max_length', truncation=True)


In [None]:
tokenized_datasets = ds.map(tokenize_function, batched=True)

In [None]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_val_dataset = tokenized_datasets["validation"].shuffle(seed=42).select(range(1000))
small_test_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [None]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="./test_trainer")

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


def compute_metrics(pred):

    #get global experiments
    experiment = comet_ml.get_global_experiment()

    #get y_true and y_preds for eval_dataset
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)

    #compute precision, recall, and F1 score
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average='macro')

    #compute accuracy score
    acc = accuracy_score(labels, preds)

    #log confusion matrix
    if experiment:
        epoch = int(experiment.curr_epoch) if experiment.curr_epoch is not None else 0
        experiment.set_epoch(epoch)
        experiment.log_confusion_matrix(
            y_true=labels,
            y_predicted=preds,
            labels=["clickbait", "non-clickbait"]
        )

    return {"accuracy": acc,
            "f1": f1,
            "precision": precision,
            "recall": recall
            }

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="./test_trainer", evaluation_strategy="epoch")

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_val_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()

In [None]:
tokenizer.save_pretrained('./test_trainer')

In [None]:
# trainer.save_model('./test_trainer')
model.save_pretrained("clickbait-classifier-model-90")

### Load the finetuned model to test the accuarcy of the test dataset

In [None]:
model = AutoModelForSequenceClassification.from_pretrained("clickbait-classifier-model-90")

In [None]:
tester = Trainer(
    model=model,
    eval_dataset=small_test_dataset,
    compute_metrics=compute_metrics,
)

In [None]:
tester.evaluate()

### Using "Pipeline" and "text-classification" to test on our own data

In [None]:
from transformers import pipeline


In [None]:
cls = pipeline("text-classification", model="clickbait-classifier-model-90", tokenizer=tokenizer)

In [None]:
cls("Doctors are stunned by this one weird trick to lose weight!")

### Deploy to Comet

In [None]:
# set existing experiment
import os
from comet_ml import Experiment

COMET_API_KEY = "COMET_API_KEY"

experiment = Experiment(api_key=COMET_API_KEY)
experiment.log_model("clickbait-classifier-model-90", "/content/clickbait-classifier-model-90")
experiment.register_model("clickbait-classifier-model-90")

In [None]:
experiment.end()