<a href="https://colab.research.google.com/github/ben-ogden/musiccaps/blob/main/init-pinecone-index.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install Dependencies

In [None]:
%pip install --upgrade jupyter ipywidgets sentence_transformers 'pinecone-client[grpc]' datasets torch

In [None]:
import torch

# set device to GPU if available
device = torch.cuda.current_device() if torch.cuda.is_available() else None

## Load and Preview the Dataset

In [None]:
from datasets import load_dataset

# load the dataset and convert to pandas dataframe
df = load_dataset(
    'google/MusicCaps', data_files='musiccaps-public.csv', split='train'
).to_pandas()

df

## Initialize Transformer

In [None]:
from sentence_transformers import SentenceTransformer

# load the model from huggingface
retriever = SentenceTransformer(
    'flax-sentence-embeddings/all_datasets_v3_mpnet-base',
    device=device
)
retriever

## Connect to Pinecone

In [None]:
from pinecone.grpc import PineconeGRPC as Pinecone
from pinecone import ServerlessSpec

pc = Pinecone(api_key='YOUR_API_KEY')

## Create Pinecone Index

In [None]:
index_name = 'music-caps-index'

# create the index if it does not exist
if index_name not in pc.list_indexes().names():
    pc.create_index(
        index_name,
        dimension=768,
        metric='cosine',
        spec=ServerlessSpec(
            cloud='aws', 
            region='us-east-1'
        ) 
    )

# connect to the index we created
index = pc.Index(index_name)

## Generate Embeddings and Populate Index

In [None]:
from tqdm.auto import tqdm

# we will use batches of 128
batch_size = 128

for i in tqdm(range(0, len(df), batch_size)):
    # find end of batch
    i_end = min(i+batch_size, len(df))
    # extract batch
    batch = df.iloc[i:i_end]
    # generate embeddings for batch
    emb = retriever.encode(batch['caption'].tolist()).tolist()
    # get metadata
    meta = batch.to_dict(orient='records')
    # create unique IDs
    ids = [f"{idx}" for idx in range(i, i_end)]
    # add all to upsert list
    to_upsert = list(zip(ids, emb, meta))
    # upsert/insert these records to pinecone
    _ = index.upsert(vectors=to_upsert)
 
# check that we have all vectors in index
index.describe_index_stats()

## Query the Index

In [None]:
from pprint import pprint

def search_pinecone(query):
    # create embeddings for the query
    xq = retriever.encode(query).tolist()
    # query the pinecone index for top 3 results
    xc = index.query(xq, top_k=3, include_metadata=True)
    return xc

In [None]:
query = 'lively eastern european folk music with strings outdoors'
search_pinecone(query)

## Clean up

In [None]:
pc.delete_index(index_name)