In [2]:
import os
import time
from dotenv import load_dotenv
from qdrant_client import QdrantClient
from sentence_transformers import SentenceTransformer
import numpy as np

load_dotenv()  # loads QDRANT_URL and QDRANT_API_KEY from .env

QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
COLLECTION_NAME = "freen_memory"

# Choose sentence-transformers model (small, fast, 384-dim)
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"

class FreenMemory:
    def __init__(self, qdrant_url=QDRANT_URL, api_key=QDRANT_API_KEY, collection=COLLECTION_NAME):
        assert qdrant_url and api_key, "Set QDRANT_URL and QDRANT_API_KEY in .env"
        self.client = QdrantClient(url=qdrant_url, api_key=api_key)
        self.collection = collection
        self.embed_model = SentenceTransformer(EMBED_MODEL_NAME)
        # 384 is the vector dimension for all-MiniLM-L6-v2
        self.dim = 384
        self._ensure_collection()

    def _ensure_collection(self):
        # create or recreate collection if needed
        try:
            # if collection exists, this will raise; safer to check then create if missing
            if self.collection not in [c.name for c in self.client.get_collections().collections]:
                self.client.create_collection(
                    collection_name=self.collection,
                    vector_size=self.dim,
                    distance="Cosine"
                )
        except Exception as e:
            # fallback: recreate collection
            # self.client.recreate_collection(collection_name=self.collection, vector_size=self.dim, distance="Cosine")
            print("Collection check/create resulted in:", e)

    def _embed(self, texts):
        # text -> numpy array(s)
        emb = self.embed_model.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
        return emb

    def add_memory(self, text, meta=None, id=None):
        """
        Add a memory to Qdrant.
        - text: short string (the memory)
        - meta: dict with optional metadata (e.g., {"type":"chat", "timestamp": 123456789, "source":"voice"})
        - id: optional unique id; if None, we use timestamp-based id
        """
        if meta is None:
            meta = {}
        if "timestamp" not in meta:
            meta["timestamp"] = int(time.time())
        vector = self._embed([text])[0]
        if id is None:
            id = f"{meta['timestamp']}_{abs(hash(text)) % (10**9)}"
        point = {
            "id": id,
            "vector": vector.tolist(),
            "payload": {"text": text, **meta}
        }
        # upsert into qdrant
        self.client.upsert(collection_name=self.collection, points=[point])
        return id

    def get_relevant_memories(self, query_text, top_k=3, min_score=None):
        """
        Return top_k memories for the query_text.
        Response: list of dicts: {"id","score","payload"}
        """
        q_vec = self._embed([query_text])[0]
        hits = self.client.search(collection_name=self.collection, query_vector=q_vec.tolist(), limit=top_k)
        # hits are qdrant client SearchResult objects; shape may vary by client version
        results = []
        for h in hits:
            # depending on client version: h.payload, h.score
            results.append({"id": h.id, "score": getattr(h, "score", None), "payload": getattr(h, "payload", {})})
        return results

    def delete_old_memories(self, older_than_seconds=60*60*24*30):
        """
        Simple pruning: delete memories older than `older_than_seconds`.
        This requires that we stored 'timestamp' in payload.
        Note: qdrant filter API syntax may vary by client version.
        """
        cutoff = int(time.time()) - older_than_seconds
        # Basic approach: search all points with a filter; if client version doesn't support delete by filter, fetch ids and delete in python
        try:
            # naive fetch â€” get all points (careful with large DB); for prototype this is OK
            # Use pagination if many points exist
            resp = self.client.scroll(collection_name=self.collection, with_payload=True)
            ids_to_delete = []
            for item in resp:
                ts = item.payload.get("timestamp", 0)
                if ts < cutoff:
                    ids_to_delete.append(item.id)
            if ids_to_delete:
                self.client.delete(collection_name=self.collection, points=ids_to_delete)
                print(f"Deleted {len(ids_to_delete)} old memories.")
        except Exception as e:
            print("Prune failed:", e)
