### Chromadb setup

Stockage des embeddings : 
- Enregistrement des vecteurs et de leurs métadonnées dans la base vectorielle ChromaDB :
    - Une collection pour les données d’entraînement
    - Une collection pour les données de test.

In [1]:
import chromadb
import os
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

In [2]:
# load embeddings
embeddings = np.load("../data/embedding/embeddings.npy")

embeddings.shape

(14292, 384)

In [3]:
# Load original df
df = pd.read_csv("../data/processed/data.csv")
df = df.reset_index(drop=True)


prepare data

In [4]:
# Get train/test indices
train_idx, test_idx = train_test_split(
    range(len(df)),
    test_size=0.2, 
    random_state=42, 
    stratify=df['airline_sentiment']
)


# Add train data

train_df = df.iloc[train_idx].reset_index(drop=True)
train_embeddings = embeddings[train_idx]

# Add test data

test_df = df.iloc[test_idx].reset_index(drop=True)
test_embeddings = embeddings[test_idx]

print(f"Train: {len(train_df)}, Test: {len(test_df)}")


Train: 11433, Test: 2859


inserting

In [None]:

CHROMA_PATH = "../chromadb"
BATCH_SIZE = 5000
os.makedirs(CHROMA_PATH, exist_ok=True)
client = chromadb.PersistentClient(CHROMA_PATH)

def upsert_to_collection(collection_name, df, embeddings, id_prefix, text_col='clean_text'):
    """Generic function to batch insert data into a Chroma collection."""
    collection = client.get_or_create_collection(name=collection_name)
    
    if collection.count() > 0:
        return collection

    for batch_start in range(0, len(df), BATCH_SIZE):
        batch_end = min(batch_start + BATCH_SIZE, len(df))
        
        batch_df = df.iloc[batch_start:batch_end]
        batch_embs = embeddings[batch_start:batch_end]

        try:
            collection.add(
                embeddings=batch_embs.tolist(),
                documents=batch_df[text_col].tolist(),
                metadatas=[
                    {
                        "label": str(row['airline_sentiment']),
                        "airline": str(row['airline'])
                    } for _, row in batch_df.iterrows()
                ],
                ids=[f"{id_prefix}_{i}" for i in range(batch_start, batch_end)]
            )
        except Exception as e:
            print(f"error inserting into {collection_name}: {e}")
            raise
            
    return collection

train_collection = upsert_to_collection(
    "airline_sentiment_train", train_df, train_embeddings, "train", text_col='clean_text'
)

test_collection = upsert_to_collection(
    "airline_sentiment_test", test_df, test_embeddings, "test", text_col='text'
)

print(f"Train collection: {train_collection.count()} documents")
print(f"Test collection:  {test_collection.count()} documents")

Data already exists in airline_sentiment_train. Skipping...
Data already exists in airline_sentiment_test. Skipping...
Train collection: 11418 documents
Test collection:  2855 documents


In [None]:
print(client.list_collections())

collection = client.get_collection("airline_sentiment_train")

print(collection.count()) 


[Collection(name=airline_sentiment_train), Collection(name=airline_sentiment_test)]
11418


In [7]:
collection.get(include=["documents", "metadatas"], limit=5)


{'ids': ['train_0', 'train_1', 'train_2', 'train_3', 'train_4'],
 'embeddings': None,
 'documents': ['you are the Official airlines of DivadaPouch aka ThePoopQueen',
  'just Cancelled Flighted my flight and told me to call to rebook. Been on hold for 48 minutes at 4 am and still waiting',
  "Hi, Virgin! I'm on hold for 40-50 minutes -- are there any earlier flights from LA to NYC tonight; earlier than 11:50pm?",
  'any ways to get through the 50 minute wait to book a flight?',
  "but you guys switched me and didn't inform me of the chathes"],
 'uris': None,
 'included': ['documents', 'metadatas'],
 'data': None,
 'metadatas': [{'label': 'positive', 'airline': 'Southwest'},
  {'airline': 'US Airways', 'label': 'negative'},
  {'label': 'negative', 'airline': 'Virgin America'},
  {'airline': 'American', 'label': 'negative'},
  {'airline': 'US Airways', 'label': 'negative'}]}

In [None]:
train_collection = client.get_collection("airline_sentiment_train")

# Query similar tweets
results = train_collection.query(
    query_texts=["Flight was delayed for hours"],
    n_results=5,
    include=["documents", "metadatas", "distances"]
)

print("Similar tweets:")
for doc, metadata, distance in zip(
    results['documents'][0], 
    results['metadatas'][0], 
    results['distances'][0]
):
    print(f"\nText: {doc}")
    print(f"Label: {metadata['label']}")
    print(f"Distance: {distance:.4f}")

Similar tweets:

Text: Gerne :)
Label: neutral
Distance: 8.2260

Text: thanks for the response. I know it's not your fault... But Im in ORD in T5 and hungry if you want to stop by
Label: negative
Distance: 9.1247

Text: Please try it yourself - call 1-800-433-7300 and see what happens... then you'll understand. allrepresentativesbusy nooption
Label: negative
Distance: 9.2172

Text: ok!!! That's super helpful. Thank you. I'll reach out if I have any other questions.
Label: positive
Distance: 9.2285

Text: are u paying incedentals? noworstairline
Label: neutral
Distance: 9.3978


: 