<a target="_blank" href="https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/llmu/Fine_Tuning_for_Chat.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Fine-Tuning for Chat

Our ready-to-use large language models, such as [Command](https://cohere.com/models/command), are very good at producing responses to natural language prompts. However, there are many cases in which getting the best model performance requires performing an additional round of training on custom user data. Creating a custom model using this process is called **fine-tuning**.

Fine-tuning is recommended when you want to teach the model a new task, or leverage your company's unique knowledge base. Fine-tuning models is also helpful for generating a specific writing style or format, or leveraging a new data type.

In this notebook, you will fine-tune a chatbot on custom conversational data to improve its performance at a specific task.

_Read the [accompanying blog post here](https://docs.cohere.com/docs/fine-tuning-for-chat)._

## Overview

We'll do the following steps:
- **Step 1: Prepare the Dataset** - Download the dataset, select a subset, and prepare it for the Chat endpoint.
- **Step 2: Fine-Tune the Model** - Kick off a fine-tuning job, and confirm when the model has completed training.
- **Step 3: Use/Evaluate the Fine-Tuned Model** - Evaluate the fine-tuned model's performance on the test dataset, and confirm it is a competent participant in multi-turn conversations.

## Setup

We'll start by installing the tools we'll need and then importing them.

In [1]:
! pip install cohere jsonlines -q

Fill in your Cohere API key in the next cell. To do this, begin by [signing up to Cohere](https://os.cohere.ai/) (for free!) if you haven't yet. Then get your API key [here](https://dashboard.cohere.com/api-keys).

In [2]:
import cohere
import os
import json
import jsonlines

co = cohere.ClientV2("COHERE_API_KEY") # Get your free API key: https://dashboard.cohere.com/api-keys

## Step 1: Prepare and Validate the Dataset


### Download the dataset

We will work with the [CoEdIT dataset](https://huggingface.co/datasets/grammarly/coedit) of text editing examples (Raheja, et al). In each example, the user asks a writing assistant to rewrite text to suit a specific task (editing fluency, coherence, clarity, or style) and receives a response. 

In [4]:
# Download the dataset
from datasets import load_dataset

! wget "https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl"

--2024-08-01 14:26:26--  https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl
Resolving huggingface.co (huggingface.co)... 2600:9000:2668:bc00:17:b174:6d00:93a1, 2600:9000:2668:1800:17:b174:6d00:93a1, 2600:9000:2668:4c00:17:b174:6d00:93a1, ...
Connecting to huggingface.co (huggingface.co)|2600:9000:2668:bc00:17:b174:6d00:93a1|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/30/91/3091c2c741f77a2f5aa8986b13e4fb2c3658ab3ebc30ecaa5f6890e60939bdf9/2913249158d6a178dc638e870212ff8a432d128eb6b4bdbe969ee805e6063ce3?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train.jsonl%3B+filename%3D%22train.jsonl%22%3B&Expires=1722752787&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjc1Mjc4N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zMC85MS8zMDkxYzJjNzQxZjc3YTJmNWFhODk4NmIxM2U0ZmIyYzM2NThhYjNlYmMzMGVjYWE1ZjY4OTBlNjA5MzliZGY5LzI5MTMy

### Get a subset of the dataset

Instead of using the full dataset, we will use a subset focused on making text coherent: 927 total conversations.

In [8]:
# we will use subset of the dataset focused on making text more coherent
phrase = "coherent"

# instantiate python list where we will store correct subset of dataset
dataset_list = []

# create subset of dataset
with jsonlines.open('train.jsonl') as f:
    for line in f.iter():
        if phrase in line['src'].split(":")[0]:
            dataset_list.append(line)

# Split data into training and test
dataset_list_train = dataset_list[:800]
dataset_list_test = dataset_list[800:]

print("Total number of examples:", len(dataset_list))
print("Number of examples in training set:", len(dataset_list_train))
print("Number of examples in the test set:", len(dataset_list_test))

Total number of examples: 927
Number of examples in training set: 800
Number of examples in the test set: 127


In [9]:
# we will use subset of the dataset focused on making text more coherent
phrase = "coherent"

# instantiate python list where we will store correct subset of dataset
dataset_list = []

# create subset of dataset
with jsonlines.open('train.jsonl') as f:
    for line in f.iter():
        if phrase in line['src'].split(":")[0]:
            dataset_list.append(line)

# Split data into training and test
dataset_list_train = dataset_list[:800]
dataset_list_test = dataset_list[800:]

print("Total number of examples:", len(dataset_list))
print("Number of examples in training set:", len(dataset_list_train))
print("Number of examples in the test set:", len(dataset_list_test))

Total number of examples: 927
Number of examples in training set: 800
Number of examples in the test set: 127


### Preview the dataset

We will use the `src` and `tgt` fields from each example, which correspond to the user’s prompt and the writing assistant’s response, respectively.

In [12]:
# print the first ten prompts and corresponding responses
for item in dataset_list_train[:10]:
    print(item["src"])
    print(item["tgt"])
    print("-"*50)

Make the text coherent: The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Bank, by developing and financing Bank's portfolio of and strengthening the bank's funding base.
The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Union Bank, by developing and financing its portfolio of and strengthening the bank's funding base.
--------------------------------------------------
Make the text coherent: It was not illegal under international law ; captured foreign sailors were released. Confederates went to prison camps.
It was not illegal under international law ; captured foreign sailors were released, while Confederates went to prison camps.
--------------------------------------------------
Make the text coherent: The Union blockade was a powerful weapon that eventually ruined the Southern econom

### Prepare the dataset for Cohere's Chat endpoint

To format the dataset for the Chat endpoint, we create a `.jsonl` where each JSON object is a conversation containing a series of messages.
- A `System` message in the beginning that guides the whole conversation
- Multiple pairs of `User` and `Chatbot` messages, representing the conversation that takes place between a human user and a chatbot

In [14]:
# arranges the data to suit Cohere's format
def create_chat_ft_data(system_message, user_message, chatbot_message):
    formatted_data = {
        "messages": [
            {
                "role": "System",
                "content": system_message
            },
            {
                "role": "User",
                "content": user_message
            },
            {
                "role": "Chatbot",
                "content": chatbot_message
            }
        ]
    }

    return formatted_data

system_message = "You are a writing assistant that helps the user write coherent text."

# creates jsonl file from list of examples
def create_jsonl_from_list(file_name, dataset_segment, system_message):
    path = f'{file_name}.jsonl'
    if not os.path.isfile(path):
        with open(path, 'w+') as file:
            for item in dataset_segment:
                user_message = item["src"]
                chatbot_message = item["tgt"]
                formatted_data = create_chat_ft_data(system_message, user_message, chatbot_message)
                file.write(json.dumps(formatted_data) + '\n')
            file.close()

# Create training jsonl file
file_name = "coedit_coherence_train"
create_jsonl_from_list(file_name, dataset_list_train, system_message)

# List the first 3 items in the JSONL file
with jsonlines.open(f'{file_name}.jsonl') as f:
    [print(line) for _, line in zip(range(3), f)]

{'messages': [{'role': 'System', 'content': 'You are a writing assistant that helps the user write coherent text.'}, {'role': 'User', 'content': "Make the text coherent: The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Bank, by developing and financing Bank's portfolio of and strengthening the bank's funding base."}, {'role': 'Chatbot', 'content': "The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Union Bank, by developing and financing its portfolio of and strengthening the bank's funding base."}]}
{'messages': [{'role': 'System', 'content': 'You are a writing assistant that helps the user write coherent text.'}, {'role': 'User', 'content': 'Make the text coherent: It was not illegal under international law ; captured foreign sailors were released. Confederates went to prison camps.'}, {

## Step 2: Fine-Tune the Model

We kick off a fine-tuning job by navigating to the [fine-tuning tab of the Dashboard](https://dashboard.cohere.com/fine-tuning).  Under "Chat", click on "Create a Chat model".

<img src="https://files.readme.io/48dad78-cohere_dashboard.png">

Next, upload the `.jsonl` file you just created as the training set by clicking on the "TRAINING SET" button. When ready, click on "Review data" to proceed to the next step.

<img src="https://files.readme.io/82e3691-image_2.png">

Then, you'll see a preview of how the model will ingest your data. If anything is wrong with the data, the page will also provide suggested changes to fix the training file. Otherwise, if everything looks good, you can proceed to the next step.

<img src="https://files.readme.io/fbce852-image_3.png">

Finally, you'll see an estimated cost of fine-tuning, followed by a page where you'll provide a nickname to your model. We used `coedit-coherence-ft` as the nickname for our model. This page also allows you to provide custom values for the hyperparameters used during training, but we'll keep them at the default values for now. 

<img src="https://files.readme.io/801e93a-name_model.png">

Once you have filled in a name, click on "Start training" to kick off the fine-tuning process. This will navigate you to a page where you can monitor the status of the model. A model that has finished fine-tuning will show the status as `READY`.

<img src="https://files.readme.io/dd0d48b-ready_model.png">

## Step 3: Use/Evaluate the Fine-Tuned Model



Once the model has completed the fine-tuning process, it’s time to evaluate its performance. 


### With Test Data

When you're ready to use the fine-tuned model, navigate to the API tab. There, you'll see the model ID that you should use when calling `co.chat()`.

<img src="https://files.readme.io/82c726e-get_model_id.png">

In the following code, we supply the first three messages from the test dataset to both the pre-trained and fine-tuned models for comparison.

In [46]:
for item in dataset_list_test[:1]:
    # User prompt
    user_message = item["src"]
    # Desired/target response from dataset
    tgt_message = item["tgt"]

    # Get default model response
    response_pretrained=co.chat(
        model="command-r-plus",
        messages=[cohere.UserMessage(content=system_message),
                  cohere.UserMessage(content=user_message)],
        )

    # Get fine-tuned model response
    response_finetuned = co.chat(
        model="4708865e-3870-42bf-99fa-ffe84e81fd5f-ft",
        messages=[cohere.UserMessage(content=system_message),
                  cohere.UserMessage(content=user_message)],
        
        )

    print(f"User: {user_message}","\n-----")
    print(f"Desired response: {tgt_message}","\n-----")
    print(f"Default model's response: {response_pretrained.message.content[0]['text']}","\n-----")
    print(f"Fine-tuned model's response: {response_finetuned.message.content[0]['text']}")


    print("-"*100,"\n\n")

User: Make the text more coherent: We do know that at the end of the Muromachi period it stopped appearing in written records. That Muromachi burned down many times, the last we know of in 1405. 
-----
Desired response: We do know that at the end of the Muromachi period it stopped appearing in written records and that it burned down many times, the last we know of in 1405. 
-----
Default model's response: We do know that towards the end of the Muromachi period, it stopped appearing in written records. Muromachi burned down several times, the last major fire being in 1405. This could have contributed to the lack of written records and the subsequent mystery surrounding the topic. It is intriguing to speculate on the reasons for this disappearance and the potential impact on historical understanding. 
-----
Fine-tuned model's response: We do know that at the end of the Muromachi period it stopped appearing in written records, and that it burned down many times, the last we know of in 140

In this example, both models provide reasonable answers that are an improvement over the user’s original text. However, the fine-tuned model’s response better matches the style of the fine-tuning data, because it is more succinct. 

### In the Chat Context

We have demonstrated that the fine-tuned model can provide good answers to individual questions. But it is also a competent participant in longer, multi-turn conversations.

In [47]:
model = "4708865e-3870-42bf-99fa-ffe84e81fd5f-ft"

def run_chat(user_message, messages=[]):

    messages = messages

    if not any(m.role == 'system' for m in messages):
        messages.append(cohere.SystemMessage(content=system_message))
        
    # Generate response
    response = co.chat(model=model,
                       messages=[cohere.UserMessage(content=user_message)])
    
    print(response.message.content[0]['text'])
    
    # Append the turn to the chat history
    messages.extend([cohere.UserMessage(content=user_message),
                     response.message])
    
    return messages

In [1]:
messages = run_chat("Hello")

In [2]:
messages = run_chat("I'm fine. Can I ask you for help with some tasks?", messages)

In [3]:
messages = run_chat("Make this more coherent: Manuel now has to decide-will he let his best friend be happy with her Prince Charming. Or will he fight for the love that has kept him alive for the last 16 years?", messages)

In [37]:
messages = run_chat("Help me with this one - She left Benaras. Conditions back home were bad.", messages)

Here is a possible continuation: 

She left Benaras with a heavy heart. The conditions back home in her village were dire, with a severe drought having ravaged the land. Crops had failed, livestock had perished, and people were struggling to survive. She knew that returning home would be challenging, but she couldn't bear the thought of her family suffering while she remained in the relative comfort of the city.

As she boarded the train, she said a silent prayer for strength and resilience. The journey back was long and arduous, the parched landscape a stark reminder of the hardships that lay ahead. Upon arriving, she was greeted by the all-too-familiar sight of cracked earth and withered trees.

However, despite the bleak surroundings, her determination burned brightly. She rolled up her sleeves and set to work, helping her family however she could. They conserved water, shared what little food they had, and worked together to find creative solutions to their problems.

It was a diff

In [38]:
messages = run_chat("What's a good time to visit London", messages)

London is a great city to visit all year round! However, the best time to visit London depends on what you want to do and see during your trip. Here are some factors to consider when planning your visit:

1. Weather: London has a mild temperate climate, which means you can expect cool winters and mild summers. If you're looking for warmer weather, the best time to visit London is during the summer months (June, July, and August). Keep in mind that this is also the busiest tourist season, so you may experience higher accommodation prices and longer lines at popular attractions.

2. Tourist Season: If you want to avoid the peak tourist crowds, consider visiting London during the shoulder seasons (spring and autumn). The weather is still pleasant, and you'll find shorter lines and better accommodation deals. Spring (March to May) is particularly lovely as the city's parks and gardens come to life with blooming flowers.

3. Cultural Events: London has a packed calendar of cultural events a

In [39]:
messages = run_chat("Could you help with this please: Make the text coherent: Critically the album has not been as well received as other Browne recordings. It remains his only album to date to reach number 1 on the Billboard chart.", messages)

Critically, the album has not been as well-received as some of Browne's other recordings, despite it being his only album to reach number one on the Billboard chart.


Note the fine-tuned model is still able to respond to prompts like “Hello”, “I’m fine. Can I ask you for help with some tasks?”, and “What’s a good time to visit London” instead of strictly following the fine-tuning objective of editing text.

The model also did a good job with context switching; it can hold a conversation when the user switches from friendly greetings, to a request for writing help, to travel planning, and finally back to writing assistance. It can also infer when the user is asking for help with making a text coherent, even if it is not explicitly stated (e.g., “Help me with this one”) or if the request is buried slightly (e.g., with “Could you help me with this please”).