This experiment focuses on the retriever part of the RAG system. Let us prepare a system retrieving a single document both since there is little data and for the sake of simplicity and let's focus on selecting the correct country.

In [21]:
import os
import json
import uuid
from pathlib import Path
import numpy as np

import mlflow
import requests
import chromadb

from tools import read_user_questions

EMBEDDING_URL = 'http://localhost:11434/api/embeddings'
EMBEDDING_MODEL = 'mxbai-embed-large'
SIMILARITY_METRIC = 'cosine'

MLFLOW_SERVER_URL = 'http://localhost:5000'
EXPERIMENT_NAME = 'retriever-country'

CHROMA_DB_PATH = Path().resolve().parent / 'data' / 'chroma-db'
USER_QUESTIONS_DIR = Path().resolve().parent / 'data' / 'user_questions'
NO_SALE_DOCUMENTS = Path().resolve().parent / 'data' / 'no_sale_countries.json'

RUN_ID = str(uuid.uuid4())

os.environ['LOGNAME'] = 'Michal Racko'

Load user-question dataset while keeping track of which country the questions belong to

In [22]:
user_questions = {
    country.title(): read_user_questions(USER_QUESTIONS_DIR / f'{country}.txt')
    for country in ('germany', 'italy', 'spain', 'sweden')
}
with open(NO_SALE_DOCUMENTS) as f:
    document_data = json.load(f)

{'Germany': ['What are the import regulations for home accessories in Germany?',
  'How much is the import tariff for home furniture into Germany?',
  'Are there any specific restrictions on importing electrical home accessories to Germany?',
  'What are the customs duties for importing textiles and fabrics for home decor into Germany?',
  'What documentation is required for importing home accessories to Germany?',
  'What is the VAT rate for home accessories in Germany?',
  'How does the VAT refund process work for home accessories in Germany?',
  'Are there any reduced VAT rates for specific types of home goods in Germany?',
  'What is the corporate tax rate for home goods companies in Germany?',
  'Are there any tax incentives for importing eco-friendly home accessories into Germany?',
  'What are the product safety standards for electrical home accessories in Germany?',
  'What certification is required for selling home appliances in Germany?',
  'Are there any specific labeling re

Let's embed individual documents and save them into ChromaDB

In [18]:
chroma_client = chromadb.PersistentClient(path=str(CHROMA_DB_PATH))
collection = chroma_client.create_collection(
    name=f'company-documents-{RUN_ID}',
    metadata={'hnsw:space': SIMILARITY_METRIC}
)
for data in document_data['documents']:
    response = requests.post(
        EMBEDDING_URL,
        json={
            'model': EMBEDDING_MODEL,
            'prompt': data['text']
        }
    )
    if response.status_code == 200:
        collection.add(
            ids=[data['meatadata']['geography']],
            embeddings=response.json()['embedding'],
            documents=data['text']
        )
    else:
        raise Exception(response.text)

Now we can evaluate the results and calculate the accuracy

In [19]:
correct_class = []
for country, questions in user_questions.items():
    for question in questions:
        response = requests.post(
            EMBEDDING_URL,
            json={
                'model': EMBEDDING_MODEL,
                'prompt': question
            }
        )
        results = collection.query(response.json()['embedding'], n_results=1)
        correct_class.append(results['ids'][0][0].lower() == country.lower())
correct_class = np.array(correct_class)
accuracy = correct_class.sum() / len(correct_class)
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 86.75%


Finally, push the run parameters and metrics to the MLFlow server

In [20]:
mlflow.set_tracking_uri(MLFLOW_SERVER_URL)
mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run(run_name=RUN_ID):
    mlflow.log_params({
        'model': EMBEDDING_MODEL,
        'similarity_metric': SIMILARITY_METRIC,
        'run_id': RUN_ID
    })
    mlflow.log_metric('accuracy', accuracy)