# Pronova LLM Train Model #
## Use this notebook to do the following ##
- Create Qdrant collections
- Chunk text files into smaller chunks
- Create embeddings for data chunks and querys
- Delete qdrant vectors

In [None]:
# Load require librarys
import os
from qdrant_client import QdrantClient
from qdrant_client.http import models
from openai import OpenAI
from dotenv import load_dotenv
from IPython.display import Markdown, display

# Load environment variables from .env file
load_dotenv()

### Setup Qdrant connection ###

In [None]:
# Get the Qdrant API key from the environment variable
Qdrant_api_key = os.getenv('Qdrant_API_KEY')
if not Qdrant_api_key:
    raise ValueError("No Qdrant API key found in environment variables")
Qdrant_url = os.getenv('Qdrant_URL')
if not Qdrant_url:
    raise ValueError("No Qdrant URL found in environment variables")


# Initialize Qdrant client
try:
    Qclient = QdrantClient(
        url= Qdrant_url,
        api_key=Qdrant_api_key
    )
    print("Successfully connected to Qdrant")
except Exception as e:
    print(f"Failed to connect to Qdrant: {e}")
    raise

### Setup OpenAI connection ###

In [None]:
# Get the OpenAI API key from the environment variable
OpenAI_api_key = os.getenv('OPENAI_API_KEY')
if not OpenAI_api_key:
    raise ValueError("No OpenAI API key found in environment variables")

OpenAI.api_key = OpenAI_api_key

### Creating a Qdrant Collection (function) ###

In [None]:
#Create collection
def create_qdrant_collection(collection_name):
    try:
        Qclient.create_collection(
            collection_name=collection_name,
            vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE),
        )
        print(f"Collection '{collection_name}' created successfully")
    except Exception as e:
        print(f"Failed to create collection '{collection_name}': {e}")
        raise

In [None]:
# create_qdrant_collection("eating-problems")
# Qclient.get_collections()

### Get an OpenAI embedding from a text segment (Function) ###

In [None]:
# Function to get the embedding of a text
def get_embedding(text):
    client = OpenAI()
    response = client.embeddings.create(
        model="text-embedding-ada-002",
        input=text
    )
    return response.data[0].embedding

### Turn a file into chunks (Function) ###

In [None]:
# Function to read a text file and chunk its content

## maybe change chunk size!
def chunk_text_from_file(file_path, chunk_size=400):
    with open(file_path, 'r', encoding='utf-8') as file:
        text = file.read()
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

### Get embedding from a list of chunks (Function) ###

In [None]:
# Function to get embeddings for a list of text chunks
def get_embeddings_for_chunks(chunks):
    embeddings = []
    for chunk in chunks:
        embedding = get_embedding(chunk)
        embeddings.append(embedding)
    return embeddings

### Get largest Qdrant ID (function) ###

In [None]:
def get_qdrant_next_id(collection_name):
    try:
        collection = Qclient.get_collection(collection_name=collection_name)
        id = collection.points_count
        if id == None:
            # print("id was zero")
            return 0
        else:
            # print(f"Next ID in collection '{collection_name}' is {id}")
            return id
    except Exception as e:
        print(f"Failed to get next ID from collection '{collection_name}': {e}")
        raise

### Upsert lists of documents to Qdrant (Function) ###
#### Mind the parameters ####

In [None]:
# function to upsert embeddings into Qdrant
# def upsert_embeddings(collection_name, embeddings, chunks):
def upsert_embeddings(collection_name, embeddings, chunks, filename, citation):
    largest_id = get_qdrant_next_id(collection_name)
    points = []
    topic, url, author, date = citation

    for i in range(len(embeddings)):
        points.append(
            {
                "id": i + largest_id + 1,
                "vector": embeddings[i],
                "payload": {
                    "text": chunks[i],   # Attach the chunk as payload
                    "source_file": filename,  # Add source file
                    "topic": topic,
                    "url": url,
                    "author": author,
                    "date": date

                }            
            }
        )
    

    try:
        Qclient.upsert(
            collection_name=collection_name,
            points=points
        )
        print("Embeddings upserted successfully")
    except Exception as e:
        print(f"Failed to upsert embeddings: {e}")
        raise

### Delete all entries in a collection effectively resetting a model (Function) ###
**BE CAREFUL WITH THIS.** *(It is secured by a verification process)*

In [None]:
def reset_model(collection_name):
    confirmation = input("Type 'DELETE' to confirm deletion of entries in " + collection_name + ". This cannot be undone. Type anything else to abort")
    if confirmation == "DELETE" :
        try:
            Qclient.delete_collection(collection_name=collection_name)
            print(f"Collection '{collection_name}' deleted successfully")
            create_qdrant_collection(collection_name)
        except Exception as e:
            print(f"Failed to delete collection '{collection_name}': {e}")
            raise
    else:
        print("Deletion aborted")

### Process Files From Folder (function) ###

In [None]:
import json

def get_citation(filename, sources_file):
    
    with open(sources_file, 'r', encoding='utf-8') as file:
        sources = json.load(file)
    
    
    citation = sources.get(filename, {})
    url = citation.get('URL', 'URL not found')
    author = citation.get('Author', 'Author not found')
    date = citation.get('Date', 'Date not found')
    topic = citation.get('Topic', 'Date not found')
    
    
    return topic, url, author, date

In [None]:
def process_files_from_folder(collection_name, file_path, sources_file, start, end):
    
    #number of files in the folder 
    num_files = len([name for name in os.listdir(file_path) if os.path.isfile(os.path.join(file_path, name))])
    
    if end == None or end > num_files - 1:
        end = num_files - 1

    current_file = start

    for filename in os.listdir(file_path)[start:]:

        citation = get_citation(filename, sources_file)
        topic, url, author, date = citation
        if topic == None or url == None or author == None or date == None:
            print("ERROR ERROR ERROR")
             

        curr_file_path = os.path.join(file_path, filename)
        print(f"Processing file {current_file} of max {end}: {filename}")
        
        # Chunk the file
        chunks = chunk_text_from_file(curr_file_path)
        # Get embeddings for the chunks
        embeddings = get_embeddings_for_chunks(chunks)

        # Upsert the embeddings into Qdrant
        upsert_embeddings(collection_name, embeddings, chunks, filename, citation)
        # print("Done")
        current_file += 1


In [None]:

# # create_qdrant_collection("FullModel")
start_index = 0 # start at 1
end_index = 22

## these are all the diff folders and sources. Run all of these to train!

# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_allergies", 'scrapingDemo/sources_petMD_allergies.json', start_index, end_index)
# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_behavior", 'scrapingDemo/sources_petMD_behavior.json', start_index, None)
# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_care_healthy_living", 'scrapingDemo/sources_petMD_care_healthy_living.json', start_index, None)
# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_nutrition", 'scrapingDemo/sources_petMD_nutrition.json', start_index, None)
# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_procedures", 'scrapingDemo/sources_petMD_procedures.json', start_index, None)
# # process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD_symptoms", 'scrapingDemo/sources_petMD_symptoms.json', start_index, None)

# process_files_from_folder("FullModel", "scrapingDemo/ScrapedFiles_petMD", 'scrapingDemo/sources_petMD.json', 589, None)
