<a target="_blank" href="https://colab.research.google.com/github/cohere-ai/notebooks/blob/main/notebooks/llmu/Introduction_Text_Embeddings.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Introduction to Text Embeddings

Text embeddings are a useful way to turn text into numbers that capture its meaning and context. In this notebook, you'll learn how to put them into practice using Cohere's [Embed endpoint](https://docs.cohere.com/reference/embed). You'll calculate embeddings for a dataset of sentences, and plot them in the plane to observe graphically that indeed similar sentences are mapped to close points in the embedding. You'll also explore how to leverage embeddings for semantic search and clustering.

## Setup

We'll start by installing the tools we'll need and then importing them.

In [1]:
! pip install cohere altair -q

In [2]:
import cohere
import pandas as pd
import numpy as np
import altair as alt
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans

In [3]:
co = cohere.ClientV2("COHERE_API_KEY") # Get your free API key: https://dashboard.cohere.com/api-keys

## Introduction to Text Embeddings

In this section, we understand the intuition behind text embeddings.

### Step 1: Prepare the Dataset

We'll work with a subset of the Airline Travel Information System (ATIS) dataset ([source](https://aclanthology.org/H90-1021/)), created based on customer inquiries related to flight bookings, flight departures, arrivals, delays, and cancellations. In the next code cell, we create and preview a dataframe `df` containing 91 queries.

In [3]:
# Load the dataset to a dataframe
df_orig = pd.read_csv('https://raw.githubusercontent.com/cohere-ai/notebooks/main/notebooks/data/atis_intents_train.csv', names=['intent','query'])

# Take a small sample for illustration purposes
sample_classes = ['atis_airfare', 'atis_airline', 'atis_ground_service']
df = df_orig.sample(frac=0.1, random_state=30)
df = df[df.intent.isin(sample_classes)]
df_orig = df_orig.drop(df.index)
df.reset_index(drop=True,inplace=True)

# Remove unnecessary column 
intents = df['intent'] #save for a later need
df.drop(columns=['intent'], inplace=True)
df.head()

Unnamed: 0,query
0,which airlines fly from boston to washington ...
1,show me the airlines that fly between toronto...
2,show me round trip first class tickets from n...
3,i'd like the lowest fare from denver to pitts...
4,show me a list of ground transportation at bo...


In [4]:
for i in df.head(10)["query"]:
    print(i)

 which airlines fly from boston to washington dc via other cities
 show me the airlines that fly between toronto and denver
 show me round trip first class tickets from new york to miami
 i'd like the lowest fare from denver to pittsburgh
 show me a list of ground transportation at boston airport
 show me boston ground transportation
 of all airlines which airline has the most arrivals in atlanta
 what ground transportation is available in boston
 i would like your rates between atlanta and boston on september third
 which airlines fly between boston and pittsburgh


### Step 2: Turn Text into Embeddings

Next, we use a function `get_embeddings()` that uses the Embed endpoint to generate the text embeddings.  We store the embeddings in the `query_embeds` column of the dataframe `df`.

In [5]:
# Get text embeddings
def get_embeddings(texts, model="embed-english-v3.0", input_type="search_document"):
    output = co.embed(
        texts=texts, 
        model=model, 
        input_type=input_type, 
        embedding_types=["float"]
    )
    return output.embeddings.float

In [6]:
# Embed the dataset
df['query_embeds'] = get_embeddings(df['query'].tolist())
df.head()

Unnamed: 0,query,query_embeds
0,which airlines fly from boston to washington ...,"[0.02609253, 0.012168884, -0.008903503, 0.0114..."
1,show me the airlines that fly between toronto...,"[0.013801575, 0.017181396, -0.014984131, -0.00..."
2,show me round trip first class tickets from n...,"[0.02053833, -0.038482666, 0.061523438, 0.0099..."
3,i'd like the lowest fare from denver to pitts...,"[0.0016708374, 0.015625, -0.029022217, 0.03759..."
4,show me a list of ground transportation at bo...,"[0.037628174, -0.007888794, -0.0024662018, -0...."


### Step 3: Visualize Embeddings on a Heatmap

Let’s get some visual intuition about the embeddings by plotting these numbers on a heatmap. What we can do is compress the dimension to a much lower number, say 10. We can do this via a technique called [Principal Component Analysis (PCA)](https://en.wikipedia.org/wiki/Principal_component_analysis), which reduces the number of dimensions in an embedding while retaining as much information as possible.

In [7]:
# Function to return the principal components
def get_pc(arr, n):
    pca = PCA(n_components=n)
    embeds_transform = pca.fit_transform(arr)
    return embeds_transform

In [8]:
# Reduce embeddings to 10 principal components to aid visualization
embeds = np.array(df['query_embeds'].tolist())
embeds_pc = get_pc(embeds, 10)

Next, we visualize the embeddings of the first 9 data points.

In [9]:
# Set sample size to visualize
sample = 9

# Reshape the data for visualization purposes
source = pd.DataFrame(embeds_pc)[:sample]
source = pd.concat([source,df['query']], axis=1)
source = source.melt(id_vars=['query'])

# Configure the plot
chart = alt.Chart(source).mark_rect().encode(
    x=alt.X('variable:N', title="Embedding"),
    y=alt.Y('query:N', title='',axis=alt.Axis(labelLimit=500)),
    color=alt.Color('value:Q', title="Value", scale=alt.Scale(
                range=["#917EF3", "#000000"]))
)

result = chart.configure(background='#ffffff'
        ).properties(
        width=700,
        height=400,
        title='Embeddings with 10 dimensions'
       ).configure_axis(
      labelFontSize=15,
      titleFontSize=12)

# Show the plot
result

Notice the 3 inquiries about ground transportation in Boston - their embeddings patterns are very similar, and at the same time are distinctive from the rest.

### Step 4: Visualize Embeddings on a 2D Plot

We can investigate this further by compressing the embeddings to two dimensions and plotting them on a scatter plot. What we would expect is that texts of similar meaning would be closer to each other, and vice versa.

In [10]:
# Function to generate the 2D plot
def generate_chart(df,xcol,ycol,lbl='on',color='basic',title=''):
    chart = alt.Chart(df).mark_circle(size=500).encode(
        x=
        alt.X(xcol,
              scale=alt.Scale(zero=False),
              axis=alt.Axis(labels=False, ticks=False, domain=False)
             ),
        y=
        alt.Y(ycol,
              scale=alt.Scale(zero=False),
              axis=alt.Axis(labels=False, ticks=False, domain=False)
             ),
        color= alt.value('#333293') if color == 'basic' else color,
        tooltip=['query']
    )
    
    if lbl == 'on':
        text = chart.mark_text(align='left', baseline='middle',dx=15, size=13,color='black').encode(text='query', color= alt.value('black'))
    else:
        text = chart.mark_text(align='left', baseline='middle',dx=10).encode()
        
    result = (chart + text).configure(background="#FDF7F0").properties(
        width=800,
        height=500,
        title=title
    ).configure_legend(orient='bottom', titleFontSize=18,labelFontSize=18)
    
    return result

In [11]:
# Reduce embeddings to 2 principal components to aid visualization
embeds_pc2 = get_pc(embeds, 2)

# Add the principal components to dataframe
df_pc2 = pd.concat([df, pd.DataFrame(embeds_pc2)], axis=1)

# Plot the 2D embeddings on a chart
df_pc2.columns = df_pc2.columns.astype(str)
generate_chart(df_pc2.iloc[:sample],'0','1',title='2D Embeddings')

Here texts of similar meaning are located close together. We see inquiries about tickets on the left, inquiries about airlines somewhere around the middle, and inquiries about ground transportation on the top right.

## Introduction to Semantic Search

In this section, you'll learn how to use embeddings to build a search capability that surfaces relevant information based on the semantic meaning of a query. We'll work with the same 9 data points plotted above.

_Read the accompanying [blog post here](https://txt.cohere.ai/introduction-to-semantic-search/)._

### Step 1: Embed the Search Query

We'll use the query below that inquires about ground transportation without using the words "ground transportation" explicitly. Ideally, the corresponding data points that are surfaced as most relevant are the three dealing with ground transportation options.

In [12]:
# Define new query
new_query = "How can I find a taxi or a bus when the plane lands?"

We embed the query using the same `get_embeddings()` function as before, but now we set `input_type="search_query"` because we're embedding a search query that we want to compare to the embedded documents.

In [13]:
# Get embeddings of the new query
new_query_embeds = get_embeddings([new_query], input_type="search_query")[0]

### Step 2: Compare to Embedded Documents

We define and use a function `get_similarity()` that employs cosine similarity to determine how similar the documents are to our query. 

In [14]:
# Calculate cosine similarity between the search query and existing queries
def get_similarity(target, candidates):
    # Turn list into array
    candidates = np.array(candidates)
    target = np.expand_dims(np.array(target),axis=0)

    # Calculate cosine similarity
    sim = cosine_similarity(target, candidates)
    sim = np.squeeze(sim).tolist()
    sort_index = np.argsort(sim)[::-1]
    sort_score = [sim[i] for i in sort_index]
    similarity_scores = zip(sort_index,sort_score)

    # Return similarity scores
    return similarity_scores

# Get the similarity between the search query and existing queries
similarity = get_similarity(new_query_embeds, embeds[:sample])

We'll then view the documents in decreasing order of similarity.

In [15]:
# View the top 5 articles
print('Query:')
print(new_query,'\n')

print('Most Similar Documents:')
for idx, sim in similarity:
    print(f'Similarity: {sim:.2f};', df.iloc[idx]['query'])

Query:
How can I find a taxi or a bus when the plane lands? 

Most Similar Documents:
Similarity: 0.37;  show me a list of ground transportation at boston airport
Similarity: 0.36;  what ground transportation is available in boston
Similarity: 0.33;  show me boston ground transportation
Similarity: 0.27;  show me the airlines that fly between toronto and denver
Similarity: 0.25;  which airlines fly from boston to washington dc via other cities
Similarity: 0.24;  of all airlines which airline has the most arrivals in atlanta
Similarity: 0.18;  i'd like the lowest fare from denver to pittsburgh
Similarity: 0.17;  show me round trip first class tickets from new york to miami
Similarity: 0.17;  i would like your rates between atlanta and boston on september third


The top three most similar documents are the inquiries about ground transportation. Notice that the query didn't explicitly mention "ground transportation," yet the embedding successfully captured the underlying meaning. It recognized the words "taxi" and "bus," grouping them together as similar concepts.

### Step 3: Visualize the Results in a 2D Plot

To prepare for plotting, we use PCA to reduce the query embedding to 2 dimensions with the `get_pc()` function we defined earlier.

In [16]:
# Create new dataframe and append new query
df_sem = df.copy()
df_sem.loc[len(df_sem.index)] = [new_query, new_query_embeds]

# Reduce embeddings dimension to 2
embeds_sem = np.array(df_sem['query_embeds'].tolist())
embeds_sem_pc2 = get_pc(embeds_sem, 2)

# Add the principal components to dataframe
df_sem_pc2 = pd.concat([df_sem, pd.DataFrame(embeds_sem_pc2)], axis=1)

We then can plot the query on a 2D scatter plot with the nine documents.

In [17]:
# Create column for representing chart legend
df_sem_pc2['Source'] = 'Existing'
df_sem_pc2.at[len(df_sem_pc2)-1, 'Source'] = "New"

# Plot on a chart
df_sem_pc2.columns = df_sem_pc2.columns.astype(str)
selection = list(range(sample)) + [-1]
generate_chart(df_sem_pc2.iloc[selection],'0','1',color='Source',title='Semantic Search')

We see that the query is located closest to the FAQs about ground transportation.

## Clustering Using Embeddings

In this section, you will learn how to use embeddings to group similar documents into clusters, to discover emerging patterns in the documents. We'll work with the same 9 data points as before.

_Read the accompanying [blog post here](https://docs.cohere.com/docs/clustering-with-embeddings)._

### Step 1: Embed the Text for Clustering

We embed the documents using the same `get_embeddings()` function as before, but now we set `input_type="clustering"` because we'll use the embeddings for clustering.

In [18]:
# Embed the text for clustering
df['clustering_embeds'] = get_embeddings(df['query'].tolist(), input_type="clustering")
embeds = np.array(df['clustering_embeds'].tolist())

### Step 2: Cluster the Embeddings

We use the K-means algorithms to cluster these data points. Since our dataset is small, we'll set the number of clusters to 2. In actual applications, this number is normally larger.

We store the cluster assignments in the `"cluster"` column of the `df_clust` DataFrame.

In [19]:
# Pick the number of clusters
n_clusters = 2

# Cluster the embeddings
kmeans_model = KMeans(n_clusters=n_clusters, n_init='auto', random_state=0)
classes = kmeans_model.fit_predict(embeds).tolist()

# Store the cluster assignments
df_clust = df_pc2.copy()
df_clust['cluster'] = (list(map(str,classes)))

# Preview the cluster assignments
df_clust.head()

Unnamed: 0,query,query_embeds,0,1,cluster
0,which airlines fly from boston to washington ...,"[0.02609253, 0.012168884, -0.008903503, 0.0114...",0.073794,-0.252909,0
1,show me the airlines that fly between toronto...,"[0.013801575, 0.017181396, -0.014984131, -0.00...",0.00505,-0.316227,0
2,show me round trip first class tickets from n...,"[0.02053833, -0.038482666, 0.061523438, 0.0099...",0.195839,0.081132,0
3,i'd like the lowest fare from denver to pitts...,"[0.0016708374, 0.015625, -0.029022217, 0.03759...",0.111741,0.206098,0
4,show me a list of ground transportation at bo...,"[0.037628174, -0.007888794, -0.0024662018, -0....",-0.313675,-0.229535,1


### Step 3: Visualize the Results in a 2D Plot

When specified with 2 clusters to group the documents by, the algorithm looks to be spot on, where it generates one cluster related to airline information and one cluster related to ground service information.

In [20]:
# Plot on a chart
df_clust.columns = df_clust.columns.astype(str)
generate_chart(df_clust.iloc[:sample],'0','1',lbl='on',color='cluster',title='Clustering with 2 Clusters')