<a href="https://colab.research.google.com/github/cipB14/Questify/blob/main/RAG_metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
from sentence_transformers import SentenceTransformer, util
import random
import math
import pandas as pd
import matplotlib.pyplot as plt

# Domain-specific data
query_domains = {
    "DSA": ("graph traversal", [
        "Stacks are linear data structures which follow the Last In First Out principle.",
        "A binary search tree is a type of binary tree where left children are less and right children are greater.",
        "Dynamic programming is used to optimize recursive problems.",
        "Graphs are collections of nodes connected by edges, useful for traversal algorithms.",
        "Hash tables allow fast access to data using keys."
    ]),
    "Operating Systems": ("process scheduling", [
        "Process scheduling determines which process runs on the CPU next.",
        "Deadlock occurs when processes wait indefinitely for resources held by each other.",
        "Virtual memory allows execution of processes that may not be completely in memory.",
        "Semaphores are used for process synchronization to avoid race conditions.",
        "Paging is a memory management scheme that eliminates the need for contiguous allocation."
    ]),
    "Software Engineering": ("agile methodology", [
        "Agile promotes iterative development and frequent feedback.",
        "Scrum is a popular agile framework using sprints and stand-up meetings.",
        "Requirement gathering is essential for successful software design.",
        "Software testing ensures correctness and quality of the system.",
        "Version control helps track and manage code changes in teams."
    ]),
    "Computer Networks": ("TCP congestion control", [
        "TCP uses congestion control to avoid overwhelming the network.",
        "OSI model has seven layers for network communication.",
        "Routing algorithms determine the optimal path for packet delivery.",
        "Switches and routers operate at different layers to forward data.",
        "IP addressing uniquely identifies each device on a network."
    ]),
    "Database Management Systems": ("ACID properties", [
        "ACID ensures reliable processing of database transactions.",
        "Normalization removes redundancy and improves integrity.",
        "Joins combine rows from two or more tables in SQL queries.",
        "Indexes improve the speed of data retrieval operations.",
        "SQL is used to manage and query relational databases."
    ]),
    "Java Programming": ("Java OOP", [
        "Java is an object-oriented programming language.",
        "Encapsulation hides the internal state of an object.",
        "Inheritance allows one class to inherit properties and methods of another class.",
        "Polymorphism enables a method to behave differently based on the object calling it.",
        "Abstraction simplifies complex systems by providing only necessary details to the user."
    ]),
    "Compiler Design": ("lexical analysis", [
        "A compiler translates high-level programming languages into machine code.",
        "Lexical analysis is the first phase of a compiler.",
        "Syntax trees are used to represent the structure of source code.",
        "Semantic analysis ensures the meaning of the program is correct.",
        "Optimization improves the efficiency of the generated code."
    ]),
    "Machine Learning": ("supervised learning", [
        "Supervised learning uses labeled data to train models.",
        "Linear regression predicts a continuous output variable.",
        "Classification assigns labels to input data based on patterns.",
        "Overfitting occurs when a model learns the noise in the training data.",
        "Cross-validation helps in assessing the model's generalizability."
    ]),
    "Distributed Systems": ("fault tolerance", [
        "Distributed systems are composed of multiple independent components working together.",
        "Fault tolerance ensures that the system continues to operate even when some parts fail.",
        "Replication stores copies of data across different nodes to prevent data loss.",
        "Consistency ensures that all copies of data are the same across nodes.",
        "Sharding splits data across different nodes for improved scalability."
    ]),
    "Computer Architecture": ("CPU pipelines", [
        "A CPU pipeline allows multiple instructions to be processed simultaneously at different stages.",
        "Pipelining improves the throughput of a CPU by executing multiple instructions in parallel.",
        "Superscalar architecture allows multiple instructions to be executed per clock cycle.",
        "Branch prediction guesses the outcome of a branch instruction to minimize delays.",
        "Cache memory stores frequently accessed data for faster retrieval."
    ])
}

# Common distractors
common_distractors = [
    "Cloud computing provides scalable resources over the internet.",
    "JavaScript is used to make web pages interactive.",
    "REST APIs allow communication between client and server.",
    "HTML and CSS define the structure and style of web documents.",
    "Object-oriented programming uses classes and objects to model software.",
    "Containers help in deploying consistent environments.",
    "Software testing includes unit and integration testing.",
    "Git is a version control system for tracking changes in source code.",
    "Microservices architecture breaks systems into smaller independent services.",
    "TCP/IP is the foundation of internet communication protocols.",
    "Blockchain enables decentralized and immutable transaction ledgers.",
    "Python is a high-level programming language used in various domains.",
    "Agile focuses on fast delivery through iterative development."
]

# Expand document size to 256–512 tokens
random.seed(42)
def expand_text(doc, min_len=256, max_len=512):
    words = doc.split()
    base = words.copy()
    while len(words) < min_len:
        words += base[:min(max_len - len(words), len(base))]
    return ' '.join(words[:random.randint(min_len, max_len)])

# Embed using SBERT
embedder = SentenceTransformer('all-MiniLM-L6-v2')



In [40]:
def evaluate_domain_with_labels(domain_name, query, relevant_docs):
    relevant_expanded = [expand_text(doc) for doc in relevant_docs]
    distractor_expanded = [expand_text(doc) for doc in common_distractors]

    # Label relevant docs as 1-5 and distractors as 0
    all_docs = relevant_expanded + distractor_expanded
    labels = [domain_name] * len(relevant_expanded) + [0] * len(distractor_expanded)

    doc_embeddings = embedder.encode(all_docs, convert_to_tensor=True)
    query_embedding = embedder.encode(query, convert_to_tensor=True)
    cosine_scores = util.cos_sim(query_embedding, doc_embeddings)[0]

    top_k = 5
    top_results = np.argsort(-cosine_scores.cpu())[:top_k]
    retrieved_labels = [labels[i] for i in top_results]

    # Convert retrieved labels to binary relevance (1 for relevant, 0 for non-relevant)
    binary_relevance = [1 if label == domain_name else 0 for label in retrieved_labels]

    # Precision: Proportion of relevant documents retrieved in top-k
    relevant_retrieved = sum(binary_relevance)
    precision = relevant_retrieved / top_k

    # Recall: Proportion of relevant documents retrieved / total relevant docs
    recall = relevant_retrieved / len(relevant_expanded)

    f1 = f1_score([1 if label == domain_name else 0 for label in retrieved_labels], [1]*top_k)
    avg_similarity = cosine_scores[top_results].mean().item()

    def dcg(rels):
        return sum(rel / math.log2(idx + 2) for idx, rel in enumerate(rels))

    def ndcg(retrieved, ideal):
        dcg_val = dcg(retrieved)
        idcg_val = dcg(sorted(ideal, reverse=True))
        return dcg_val / idcg_val if idcg_val > 0 else 0.0

    def mean_reciprocal_rank(retrieved):
        for idx, rel in enumerate(retrieved):
            if rel == 1:
                return 1 / (idx + 1)
        return 0.0

    # nDCG using binary relevance
    ideal_relevance = [1 if label == domain_name else 0 for label in labels]
    ndcg_score = ndcg(binary_relevance, sorted(ideal_relevance, reverse=True))
    mrr_score = mean_reciprocal_rank(binary_relevance)

    # Return the metrics for tabulation and visualization
    return precision, recall, f1, avg_similarity, ndcg_score, mrr_score


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Initialize a list to store the results for each domain
results = []

# Evaluate each domain and store the results
for domain, (query, docs) in query_domains.items():
    # Store the output of evaluate_domain_with_labels() function into results
    precision, recall, f1, avg_similarity, ndcg_score, mrr_score = evaluate_domain_with_labels(domain, query, docs)

    # Append the results to the list
    results.append([
        domain, precision, recall, f1, avg_similarity, ndcg_score, mrr_score
    ])

# Convert the results into a DataFrame for easier tabulation and visualization
df = pd.DataFrame(results, columns=["Domain", "Precision", "Recall", "F1 Score", "Avg Similarity", "nDCG@5", "MRR"])

# Convert the metric columns to numeric values explicitly
metric_columns = ["Precision", "Recall", "F1 Score", "Avg Similarity", "nDCG@5", "MRR"]
df[metric_columns] = df[metric_columns].apply(pd.to_numeric, errors='coerce')

# Display tabulated results
print(df)

# Compute the average of each metric
avg_metrics = df[metric_columns].mean(axis=0)  # Only include the metric columns
print("\nAverage of each metric:")
print(avg_metrics)

# Visualization: For each subject, plot all metrics in the same graph
fig, axes = plt.subplots(len(df), 1, figsize=(10, len(df) * 5))

# Plot for each subject
for i, row in df.iterrows():
    domain = row["Domain"]
    metrics = row[metric_columns].values
    axes[i].bar(["Precision", "Recall", "F1 Score", "Avg Similarity", "nDCG@5", "MRR"], metrics, color=['blue', 'green', 'purple', 'orange', 'red', 'cyan'])
    axes[i].set_title(f'{domain} Metrics')
    axes[i].set_ylabel('Score')
    axes[i].set_ylim(0, 1)  # Assuming the metric scores are between 0 and 1
    axes[i].set_xticks(range(len(metrics)))
    axes[i].set_xticklabels(["Precision", "Recall", "F1 Score", "Avg Similarity", "nDCG@5", "MRR"], rotation=45)

# Adjust the layout to prevent overlap
plt.tight_layout()

# Show the plots
plt.show()
