In [None]:
from flask import Flask, render_template, request
from functools import lru_cache
import math
import os
from dotenv import load_dotenv

In [None]:
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Searcher

In [None]:
load_dotenv()
os.environ['INDEX_NAME'] = 'saf.train.2bits'
os.environ['INDEX_ROOT'] = 'experiments/notebook/indexes'
os.environ['PORT'] = '8890'
os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE']='True'

In [None]:
INDEX_NAME = os.getenv("INDEX_NAME")
INDEX_ROOT = os.getenv("INDEX_ROOT")
app = Flask(__name__)

In [None]:
import pandas as pd
def create_entry(row):
    return 'Input: {"question": "' + str(row['question']) + '", "student_answer": "' + str(row['student']) + '", "reference_answer": "' + str(row['reference']) + '"} Output: {"label": "' + row['label'] + '", "numeric_score": ' + str(row['score']) + ', "feedback": "' + row['feedback'] + '"}'

df = pd.read_csv("data/train.csv")

df['Entry'] = df.apply(create_entry, axis=1)
entries = df['Entry'].tolist()
collection = entries

In [None]:
# searcher = Searcher(index=INDEX_NAME, index_root=INDEX_ROOT)
counter = {"api" : 0}
# /notebook/indexes
# config = ColBERTConfig(root="experiments")
config = RunConfig(experiment='notebook')
searcher = Searcher(index='saf.train.2bits', index_root='experiments/notebook/indexes', collection=collection)

In [None]:
@lru_cache(maxsize=1000000)
def api_search_query(query, k):
    print(f"Query={query}")
    if k == None: k = 10
    k = min(int(k), 100)
    pids, ranks, scores = searcher.search(query, k=100)
    pids, ranks, scores = pids[:k], ranks[:k], scores[:k]
    print("pidis: ", pids)
    passages = [searcher.collection[pid] for pid in pids]
    probs = [math.exp(score) for score in scores]
    probs = [prob / sum(probs) for prob in probs]
    topk = []
    for pid, rank, score, prob in zip(pids, ranks, scores, probs):
        text = searcher.collection[pid]
        d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob}
        topk.append(d)
    topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
    
    return {"query" : query, "topk": topk}

In [None]:
@app.route("/api/search", methods=["GET"])
def api_search():
    if request.method == "GET":
        counter["api"] += 1
        print("API request count:", counter["api"])
        return api_search_query(request.args.get("query"), request.args.get("k"))
    else:
        return ('', 405)