In [7]:
from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi
from dotenv import load_dotenv
from fastapi import FastAPI, Query, Path, HTTPException
import logging
from typing import List, Optional
import os
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import random 
import json 

import numpy as np
import gridfs
from sentence_transformers import SentenceTransformer
import faiss
from sklearn.decomposition import PCA
import pickle
from sklearn.preprocessing import normalize


In [8]:
load_dotenv()
uri = os.getenv("MONGODB_URI")
try:
    client = MongoClient(uri, server_api=ServerApi('1'))
    db = client.ica_conf
    papers_collection = db['papers']
    embeddings_collection = db['embeddings']
except Exception as e:
    print("Error connecting to MongoDB:", e)
    raise 

# Load the embeddings directly from MongoDB
print("Loading embeddings from MongoDB...")
embedding_docs = list(embeddings_collection.find({}, {"_id": 0, "paper_id": 1, "embedding": 1}))
paper_embeddings = {doc["paper_id"]: doc["embedding"] for doc in embedding_docs}
print("Loading embeddings finished!")

# Convert embeddings to NumPy array for FAISS
embeddings = np.array(list(paper_embeddings.values()), dtype=np.float32)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)

# Load the SentenceTransformer model
model = SentenceTransformer('all-MiniLM-L6-v2')

with open("pca_model.pkl", "rb") as f:
    pca = pickle.load(f)

Loading embeddings from MongoDB...
Loading embeddings finished!


In [9]:
query = 'history'
query_embedding = model.encode(query).reshape(1, -1)
query_embedding = pca.transform(query_embedding) 
query_embedding = normalize(query_embedding, norm='l2') 

In [10]:
query_embedding

array([[-0.18035112,  0.36617496,  0.00382748, -0.03999193, -0.10932939,
        -0.00104222, -0.19386088,  0.11938665,  0.15355717,  0.29081877,
         0.00515902,  0.11793398, -0.12257402,  0.17042483, -0.01794571,
         0.12195163,  0.22689171,  0.06965355, -0.07631791,  0.10628783,
         0.2116692 ,  0.0211395 ,  0.04406381,  0.15793029,  0.0431853 ,
        -0.17344472, -0.05742695, -0.03860859, -0.02712459,  0.13316845,
         0.14382991, -0.17390026,  0.06540496, -0.07261153,  0.12883144,
        -0.10479668,  0.16409655, -0.0279471 ,  0.06509209,  0.0456898 ,
        -0.04816658,  0.01565317, -0.10316963,  0.1280073 ,  0.08689662,
        -0.19221146, -0.18517006,  0.00834335,  0.07112796,  0.18116358,
        -0.0854278 ,  0.00265226, -0.04267308, -0.04016757,  0.06526354,
        -0.05516535, -0.01972004,  0.06714266, -0.0751992 ,  0.09940738,
         0.02050727, -0.05002487,  0.01648288,  0.06966137, -0.05388484,
         0.0555284 ,  0.01986847, -0.01773694,  0.0

In [11]:
print("Query embedding shape after PCA and normalization:", query_embedding.shape)


Query embedding shape after PCA and normalization: (1, 80)


In [12]:
print("FAISS index dimension:", dimension)


FAISS index dimension: 80
