In [1]:
import os
import requests
from typing import Any, List
from dotenv import load_dotenv
from llama_index.core.embeddings import BaseEmbedding

# Load environment variables from .env file
load_dotenv()

True

In [2]:
# Groq API 
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
GROQ_API_URL = "https://api.groq.com/v1/embeddings"

In [4]:
class GroqEmbeddingModel(BaseEmbedding):
    def __init__(
        self,
        instruction: str = "Represent the text for embedding:",
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        self._instruction = instruction
        self._api_key = GROQ_API_KEY
        self._api_url = GROQ_API_URL

    def _get_query_embedding(self, query: str) -> List[float]:
        """Get embedding for a single query."""
        return self._get_text_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """Get embedding for a single text via the Groq API."""
        headers = {
            'Authorization': f'Bearer {self._api_key}',
            'Content-Type': 'application/json'
        }
        payload = {
            "texts": [self._instruction + " " + text]  # Assuming the instruction is prepended to the text
        }

        response = requests.post(self._api_url, headers=headers, json=payload)
        if response.status_code == 200:
            data = response.json()
            embeddings = data.get('embeddings', [])[0]  # Assuming the API returns embeddings in this format
            return embeddings
        else:
            raise Exception(f"Failed to retrieve embedding from Groq API: {response.status_code}, {response.text}")

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Get embeddings for a list of texts via the Groq API."""
        headers = {
            'Authorization': f'Bearer {self._api_key}',
            'Content-Type': 'application/json'
        }
        payload = {
            "texts": [self._instruction + " " + text for text in texts]  # Assuming batch processing with instructions
        }

        response = requests.post(self._api_url, headers=headers, json=payload)
        if response.status_code == 200:
            data = response.json()
            embeddings = data.get('embeddings', [])  # Assuming the API returns a list of embeddings
            return embeddings
        else:
            raise Exception(f"Failed to retrieve embeddings from Groq API: {response.status_code}, {response.text}")

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """Asynchronous version for getting query embeddings."""
        return self._get_query_embedding(query)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """Asynchronous version for getting text embeddings."""
        return self._get_text_embedding(text)



In [5]:
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader

# Load documents from a directory
documents = SimpleDirectoryReader("data").load_data()

In [6]:
# Initialize GroqEmbeddingModel
groq_embed_model = GroqEmbeddingModel()

In [7]:
# Create the index using the custom Groq embedding model
index = VectorStoreIndex.from_documents(documents, embed_model=groq_embed_model, show_progress=True)

  from .autonotebook import tqdm as notebook_tqdm
Parsing nodes: 100%|██████████| 23/23 [00:00<00:00, 173.67it/s]
Generating embeddings:   0%|          | 0/49 [00:00<?, ?it/s]

Exception: Failed to retrieve embeddings from Groq API: 404, {"error":{"message":"Unknown request URL: POST /v1/embeddings. Please check the URL for typos, or see the docs at https://console.groq.com/docs/","type":"invalid_request_error","code":"unknown_url"}}


In [None]:
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.indices.postprocessor import SimilarityPostprocessor

retriever=VectorIndexRetriever(index=index,similarity_top_k=4)
postprocessor=SimilarityPostprocessor(similarity_cutoff=0.80)

query_engine=RetrieverQueryEngine(retriever=retriever,
                                  node_postprocessors=[postprocessor])

In [None]:
response=query_engine.query("What is attention is all yopu need?")

In [None]:
from llama_index.response.pprint_utils import pprint_response
pprint_response(response,show_source=True)
print(response)

In [None]:
import os.path
from llama_index import (
    VectorStoreIndex,
    SimpleDirectoryReader,
    StorageContext,
    load_index_from_storage,
)

# check if storage already exists
PERSIST_DIR = "./storage"
if not os.path.exists(PERSIST_DIR):
    # load the documents and create the index
    documents = SimpleDirectoryReader("data").load_data()
    index = VectorStoreIndex.from_documents(documents)
    # store it for later
    index.storage_context.persist(persist_dir=PERSIST_DIR)
else:
    # load the existing index
    storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
    index = load_index_from_storage(storage_context)

# either way we can now query the index
query_engine = index.as_query_engine()
response = query_engine.query("What are transformers?")
print(response)