# Fine-tuning Models for the Chat Endpoint

In this chapter, you will fine-tune a chatbot on custom conversational data to improve its performance at a specific task.  

_Read the [accompanying blog post here]()._

In [None]:
! pip install cohere hnswlib unstructured jsonlines -q

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/52.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.2/52.2 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m34.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m397.5/397.5 kB[0m [31m46.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m274.7/274.7 kB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m67.4 MB/s[0m eta

In [None]:
import os
import json
import jsonlines
import cohere

# instantiate the Cohere client
co = cohere.Client("COHERE_API_KEY")

In [None]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

## Step 1: Prepare and Validate the Dataset



### Download the dataset

Find the [dataset here](https://huggingface.co/datasets/grammarly/coedit).

In [None]:
# Download the dataset
! wget "https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl"

--2024-01-12 15:22:26--  https://huggingface.co/datasets/grammarly/coedit/resolve/main/train.jsonl
Resolving huggingface.co (huggingface.co)... 18.155.173.45, 18.155.173.122, 18.155.173.126, ...
Connecting to huggingface.co (huggingface.co)|18.155.173.45|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/30/91/3091c2c741f77a2f5aa8986b13e4fb2c3658ab3ebc30ecaa5f6890e60939bdf9/2913249158d6a178dc638e870212ff8a432d128eb6b4bdbe969ee805e6063ce3?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27train.jsonl%3B+filename%3D%22train.jsonl%22%3B&Expires=1705332146&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcwNTMzMjE0Nn19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zMC85MS8zMDkxYzJjNzQxZjc3YTJmNWFhODk4NmIxM2U0ZmIyYzM2NThhYjNlYmMzMGVjYWE1ZjY4OTBlNjA5MzliZGY5LzI5MTMyNDkxNThkNmExNzhkYzYzOGU4NzAyMTJmZjhhNDMyZDEyOGViNmI0YmRiZTk2OWVlODA1ZTYwNjNjZTM%7EcmVzcG9u

### Get a subset of the dataset

In [None]:
# we will use subset of the dataset focused on making text more coherent
phrase = "coherent"

# instantiate python list where we will store correct subset of dataset
dataset_list = []

# create subset of dataset
with jsonlines.open('train.jsonl') as f:
    for line in f.iter():
        if phrase in line['src'].split(":")[0]:
            dataset_list.append(line)

# Split data into training and test
dataset_list_train = dataset_list[:800]
dataset_list_test = dataset_list[800:]

print("Total number of examples:", len(dataset_list))
print("Number of examples in training set:", len(dataset_list_train))
print("Number of examples in the test set:", len(dataset_list_test))

Total number of examples: 927
Number of examples in training set: 800
Number of examples in the test set: 127


### Preview the dataset

In [None]:
# print the first ten prompts and corresponding responses
for item in dataset_list_train[:10]:
    print(item["src"])
    print(item["tgt"])
    print("-"*50)

Make the text coherent: The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Bank, by developing and financing Bank's portfolio of and strengthening the bank's funding base.
The Bank's main strategy is to further expand its network and increase its lending activities with particular focus on the SME sector. The EBRD helps Union Bank, by developing and financing its portfolio of and strengthening the bank's funding base.
--------------------------------------------------
Make the text coherent: It was not illegal under international law ; captured foreign sailors were released. Confederates went to prison camps.
It was not illegal under international law ; captured foreign sailors were released, while Confederates went to prison camps.
--------------------------------------------------
Make the text coherent: The Union blockade was a powerful weapon that eventually ruined the Southern econom

### Process the dataset for Cohere's Chat endpoint

In [None]:
# arranges the data to suit Cohere's format
def create_chat_ft_data(preamble, user_message, chatbot_message):
    formatted_data = {
        "messages": [
            {
                "role": "System",
                "content": preamble
            },
            {
                "role": "User",
                "content": user_message
            },
            {
                "role": "Chatbot",
                "content": chatbot_message
            }
        ]
    }

    return formatted_data

preamble = "You are a writing assistant that helps the user write coherent text."

# creates jsonl file from list of examples
def create_jsonl_from_list(file_name, dataset_segment, preamble):
    path = f'{file_name}.jsonl'
    if not os.path.isfile(path):
        with open(path, 'w+') as file:
            for item in dataset_segment:
                user_message = item["src"]
                chatbot_message = item["tgt"]
                formatted_data = create_chat_ft_data(preamble, user_message, chatbot_message)
                file.write(json.dumps(formatted_data) + '\n')
            file.close()

# Create training jsonl file
create_jsonl_from_list("coedit_coherence_train", dataset_list_train, preamble)

### Create a Dataset object

In [None]:
# create a Dataset object
dataset = co.create_dataset(name="coedit_coherence",
                            data=open("coedit_coherence_train.jsonl", "rb"),
                            dataset_type="chat-finetune-input")

# check the validation status of the dataset
print(dataset.await_validation())

uploading file, starting validation...
coedit-coherence-53gwpx was uploaded
...


cohere.Dataset {
	id: coedit-coherence-53gwpx
	name: coedit_coherence
	dataset_type: chat-finetune-input
	validation_status: validated
	created_at: 2024-01-11 20:58:03.776464
	updated_at: 2024-01-11 20:58:03.776464
	download_urls: ['https://storage.googleapis.com/cohere-user/dataset-api-temp/505106e3-9bfe-4335-b597-be1e2803a3ce/d53941a0-e52d-4c73-9261-577caf7f1f32/coedit-coherence-53gwpx/000_coedit_coherence_train.avro?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=dataset%40cohere-production.iam.gserviceaccount.com%2F20240111%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240111T205814Z&X-Goog-Expires=14399&X-Goog-Signature=4544d5b711464840ef2de306dff81ca924a052c5b98098c4e85efacf3c68af6f8504974ed252e23b76160cef46aae7527d155dccf42072746f0ce62eef4a1149d88bf10be8338361eab570e3fa07f4b4bb6a8e02848631ca117fdf5905c14b1b9efeff34561b14232869cef255f010496c798327d53b7820c21cd91ce9cf3c8cc396431f8033f2210c2e6da3e0be3d5ef29d27aad6b29c66ce70737feffac9ecb7cba5b80dfdf4b93b8c28041d1fda5dd03b72a312

## Step 2: Fine-Tune the Model

### Set hyperparameters and kick off fine-tuning

In [None]:
from cohere.responses.custom_model import HyperParametersInput

# define custom hyperparameters (optional)
hp = HyperParametersInput(
early_stopping_patience=6,      # default: 6
early_stopping_threshold=0.01,  # default: 0.01
train_batch_size=16,            # default: 16
train_epochs=1,                 # default: 1
learning_rate=0.01              # default: 0.01
)

# start fine-tuning using the dataset
co.create_custom_model(
    name="coedit-coherence-test4",
    dataset=dataset,
    model_type="CHAT",
    hyperparameters=hp
    )

cohere.CustomModel {
	id: 11fd5a56-8985-4920-85d6-89b6115802d2
	name: coedit-coherence-test4
	status: QUEUED
	model_type: CHAT
	created_at: 2024-01-11 20:58:23.674889+00:00
	completed_at: None
	base_model: medium
	model_id: 11fd5a56-8985-4920-85d6-89b6115802d2-ft
	hyperparameters: HyperParameters(early_stopping_patience=6, early_stopping_threshold=0.01, train_batch_size=16, train_steps=None, train_epochs=1, learning_rate=0.01)
	dataset_id: coedit-coherence-53gwpx
	billing: FinetuneBilling(train_epochs=1, num_training_tokens=79729, unit_price=1e-06, total_cost=0.079729)
}

### View model status

In [None]:
# get the CustomModel object
ft = co.get_custom_model_by_name('coedit-coherence-test4')
# print the status
print(ft)

cohere.CustomModel {
	id: 11fd5a56-8985-4920-85d6-89b6115802d2
	name: coedit-coherence-test4
	status: READY
	model_type: CHAT
	created_at: 2024-01-11 20:58:23.674889+00:00
	completed_at: None
	base_model: medium
	model_id: 11fd5a56-8985-4920-85d6-89b6115802d2-ft
	hyperparameters: HyperParameters(early_stopping_patience=6, early_stopping_threshold=0.01, train_batch_size=16, train_steps=50, train_epochs=1, learning_rate=0.01)
	dataset_id: coedit-coherence-53gwpx
	billing: FinetuneBilling(train_epochs=1, num_training_tokens=79729, unit_price=1e-06, total_cost=0.079729)
}


## Step 3: Evaluate the Fine-Tuned Model with Test Data



In [None]:
for item in dataset_list_test[:10]:
    # User prompt
    user_message = item["src"]
    # Desired/target response from dataset
    tgt_message = item["tgt"]

    # Get default model response
    response_pretrained=co.chat(
        message=user_message,
        preamble_override=preamble,
        )

    # Get fine-tuned model response
    response_finetuned = co.chat(
        message=user_message,
        model=ft.model_id,
        preamble_override=preamble,
        )

    print(f"User: {user_message}","\n-----")
    print(f"Default response: {response_pretrained.text}","\n-----")
    print(f"Fine-tuned response: {response_finetuned.text}","\n-----")
    print(f"Desired response: {tgt_message}")

    print("-"*50,"\n\n")

User: Make the text more coherent: We do know that at the end of the Muromachi period it stopped appearing in written records. That Muromachi burned down many times, the last we know of in 1405. 
-----
Default response: Although the specific reasons are not entirely clear, the absence of records indicating the presence of the entity after the end of the Muromachi period suggests that it may have ceased to exist or at least played a diminished role in societal affairs. The frequent burning of the Muromachi district, including the last significant fire in 1405, likely contributed to the diminishing significance of the entity in written records. 

It's worth noting that the lack of documentation does not necessarily mean that the entity ceased to exist entirely but could indicate a decrease in its cultural or societal influence. Additionally, there may have been other factors beyond the fires that contributed to the decrease in records, such as political or social changes. 

To obtain a m

## Step 4: Evaluate the Fine-Tuned Model in the Chat Context

In [None]:
# Create a conversation ID
import uuid
conversation_id = str(uuid.uuid4())

print('Starting the chat. Type "quit" to end.\n')

while True:

    # User message
    message = input("User: ")

    # Typing "quit" ends the conversation
    if message.lower() == 'quit':
        print("Ending chat.")
        break

    # Chatbot response
    response = co.chat(message=message,
                        model=ft.model_id,
                        stream=True,
                        conversation_id=conversation_id,
                        return_chat_history=True)

    print("Chatbot: ", end='')

    for event in response:
        if event.event_type == "text-generation":
            print(event.text, end='')

    print(f"\n{'-'*50}\n")

Starting the chat. Type "quit" to end.

User: Hello
Chatbot: Hello, how can I help you today?
--------------------------------------------------

User: I'm fine.  Can I ask you for help with some tasks?
Chatbot: I am happy to help you with any tasks that you need help with.  I can also provide information about any topic that you would like to know more about.
--------------------------------------------------

User: Make this more coherent: Manuel now has to decide-will he let his best friend be happy with her Prince Charming. Or will he fight for the love that has kept him alive for the last 16 years?
Chatbot: Manuel now has to decide-will he let his best friend be happy with her Prince Charming, or will he fight for the love that has kept him alive for the last 16 years?
--------------------------------------------------

User: Help me with this one - She left Benaras. Conditions back home were bad.
Chatbot: She left Benaras because conditions back home were bad.
-------------------