In [1]:
import pandas as pd

# data = pd.read_json("experiments/results/baseline_no_ref/search_results.jsonl", lines=True)
# data.head()

In [2]:
import json

json_data = [json.loads(line) for line in open("experiments/results/baseline_no_ref/search_results.jsonl")]
print(json_data[:3])

[{'record': {'source_doi': '10.1146/annurev-astro-081811-125521', 'sent_original': 'Gonzalez, Zaritsky Zabludoff (2007) find an increasing stellar fraction with decreasing group/cluster mass, peaking below 10 14 M ⊙ , where stellar and gas masses are equal.', 'sent_no_cit': ' find an increasing stellar fraction with decreasing group/cluster mass, peaking below 10 14 M ⊙ , where stellar and gas masses are equal.', 'sent_idx': 506, 'citation_dois': ['10.1086/519729'], 'pubdate': 20120901, 'resolved_bibcodes': ['2007ApJ...666..147G'], 'sent_cit_masked': '[REF] find an increasing stellar fraction with decreasing group/cluster mass, peaking below 10 14 M ⊙ , where stellar and gas masses are equal.', 'expanded_query': None}, 'results': [{'pubdate': 19970501, 'citation_count': 953, 'text': 'A lower limit to the total cluster mass in stars is M_stars ~900 M_⊙ (likely a factor of <2 underestimate), and a lower limit to the recent star formation rate is ~10^-4 M_⊙ yr^-1.', 'doi': '10.1086/118389

In [None]:
from metrics import Metric


class RankFuser:
    """
    A class that produces a weighted sum of scores from multiple scoring functions,
    then uses those weights to rerank a set of results
    """

    def __init__(self, config: dict[str, float]):
        """
        Initializes the RankFuser with a configuration dictionary that maps scoring function names to their weights.

        Args:
            config (dict[str, float]): A dictionary where keys are scoring function names and values are their respective weights.
        """
        self.config = config
        self.metrics = [Metric.get_metric(name) for name in config.keys()]
        self.weights = list(config.values())

    def rerank(self, data: list[dict]) -> list[dict[str, pd.Series | pd.DataFrame]]:
        """
        Expects a list of dictionaries, each containing "record" and "results" keys.
        - "record": the dict representing the original query record
        - "results": ordered list of dicts representing the search results

        Returns:
            list[dict]: The data with reranked results
            dict:
            - "record": pd.Series of the original query record
            - "results": pd.DataFrame of the reranked search results
        """
        rows = []

        for row in data:
            query = pd.Series(row["record"])
            results = pd.DataFrame(row["results"])
            reranked_results = self._rerank_single(query, results)
            rows.append({"record": query, "results": reranked_results})

        return rows

    def _rerank_single(self, query: pd.Series, results: pd.DataFrame) -> pd.DataFrame:
        """
        Reranks the results DataFrame based on the weighted sum of scores from the configured metrics.

        Args:
            query (pd.Series): The input record for which results are being reranked.
            results (pd.DataFrame): The DataFrame containing results to be reranked.

        Returns:
            pd.DataFrame: The reranked results.
        """
        results_df = results.copy()
        results_df["weighted_score"] = 0

        # Compute metrics & build weighted score
        for metric, weight in zip(self.metrics, self.weights):
            scores = metric(query, results)
            results_df[metric.name] = scores
            results_df["weighted_score"] += scores * weight

        # Sort by the weighted score in descending order
        return results_df.sort_values("weighted_score", ascending=False).reset_index(drop=True)

# TODO: create a second metric and add it here, test it
rf = RankFuser(
    config={
        "recency": 0.5,
        "log_citations": 0.5,
    }
)
print(rf.metrics)
print(rf.weights)

[<metrics.Recency object at 0x56fd2ab10>, <metrics.LogCitations object at 0x56fc2ed10>]
[0.5, 0.5]


In [18]:
# single_results_df = pd.DataFrame(data.results.iloc[0].tolist())
# single_record = pd.Series(data.record.iloc[0])
# # single_results_df.head()

# type(single_results_df.pubdate)

In [19]:
reranked = rf.rerank(json_data[:3])

Got columns: Index(['pubdate', 'citation_count', 'text', 'doi', 'metric'], dtype='object')
Got columns: Index(['doi', 'pubdate', 'citation_count', 'text', 'metric'], dtype='object')
Got columns: Index(['doi', 'pubdate', 'citation_count', 'text', 'metric'], dtype='object')


In [20]:
print(type(reranked[0]['results']))
print(reranked[1]['results']['citation_count'].tolist())

<class 'pandas.core.frame.DataFrame'>
[5291, 15792, 15792, 16486, 16486, 16486, 10302, 10302, 10302, 805, 805, 805, 9171, 9171, 1229, 4165, 4165, 4717, 4717, 882, 882, 882, 1795, 387, 500, 1663, 1632, 1632, 2888, 1567, 985, 1343, 4058, 4058, 4058, 951, 951, 482, 2045, 3131, 859, 498, 713, 1837, 278, 2383, 861, 820, 1332, 1332, 307, 307, 1979, 2525, 2525, 748, 1589, 1589, 2476, 536, 1821, 2108, 2108, 468, 1218, 1218, 2370, 2370, 2370, 2370, 159, 678, 442, 442, 442, 3858, 1438, 1400, 1400, 173, 1514, 1514, 1514, 1514, 323, 1852, 146, 234, 2046, 982, 982, 1611, 313, 312, 2696, 1464, 1464, 631, 631, 354, 945, 162, 162, 2814, 344, 206, 400, 400, 400, 245, 245, 413, 899, 130, 960, 960, 2097, 430, 430, 262, 262, 262, 262, 669, 669, 3668, 259, 446, 569, 150, 988, 2504, 156, 839, 237, 237, 275, 844, 1095, 306, 306, 306, 306, 552, 1863, 769, 961, 307, 1361, 836, 478, 478, 478, 256, 662, 485, 919, 1347, 1347, 1388, 1388, 1388, 1388, 767, 767, 767, 339, 339, 2598, 2598, 923, 1238, 630, 740, 1644, 

In [21]:
reranked[0]['results'].head(10)

Unnamed: 0,pubdate,citation_count,text,doi,metric,weighted_score,recency,log_citations
0,20030701,7808,The spheroid IMF is less robustly determined b...,10.1086/376392,0.631763,3.321706,-2.319619,8.963032
1,20100901,1726,"A simple mass-quenching law, holding over a br...",10.1088/0004-637X/721/1/193,0.63568,3.177536,-1.099068,7.454141
2,20111201,648,"Quiescent, moderately, and highly star-forming...",10.1088/0004-637X/742/2/96,0.614475,2.957078,-0.561277,6.475433
3,20120901,366,For galaxies with stellar mass less than 1.0 ×...,10.1088/0004-637X/757/1/85,0.634023,2.952681,-0.0,5.905362
4,20120901,366,There is a stellar mass threshold of M_stellar...,10.1088/0004-637X/757/1/85,0.6308,2.952681,-0.0,5.905362
5,20030701,3693,Protosolar elemental and isotopic abundances a...,10.1086/375492,0.612405,2.947423,-2.319619,8.214465
6,20030701,3693,Photospheric abundances give mass fractions of...,10.1086/375492,0.637769,2.947423,-2.319619,8.214465
7,20041001,3157,A tight correlation (+/-0.1 dex) between stell...,10.1086/423264,0.622216,2.934819,-2.188057,8.057694
8,20041001,3157,The mass-metallicity relation is steep from 10...,10.1086/423264,0.63479,2.934819,-2.188057,8.057694
9,20081201,1663,In regions where the interstellar medium (ISM)...,10.1088/0004-6256/136/6/2782,0.609471,2.929327,-1.558325,7.41698


In [6]:
print(type(reranked.results.iloc[0]))
for i in range(len(reranked)):
    print([result["citation_count"] for result in reranked.results.iloc[i]])

AttributeError: 'list' object has no attribute 'results'

In [None]:
reranked.results.iloc[0][0].keys()