<a target="_blank" href="https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/Fine_tuning_Models_for_the_Chat_Endpoint.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Fine-tuning Models for Cohere Chat

In this chapter, you will fine-tune a chatbot on custom conversational data to improve its performance at a specific task.  

If you want to change the chatbot’s style of voice or output format, you often need to perform another round of training on additional data to ensure the best performance. This extra training is referred to as fine-tuning.

We'll do the following steps:
- Step 1: Prepare and Validate the Dataset (`co.create_dataset`)
- Step 2: Fine-Tune the Model (`co.create_custom_model`)
- Step 3: Use the Fine-Tuned Model  (`co.chat`)

_Read the [accompanying blog post here](https://txt.cohere.com/chat-finetuning-guide/)._

In [None]:
! pip install cohere jsonlines -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m45.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import json
import jsonlines
import cohere

# instantiate the Cohere client
co = cohere.Client('COHERE_API_KEY')

In [None]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

## Step 1: Prepare and Validate the Dataset


### Download the dataset

Find the [dataset here](https://huggingface.co/datasets/grammarly/coedit).

In [None]:
# Download the dataset
! wget "https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl"

--2024-02-26 03:38:04--  https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl
Resolving huggingface.co (huggingface.co)... 108.138.246.67, 108.138.246.71, 108.138.246.79, ...
Connecting to huggingface.co (huggingface.co)|108.138.246.67|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/30/91/3091c2c741f77a2f5aa8986b13e4fb2c3658ab3ebc30ecaa5f6890e60939bdf9/2913249158d6a178dc638e870212ff8a432d128eb6b4bdbe969ee805e6063ce3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27train.jsonl%3B+filename%3D%22train.jsonl%22%3B&Expires=1709177885&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwOTE3Nzg4NX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zMC85MS8zMDkxYzJjNzQxZjc3YTJmNWFhODk4NmIxM2U0ZmIyYzM2NThhYjNlYmMzMGVjYWE1ZjY4OTBlNjA5MzliZGY5LzI5MTMyNDkxNThkNmExNzhkYzYzOGU4NzAyMTJmZjhhNDMyZDEyOGViNmI0YmRiZTk2OWVlODA1ZTYwNjNjZTM%7EcmVzcG

### Get a subset of the dataset

In [None]:
# we will use subset of the dataset focused on making text more coherent
phrase = "coherent"

# instantiate python list where we will store correct subset of dataset
dataset_list = []

# create subset of dataset
with jsonlines.open('train.jsonl') as f:
    for line in f.iter():
        if phrase in line['src'].split(":")[0]:
            dataset_list.append(line)

# Split data into training and test
dataset_list_train = dataset_list[:800]
dataset_list_test = dataset_list[800:]

print("Total number of examples:", len(dataset_list))
print("Number of examples in training set:", len(dataset_list_train))
print("Number of examples in the test set:", len(dataset_list_test))

Total number of examples: 927
Number of examples in training set: 800
Number of examples in the test set: 127


### Preview the dataset

In [None]:
# print the first ten prompts and corresponding responses
for item in dataset_list_train[:10]:
    print(item["src"])
    print(item["tgt"])
    print("-"*50)

Make the text coherent: The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Bank, by developing and financing Bank's portfolio of and strengthening the bank's funding base.
The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Union Bank, by developing and financing its portfolio of and strengthening the bank's funding base.
--------------------------------------------------
Make the text coherent: It was not illegal under international law ; captured foreign sailors were released. Confederates went to prison camps.
It was not illegal under international law ; captured foreign sailors were released, while Confederates went to prison camps.
--------------------------------------------------
Make the text coherent: The Union blockade was a powerful weapon that eventually ruined the Southern econom

### Prepare the dataset for Cohere's Chat endpoint

In [None]:
# arranges the data to suit Cohere's format
def create_chat_ft_data(preamble, user_message, chatbot_message):
    formatted_data = {
        "messages": [
            {
                "role": "System",
                "content": preamble
            },
            {
                "role": "User",
                "content": user_message
            },
            {
                "role": "Chatbot",
                "content": chatbot_message
            }
        ]
    }

    return formatted_data

preamble = "You are a writing assistant that helps the user write coherent text."

# creates jsonl file from list of examples
def create_jsonl_from_list(file_name, dataset_segment, preamble):
    path = f'{file_name}.jsonl'
    if not os.path.isfile(path):
        with open(path, 'w+') as file:
            for item in dataset_segment:
                user_message = item["src"]
                chatbot_message = item["tgt"]
                formatted_data = create_chat_ft_data(preamble, user_message, chatbot_message)
                file.write(json.dumps(formatted_data) + '\n')
            file.close()

# Create training jsonl file
file_name = "coedit_coherence_train"
create_jsonl_from_list(file_name, dataset_list_train, preamble)

# List the first 3 items in the JSONL file
with jsonlines.open(f'{file_name}.jsonl') as f:
    [print(line) for _, line in zip(range(3), f)]

{'messages': [{'role': 'System', 'content': 'You are a writing assistant that helps the user write coherent text.'}, {'role': 'User', 'content': "Make the text coherent: The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Bank, by developing and financing Bank's portfolio of and strengthening the bank's funding base."}, {'role': 'Chatbot', 'content': "The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Union Bank, by developing and financing its portfolio of and strengthening the bank's funding base."}]}
{'messages': [{'role': 'System', 'content': 'You are a writing assistant that helps the user write coherent text.'}, {'role': 'User', 'content': 'Make the text coherent: It was not illegal under international law ; captured foreign sailors were released. Confederates went to prison camps.'}, {

### Create a Dataset object

In [None]:
# create a Dataset object
dataset = co.create_dataset(name="coedit_coherence",
                            data=open("coedit_coherence_train.jsonl", "rb"),
                            dataset_type="chat-finetune-input")

# check the validation status of the dataset
print(dataset.await_validation())

uploading file, starting validation...
coedit-coherence-pm2ft1 was uploaded
...


cohere.Dataset {
	id: coedit-coherence-pm2ft1
	name: coedit_coherence
	dataset_type: chat-finetune-input
	validation_status: validated
	created_at: 2024-02-26 03:43:42.900073
	updated_at: 2024-02-26 03:43:42.900074
	schema: {"name":"cohere.chat_finetune_input","type":"record","fields":[{"name":"messages","type":{"type":"array","items":{"name":"cohere.message","type":"record","fields":[{"name":"role","type":"string"},{"name":"content","type":"string"}]}}},{"name":"is_eval","type":"boolean","default":false}]}
	download_urls: ['https://storage.googleapis.com/cohere-production-user-datasets/dataset-api-temp/d489c39a-e152-49da-9ddc-9801bd74d823/96d12a16-2dd4-46f7-9630-1fa9bb0b26ca/coedit-coherence-pm2ft1/000_coedit_coherence_train.avro?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=dataset%40cohere-production.iam.gserviceaccount.com%2F20240226%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240226T034353Z&X-Goog-Expires=28799&X-Goog-Signature=1544b0b76375589ee44b155ee361db9cf1f428368a60

## Step 2: Fine-Tune the Model

### Set hyperparameters and kick off fine-tuning

In [None]:
from cohere.responses.custom_model import HyperParametersInput

# define custom hyperparameters (optional)
hp = HyperParametersInput(
early_stopping_patience=6,      # default: 6
early_stopping_threshold=0.01,  # default: 0.01
train_batch_size=16,            # default: 16
train_epochs=1,                 # default: 1
learning_rate=0.01              # default: 0.01
)

# start fine-tuning using the dataset
co.create_custom_model(
    name="coedit-coherence-test-5",
    dataset=dataset,
    model_type="CHAT",
    hyperparameters=hp
    )

cohere.CustomModel {
	id: 3154a093-cdf2-4e52-8e35-3c3d18468e86
	name: coedit-coherence-test-5
	status: QUEUED
	model_type: CHAT
	created_at: 2024-02-26 03:49:45.452460+00:00
	completed_at: None
	base_model: medium
	model_id: 3154a093-cdf2-4e52-8e35-3c3d18468e86-ft
	hyperparameters: HyperParameters(early_stopping_patience=6, early_stopping_threshold=0.01, train_batch_size=16, train_steps=None, train_epochs=1, learning_rate=0.01)
	dataset_id: coedit-coherence-pm2ft1
	billing: FinetuneBilling(train_epochs=1, num_training_tokens=79729, unit_price=1e-06, total_cost=0.079729)
}

### View model status

In [None]:
# get the CustomModel object
ft = co.get_custom_model_by_name('coedit-coherence-test-5')
# print the status
print(ft)

cohere.CustomModel {
	id: 3154a093-cdf2-4e52-8e35-3c3d18468e86
	name: coedit-coherence-test-5
	status: READY
	model_type: CHAT
	created_at: 2024-02-26 03:49:45.452460+00:00
	completed_at: None
	base_model: medium
	model_id: 3154a093-cdf2-4e52-8e35-3c3d18468e86-ft
	hyperparameters: HyperParameters(early_stopping_patience=6, early_stopping_threshold=0.01, train_batch_size=16, train_steps=50, train_epochs=1, learning_rate=0.01)
	dataset_id: coedit-coherence-pm2ft1
	billing: FinetuneBilling(train_epochs=1, num_training_tokens=79729, unit_price=1e-06, total_cost=0.079729)
}


## Step 3: Use/Evaluate the Fine-Tuned Model



### With Test Data

In [None]:
for item in dataset_list_test[:3]:
    # User prompt
    user_message = item["src"]
    # Desired/target response from dataset
    tgt_message = item["tgt"]

    # Get default model response
    response_pretrained=co.chat(
        message=user_message,
        preamble_override=preamble,
        )

    # Get fine-tuned model response
    response_finetuned = co.chat(
        message=user_message,
        model=ft.model_id,
        preamble_override=preamble,
        )

    print(f"User: {user_message}","\n-----")
    print(f"Desired response: {tgt_message}","\n-----")
    print(f"Default model's response: {response_pretrained.text}","\n-----")
    print(f"Fine-tuned model's response: {response_finetuned.text}")


    print("-"*100,"\n\n")

User: Make the text more coherent: We do know that at the end of the Muromachi period it stopped appearing in written records. That Muromachi burned down many times, the last we know of in 1405. 
-----
Desired response: We do know that at the end of the Muromachi period it stopped appearing in written records and that it burned down many times, the last we know of in 1405. 
-----
Default model's response: Sure! Here is your text rewritten for better coherence:

The end of the Muromachi period saw the disappearance of written records about it. The period witnessed several occurrences of fire, with the last one taking place in 1405. 

Let me know if it makes more sense now! 
-----
Fine-tuned model's response: We do know that at the end of the Muromachi period it stopped appearing in written records because it burned down many times, the last we know of in 1405.
---------------------------------------------------------------------------------------------------- 


User: Make the text cohe

### In the Chat Context

In [None]:
# Create a conversation ID
import uuid
conversation_id = str(uuid.uuid4())

print('Starting the chat. Type "quit" to end.\n')

while True:

    # User message
    message = input("User: ")

    # Typing "quit" ends the conversation
    if message.lower() == 'quit':
        print("Ending chat.")
        break

    # Chatbot response
    response = co.chat(message=message,
                        model=ft.model_id,
                        stream=True,
                        conversation_id=conversation_id,
                        return_chat_history=True)

    print("Chatbot: ", end='')

    for event in response:
        if event.event_type == "text-generation":
            print(event.text, end='')

    print(f"\n{'-'*50}\n")

Starting the chat. Type "quit" to end.

User: Hello
Chatbot: Hello.
--------------------------------------------------

User: Make this more coherent: Manuel now has to decide-will he let his best friend be happy with her Prince Charming. Or will he fight for the love that has kept him alive for the last 16 years?
Chatbot: Manuel now has to decide whether he let his best friend be happy with her Prince Charming or he fight for the love that has kept him alive for the last 16 years.
--------------------------------------------------

User: Help me with this one - She left Benaras. Conditions back home were bad.
Chatbot: She left Benaras because it was bad.
--------------------------------------------------

User: What's a good time to visit London
Chatbot: A good time to visit London is during the spring season, when the weather is mild and the days are long and light.
--------------------------------------------------

User: Could you help with this please: Make the text coherent: Crit