# params

In [None]:
file_root_path = "/content/drive/Shareddrives/Curie/benchmarks/public_release" #@param

# read files in drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# install

In [None]:
!pip install json5

In [None]:
!pip install bert-score

In [None]:
!pip install -r rouge/requirements.txt
!pip install rouge-score
!pip install Bio
!pip install Levenshtein

# imports

In [None]:
import json5
import re
import functools
from typing import Any, Tuple, Union, Dict
from Bio import Align
import glob
from bert_score import score
from rouge_score import rouge_scorer
import numpy as np
import ast
import os
import google.generativeai as genai
import Levenshtein

# eval functions

## LLMSim eval

In [None]:
from google.colab import userdata

### LLMSim util

In [None]:
@functools.lru_cache(maxsize=1)
def get_model(model_name: str = 'gemini-1.5-pro-latest'):

  return genai.GenerativeModel(model_name=model_name)


def llm_output(client: Any, prompt: str) -> str:
  # client=None for external api
  return get_model().generate_content(prompt).text

def model_eval_json(
    record_id: str | None,
    json_ground_truth: list[dict[str, Any]],
    json_model_response: list[dict[str, Any]],
    eval_prompt: str,
    client: Any,
) -> dict[str, Any]:
  """Evaluate with json ground truth and model response."""
  eval_list = []
  genai.configure(api_key=userdata.get('GEMINI_API_KEY'))
  for j, ground_truth_item in enumerate(json_ground_truth):
    for k, model_response_item in enumerate(json_model_response):
      # Add index to llm input to suppress hallucination on json_extracted_index
      # prediction
      model_response_item["json_extracted_index"] = k
    prompt = (
        load_matsci_prompt(eval_prompt)
        .replace(
            "{{json_ground_truth}}", json5.dumps(ground_truth_item, indent=2)
        )
        .replace(
            "{{json_extracted_list}}", json5.dumps(json_model_response, indent=2)
        )
    )
    output = llm_output(client=client, prompt=prompt)
    try:
      output_json = json5.loads(output)
    except Exception as e:  # pylint: disable=broad-except
      print("Skipping incomplete last item in output: ", e)
      inds = [m.start() for m in re.finditer(r",\s*\{", output)]
      if inds:
        ind = inds[-1]
        output_json = json5.loads(output[:ind] + "]")
      else:
        output_json = []
    if isinstance(output_json, list):
      # Handle edge case that model hallucinated outputing list enclosing json.
      if not output_json:
        output_json = {}
      else:
        output_json = output_json[0]
    output_json["json_ground_truth_index"] = j
    output_json["json_ground_truth"] = ground_truth_item
    output_json["json_extracted"] = {}
    # If not in it, it means llm didn't find a good match, so leave it empty.
    if "json_extracted_index" in output_json:
      if (
          str(output_json["json_extracted_index"]).isdigit()
          and int(output_json["json_extracted_index"]) > 0
          and int(output_json["json_extracted_index"])
          < len(json_model_response)
      ):
        output_json["json_extracted"] = json_model_response[
            int(output_json["json_extracted_index"])
        ]
      else:
        del output_json["json_extracted_index"]
    eval_list.append(output_json)
  return {
      "record_id": record_id,
      "ground_truth_length": len(json_ground_truth),
      "model_response_length": len(json_model_response),
      "response_json": eval_list,
  }

### LLMSim dft

In [None]:
def get_lmsim_score_dft(prediction, reference):
  return dft_domain_expert_model_based_eval(reference, prediction, client=None) # Client=None for external api


_METADATA_EVAL_PROMPT_FILENAME = file_root_path + "/prompts/dft_metadata_eval_output_1_shot.txt"
_STRUCTURE_EVAL_PROMPT_FILENAME = file_root_path + "/prompts/dft_structure_eval_output_1_shot.txt"


def get_dft_model_response_field(
    model_output_value: dict[str, Any], field_name: str
) -> list[Any] | str:
  """Returns the model response for a given field from the inference output.

     This applies to json responses from the dft chained inference output.

  Args:
    model_output_value: The model output response (from one two) as a dict.
    field_name: The name of the field to extract. This should be one of
      "structure_metadata", "dft_metadata", or "code".
  """
  if field_name in ["structure_metadata"]:
    if field_name in model_output_value:
      print(model_output_value)
      return model_output_value["structure_metadata"]
    else:
      return ""

  elif field_name in ["dft_metadata"]:
    if field_name in model_output_value:
      return model_output_value["dft_metadata"]
    else:
      return ""

  elif field_name == "code":
    if field_name in model_output_value:
      code = "\n".join(
          [x["code_element"] for x in model_output_value["code_elements"]]
      )
      code += "\n" + model_output_value["execution_code"]
      return code
    else:
      return ""

  raise ValueError(f"Unknown field name: {field_name}")


def get_annotated_structure_metadata_and_dft_params(
    gt_paper_code: str, verbose: int = 0
) -> dict[str, list[str]]:
  """Returns structure metadata and dft params from the ground truth code.

  Args:
    gt_paper_code: The ground truth code from .py file as a string.
    verbose: The verbosity level.
  """
  structures = []
  dft_params = []

  paper = gt_paper_code
  if "structure_metadata_" in paper:
    parts = paper.split("structure_metadata_")[1:]
    whole_parts = ["structure_metadata_" + part for part in parts]

    for part in whole_parts:
      if verbose > 1:
        print("PART:\n", part)
      if "parse_raw(" not in part:
        continue
      left, right, *_ = part.split("parse_raw(")
      if "StructureMetadata" in left:
        end_struc = ")"
        if "')" in right:
          end_struc = "')"
        elif "'\n)" in right:
          end_struc = "'\n)"
        struc_json = right.split(end_struc)[0].strip()
        if verbose > 0:
          print("Extracted structure:\n", struc_json)
        # clean_json = struc_json.replace('NaN', '"NaN"')
        # struc_json = ast.literal_eval(clean_json)
        structures.append(struc_json)

  if "dft_params_" in paper:
    parts = paper.split("dft_params_")[1:]
    # print(parts)
    whole_parts = ["dft_params_" + part for part in parts]

    for part in whole_parts:
      if verbose > 1:
        print("PART:\n", part)
      if "parse_raw(" not in part:
        continue
      left, right, *_ = part.split("parse_raw(")
      if "DFTParameters" in left:
        end_struc = ")"
        if "')" in right:
          end_struc = "')"
        elif "'\n)" in right:
          end_struc = "'\n)"
        dft_params_str = right.split(end_struc)[0].strip()
        if verbose > 0:
          print("Extracted dft_param:\n", dft_params_str)
        # clean_json = dft_params_str.replace('NaN', '"NaN"')
        # dft_params_str = ast.literal_eval(clean_json)
        dft_params.append(dft_params_str)

  gt_struc_jsons = []
  for struct_metadata in structures:
    gt_json = struct_metadata.split("'")[1]
    clean_json = gt_json.replace("NaN", '"NaN"')
    try:
      gt_structure_json = ast.literal_eval(clean_json)
      gt_struc_jsons.append(gt_structure_json)
    except Exception:  # pylint: disable=broad-exception-caught
      gt_struc_jsons.append(clean_json)

  gt_dft_params_jsons = []
  for dft_param in dft_params:
    gt_json = dft_param.split("'")[1]
    clean_json = gt_json.replace("NaN", '"NaN"')
    try:
      gt_dft_json = ast.literal_eval(clean_json)
      gt_dft_params_jsons.append(gt_dft_json)
    except Exception:  # pylint: disable=broad-exception-caught
      gt_dft_params_jsons.append(clean_json)

  return {
      "structures_metadata": gt_struc_jsons,
      "dft_params": gt_dft_params_jsons,
  }


def get_material_composition_from_struc(
    structure: str | dict[str, str],
) -> str | None:
  """Returns material composition from the structure metadata dict or string."""
  if isinstance(structure, dict):
    if "composition" in structure:
      return structure["composition"]
  elif isinstance(structure, str):
    if "composition" in structure:
      material = structure.split(r"\"composition\":")[1].split(",")[0]
      return material
  return None


def get_json_from_str(input_str: str) -> dict[str, Any] | None:
  """Returns the json object from the ground truth input string."""
  output_val = input_str.replace("NaN", '"NaN"')
  output_val = output_val.replace("true", '"1.0"')
  output_val = output_val.replace("false", '"NaN"')
  try:
    output_val = ast.literal_eval(output_val)
  except ValueError:
    return None
  return output_val


def parse_ground_truth_dft(ground_truth: str, client: Any) -> list[dict[str, Any]]:
  """Parses ground truth."""
  try:
    json_ground_truth = json5.loads(
        ground_truth.replace("\n", "").replace("\\", "")
    )
    if json_ground_truth and isinstance(json_ground_truth[0], str):
      json_ground_truth = [json5.loads(item) for item in json_ground_truth]
  except Exception:  # pylint: disable=broad-except
    print("***using llm to parse")
    ground_truth = llm_output(
        client,
        prompt="Extract ground truth json list from the following text.\n"
        + ground_truth
        + "\nMake sure to remove all backslashes for escape characters. Output"
        " the json list ONLY, without any explanation, prefix or suffix:\n",
    )
    print("***llm_ground_truth:\n", ground_truth)
    json_ground_truth = json5.loads(
        ground_truth.replace("\n", "").replace('\\"', "").replace("\\", "")
    )

  return json_ground_truth


def parse_model_response_dft(
    model_response: str, client: Any, use_llm=False
) -> list[dict[str, Any]]:
  """Parses model response."""
  def remove_prefix_suffix(text):
    return text.replace("\n", "").removeprefix("```json").removesuffix("```json").removeprefix("```").removesuffix("```").removeprefix("`").removesuffix("`")
  if use_llm:
    model_response = llm_output(
            client,
            prompt="Extract model_response json list from the following text.\n"
            + model_response
            + '\nMake sure all None values are converted to "NaN". Output the json'
            " list"
            " ONLY, without any explanation, prefix or suffix:\n",
          )
  response_text = remove_prefix_suffix(model_response)
  try:
    try:
      formatted_text = re.sub(r"(?<=\w)'(?=\w|\s)", "\\'", response_text)
      if not formatted_text:
        formatted_text = "{}"
        print("Response_text is empty")
      json_model_response = json5.loads(formatted_text)
    except Exception as e:  # pylint: disable=broad-except
      print("Skipping incomplete last item: ", e)
      print("***", response_text)
      ind = [m.start() for m in re.finditer(r",\s*\{", response_text)][-1]
      json_model_response = json5.loads(response_text[:ind] + "]")
  except:
      return parse_model_response_dft(model_response, client, use_llm=True) if not use_llm else []
  return json_model_response


def dft_model_eval_paper(
    record_id: str | None,
    ground_truth: str,
    model_response: str,
    eval_prompt: str,
    client: Any,
) -> dict[str, Any]:
  """Runs model evaluation on material properties for a single paper.

  Args:
    record_id: record id or paper id.
    ground_truth: ground truth list in str type.
    model_response: model response list in str type.
    eval_prompt: eval prompt.
    client: llm client.

  Returns:
    model eval response json.
  """
  json_ground_truth = parse_ground_truth_dft(ground_truth, client)
  json_model_response = parse_model_response_dft(model_response, client)
  return model_eval_json(
      record_id=record_id,
      json_ground_truth=json_ground_truth,
      json_model_response=json_model_response,
      eval_prompt=eval_prompt,
      client=client,
  )


def dft_metadata_domain_expert_model_based_eval(
    ground_truth: str,
    model_response: str,
    client: Any | None = None,
    eval_prompt: str = _METADATA_EVAL_PROMPT_FILENAME,
    verbose: bool = True,
) -> dict[str, Any]:
  return dft_domain_expert_model_based_eval(
      ground_truth=ground_truth,
      model_response=model_response,
      client=client,
      eval_prompt=eval_prompt,
      verbose=verbose,
  )


def dft_structure_domain_expert_model_based_eval(
    ground_truth: str,
    model_response: str,
    client: Any | None = None,
    eval_prompt: str = _STRUCTURE_EVAL_PROMPT_FILENAME,
    verbose: bool = True,
) -> dict[str, Any]:
  return dft_domain_expert_model_based_eval(
      ground_truth=ground_truth,
      model_response=model_response,
      client=client,
      eval_prompt=eval_prompt,
      verbose=verbose,
  )


def dft_domain_expert_model_based_eval(
    ground_truth: str,
    model_response: str,
    client: Any | None = None,
    eval_prompt: str = _METADATA_EVAL_PROMPT_FILENAME,
    verbose: bool = True,
) -> dict[str, Any]:
  """Runs model based eval on dft.

  Args:
    ground_truth: ground truth list in str type.
    model_response: model response list in str type.
    client: llm client.
    eval_prompt: eval prompt.
    verbose: whether to print out eval results.

  Returns:
    eval result.
  """
  if verbose:
    print("Model eval started...")
  eval_output_item = dft_model_eval_paper(
      record_id=None,
      ground_truth=ground_truth,
      model_response=model_response,
      eval_prompt=eval_prompt,
      client=client,
  )
  if verbose:
    print("Model eval finished.")
  eval_result = eval_overall_result(
      eval_output_item, verbose=verbose
  )
  if verbose:
    print("Eval results:\n", eval_result)
  return eval_result

### LLMSim mpve

In [None]:
def get_lmsim_score_mpve(prediction, reference):
  # TODO: Convert to external model client
  return mpve_domain_expert_model_based_eval(reference, prediction, client=None) # Client=None for external api


# TODO: Convert to Drive path
_EVAL_PROMPT_FILENAME = file_root_path + "/prompts/mat_eval_output_1_shot.txt"


def parse_ground_truth_mpve(ground_truth: str) -> list[dict[str, Any]]:
  json_ground_truth = json5.loads(ground_truth.replace("\n", ""))
  json_ground_truth.sort(key=lambda x: x["index"])
  # remove unnecessary fields for model eval to prevent hallucination
  for item in json_ground_truth:
    if "index" in item:
      del item["index"]
    if "paper_id" in item:
      del item["paper_id"]
    if "synonyms" in item:
      del item["synonyms"]
  return json_ground_truth


def parse_model_response_mpve(model_response: str) -> list[dict[str, Any]]:
  """Parses model response.

  Args:
    model_response:

  Returns:
  """
  response_text = (
      model_response.replace("\n", "")
      .removeprefix(" ")
      .removesuffix(" ")
      .removeprefix("```json")
      .removesuffix("```json")
      .removeprefix("```")
      .removesuffix("```")
      .removeprefix("`")
      .removesuffix("`")
  )
  try:
    formatted_text = re.sub(r"(?<=\w)'(?=\w|\s)", "\\'", response_text)
    if not formatted_text:
      formatted_text = "{}"
      print("Response_text is empty")
    json_model_response = json5.loads(formatted_text)
  except Exception as e:  # pylint: disable=broad-except
    print("Skipping incomplete last item: ", e)
    ind = [m.start() for m in re.finditer(r",\s*\{", response_text)][-1]
    json_model_response = json5.loads(response_text[:ind] + "]")
  return json_model_response


def mpv_model_eval_paper(
    record_id: str | None,
    ground_truth: str,
    model_response: str,
    eval_prompt: str,
    client: Any,
) -> dict[str, Any]:
  """Runs model evaluation on material properties for a single paper.

  Args:
    record_id: record id or paper id.
    ground_truth: ground truth list in str type.
    model_response: model response list in str type.
    eval_prompt: eval prompt.
    client: llm client.

  Returns:
    model eval response json.
  """
  json_ground_truth = parse_ground_truth_mpve(ground_truth)
  json_model_response = parse_model_response_mpve(model_response)
  return model_eval_json(
      record_id=record_id,
      json_ground_truth=json_ground_truth,
      json_model_response=json_model_response,
      eval_prompt=eval_prompt,
      client=client,
  )


def filter_ground_truth_properties(ground_truth: str) -> list[dict[str, Any]]:
  """Filters ground truth to only keep the properties we want to evaluate.

  Args:
    ground_truth: ground truth list in str type.

  Returns:
    filtered ground truth list.
  """
  json_ground_truth = json5.loads(ground_truth)
  filtered_ground_truth = []
  valid_property_names = [
      "bandgap",
      "band gap",
      "gap energy",
      "energy gap",
      "refractive_index",
      "refractive index",
      "index of refraction",
      "n-value",
      "n value",
  ]
  for item in json_ground_truth:
    for valid_property_name in valid_property_names:
      if valid_property_name in item["property_name"].lower():
        filtered_ground_truth.append(item)
        break
  return filtered_ground_truth


def mpve_domain_expert_model_based_eval(
    ground_truth: str,
    model_response: str,
    client: Any | None = None,
    eval_prompt: str = _EVAL_PROMPT_FILENAME,
    verbose: bool = True,
) -> dict[str, Any]:
  """Runs model based eval on material properties.

  Args:
    ground_truth: ground truth list in str type.
    model_response: model response list in str type.
    client: llm client.
    eval_prompt: eval prompt.
    verbose: whether to print out eval results.

  Returns:
    eval result.
  """
  if verbose:
    print("Model eval started...")
  eval_output_item = mpv_model_eval_paper(
      record_id=None,
      ground_truth=ground_truth,
      model_response=model_response,
      eval_prompt=eval_prompt,
      client=client,
  )
  if verbose:
    print("Model eval finished.")
  eval_result = eval_overall_result(
      eval_output_item, verbose=verbose
  )
  if verbose:
    print("Eval results:\n", eval_result)
  return eval_result

def load_matsci_prompt(filepath: str) -> str:
  """Loads matsci prompt.

  Args:
    filepath: filepath of prompt.

  Returns:
    Loaded prompt.
  """
  # return resources.GetResource(filepath).decode("utf-8").strip()
  with open(filepath, 'r') as file:
    text_content = file.read()
  return text_content.strip()





def eval_overall_result(
    eval_output_item: dict[str, Any], verbose: bool = False
) -> dict[str, Any]:
  """Gets overall eval result.

  Args:
    eval_output_item: eval output item.
    verbose: whether to print model eval original output.

  Returns:
    overall eval result.
  """
  num_match = sum([
      1 if ("json_extracted_index" in item) else 0
      for item in eval_output_item["response_json"]
  ])
  if verbose:
    print("Model eval original output:\n", eval_output_item)
  num_gt = eval_output_item["ground_truth_length"]
  num_response = eval_output_item["model_response_length"]
  pre = min(num_match / num_response if num_response else np.nan, 1.0)
  rec = min(num_match / num_gt if num_gt else np.nan, 1.0)
  return {
      "num_match": num_match,
      "num_ground_truth": num_gt,
      "num_model_response": num_response,
      "precision": pre,
      "recall": rec,
      "f1": 2.0 * pre * rec / (pre + rec) if pre + rec else 0.0,
  }

## bert eval

In [None]:
def get_bert_score(prediction, reference):
  precision, recall, F1 = score([prediction], [reference], lang="en", verbose=False)
  return {"bert_precision": precision.item(), "bert_recall" : recall.item(), "bert_f1": F1.item()}

## rouge eval

In [None]:
import six
import collections

class AggregateScore(
    collections.namedtuple("AggregateScore", ["low", "mid", "high"])):
  """Tuple containing confidence intervals for scores."""

class BootstrapAggregator(object):
  """Aggregates scores to provide confidence intervals.

  Sample usage:
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'])
    aggregator = Aggregator()
    aggregator.add_scores(scorer.score("one two three", "one two"))
    aggregator.add_scores(scorer.score("one two five six", "seven eight"))
    result = aggregator.aggregate()
    print result
    {'rougeL': AggregateScore(
         low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
         mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
         high=Score(precision=1.0, recall=0.66, fmeasure=0.80)),
     'rouge1': AggregateScore(
         low=Score(precision=0.0, recall=0.0, fmeasure=0.0),
         mid=Score(precision=0.5, recall=0.33, fmeasure=0.40),
         high=Score(precision=1.0, recall=0.66, fmeasure=0.80))}
  """

  def __init__(self, confidence_interval=0.95, n_samples=1000):
    """Initializes a BootstrapAggregator object.

    Args:
      confidence_interval: Confidence interval to compute on the mean as a
        decimal.
      n_samples: Number of samples to use for bootstrap resampling.

    Raises:
      ValueError: If invalid argument is given.
    """

    if confidence_interval < 0 or confidence_interval > 1:
      raise ValueError("confidence_interval must be in range [0, 1]")
    if n_samples <= 0:
      raise ValueError("n_samples must be positive")

    self._n_samples = n_samples
    self._confidence_interval = confidence_interval
    self._scores = collections.defaultdict(list)

  def add_scores(self, scores):
    """Adds a sample for future aggregation.

    Args:
      scores: Dict mapping score_type strings to a namedtuple object/class
        representing a score.
    """

    for score_type, score in six.iteritems(scores):
      self._scores[score_type].append(score)

  def aggregate(self):
    """Aggregates scores previously added using add_scores.

    Returns:
      A dict mapping score_type to AggregateScore objects.
    """

    result = {}
    for score_type, scores in six.iteritems(self._scores):
      # Stack scores into a 2-d matrix of (sample, measure).
      score_matrix = np.vstack(tuple(scores))
      # Percentiles are returned as (interval, measure).
      percentiles = self._bootstrap_resample(score_matrix)
      # Extract the three intervals (low, mid, high).
      intervals = tuple(
          (scores[0].__class__(*percentiles[j, :]) for j in range(3)))
      result[score_type] = AggregateScore(
          low=intervals[0], mid=intervals[1], high=intervals[2])
    return result

  def _bootstrap_resample(self, matrix):
    """Performs bootstrap resampling on a matrix of scores.

    Args:
      matrix: A 2-d matrix of (sample, measure).

    Returns:
      A 2-d matrix of (bounds, measure). There are three bounds: low (row 0),
      mid (row 1) and high (row 2). Mid is always the mean, while low and high
      bounds are specified by self._confidence_interval (which defaults to 0.95
      meaning it will return the 2.5th and 97.5th percentiles for a 95%
      confidence interval on the mean).
    """

    # Matrix of (bootstrap sample, measure).
    sample_mean = np.zeros((self._n_samples, matrix.shape[1]))
    for i in range(self._n_samples):
      sample_idx = np.random.choice(
          np.arange(matrix.shape[0]), size=matrix.shape[0])
      sample = matrix[sample_idx, :]
      sample_mean[i, :] = np.mean(sample, axis=0)

    # Take percentiles on the estimate of the mean using bootstrap samples.
    # Final result is a (bounds, measure) matrix.
    percentile_delta = (1 - self._confidence_interval) / 2
    q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta])
    return np.percentile(sample_mean, q, axis=0)

def _prepare_summary_rouge(summary):
  # Make sure the summary is not bytes-type
  # Add newlines between sentences so that rougeLsum is computed correctly.
  summary = summary.replace(" . ", " .\n")
  return summary

def get_rouge_score(prediction, reference):
  score_keys = ['rouge1', 'rouge2', 'rougeLsum']
  predictions = [prediction]
  targets = [reference]
  scorer = rouge_scorer.RougeScorer(score_keys)
  count = 0
  sum_scores = collections.defaultdict(float)
  for prediction, target in zip(predictions, targets):
    target = _prepare_summary_rouge(target)
    prediction = _prepare_summary_rouge(prediction)
    scores = scorer.score(target=target, prediction=prediction)
    count += 1
    for k, v in scores.items():
      sum_scores[k] += v.fmeasure
  if count == 0:
    raise ValueError("Predictions and targets must both have nonzero length")
  result = {k: v / count for k, v in sum_scores.items()}
  return {key: result[key] * 100 for key in score_keys}

## biogr eval

In [None]:
def center_of_bbox(json_coords: dict[str, float]) -> dict[str, float]:
  # NOTE: This doesn't work if you wrap around the 180 --> -179
  # longitude line in the Pacific but we don't have it in our data.
  bbox_center = {}
  bbox_center["lat"] = np.mean([json_coords["S"], json_coords["N"]])
  bbox_center["lng"] = np.mean([json_coords["E"], json_coords["W"]])
  return bbox_center


def compute_distance(
    lat1_deg: float, lng1_deg: float, lat2_deg: float, lng2_deg: float
) -> float:
  """Computes the distance between two points on a sphere in meters.

  Args:
    lat1_deg: Latitude of the first point in degrees.
    lng1_deg: Longitude of the first point in degrees.
    lat2_deg: Latitude of the second point in degrees.
    lng2_deg: Longitude of the second point in degrees.

  Returns:
    The distance between the two points in meters.
  """
  # Haversine Formula for the geodesic distance on a sphere.
  # See https://en.wikipedia.org/wiki/Haversine_formula.
  lat1 = np.deg2rad(lat1_deg)
  lng1 = np.deg2rad(lng1_deg)
  lat2 = np.deg2rad(lat2_deg)
  lng2 = np.deg2rad(lng2_deg)

  alpha = np.sin((lat2 - lat1) * 0.5)
  gamma = np.sin((lng2 - lng1) * 0.5)
  alpha = alpha * alpha + np.cos(lat1) * np.cos(lat2) * gamma * gamma
  if alpha > 1.0:
    alpha = 1.0  # bulletproof sqrt(1-alpha)
  gamma = 2.0 * np.arctan2(np.sqrt(alpha), np.sqrt(1.0 - alpha))
  return 6371000 * gamma


def compute_center_error_km(
    center_prediction: dict[str, float], center_ground_truth: dict[str, float]
) -> float:
  return 0.001 * compute_distance(
      center_prediction["lat"],
      center_prediction["lng"],
      center_ground_truth["lat"],
      center_ground_truth["lng"],
  )


def compute_box_size_km(coords: dict[str, float]) -> float:
  # Returns half of the diagonal (e.g. like half of a TV). This corresponds
  # to the radius of a circle that inscribes the rectangle.
  return (
      0.001
      * 0.5
      * compute_distance(coords["S"], coords["W"], coords["N"], coords["E"])
  )


def compute_distance_metrics(
    prediction: dict[str, float], ground_truth: dict[str, float]
) -> dict[str, float]:
  """Computes distance metrics between the prediction and ground truth.

  Computes two distance metrics:
  normalized_distance_error - Distance between the center of the predicted
  bounding box and the center of the ground truth bounding box, normalized by
  the ground truth box radius.
  relative_box_size - Ratio of predicted box size to ground truth box size
  (using the diagonal length as the size metric).

  Args:
    prediction: A dictionary with the prediction coordinates.
    ground_truth: A dictionary with the ground truth coordinates.

  Returns:
    A dictionary with the normalized distance error between the predicted and
    ground truth box centers and the relative size of the predicted box.
  """

  center_ground_truth = center_of_bbox(ground_truth)
  center_prediction = center_of_bbox(prediction)
  center_error_km = compute_center_error_km(
      center_prediction, center_ground_truth
  )

  ground_truth_box_size_km = compute_box_size_km(ground_truth)
  normalized_distance_error = center_error_km / ground_truth_box_size_km

  prediction_box_size_km = compute_box_size_km(prediction)
  relative_size = prediction_box_size_km / ground_truth_box_size_km

  return {
      "normalized_distance_error": normalized_distance_error,
      "relative_box_size": relative_size,
  }

def coords_to_box(coords: dict[str, float]) -> np.ndarray:
  return np.array([coords["W"], coords["S"], coords["E"], coords["N"]])


def parse_biodiversity_response(model_response: str) -> dict[str, float]:
  """Parses a model response string into a dictionary of coordinates.

  Args:
    model_response: The model response string to be parsed.

  Returns:
    A dictionary containing the W, E, S, N values.

  Raises:
    ValueError: If the model response string cannot be parsed into either of the
      supported formats.
  """
  if "{" in model_response and "}" in model_response:
    cleaned_response = model_response[
        model_response.find("{") : model_response.rfind("}") + 1
    ]
    return json5.loads(cleaned_response)
  if all(key in model_response for key in ["W", "E", "S", "N"]):
    return {
        "W": float(model_response.split('"W":')[-1].split(",")[0].split()[0]),
        "E": float(model_response.split('"E":')[-1].split(",")[0].split()[0]),
        "S": float(model_response.split('"S":')[-1].split(",")[0].split()[0]),
        "N": float(model_response.split('"N":')[-1].split(",")[0].split()[0]),
    }
  raise ValueError("Can not parse model response")


def bb_intersection_over_union(box_a: np.ndarray, box_b: np.ndarray) -> float:
  """Calculates the Intersection over Union (IoU) between two bounding boxes.

  Args:
    box_a: A list of coordinates representing the first bounding box.
    box_b: A list of coordinates representing the second bounding box.

  Returns:
    The IoU value, a float between 0 and 1.
    0 indicates no overlap and 1 indicates perfect overlap.
  """

  def _intersection_area(box_a: np.ndarray, box_b: np.ndarray) -> float:
    x_a = max(box_a[0], box_b[0])
    y_a = max(box_a[1], box_b[1])
    x_b = min(box_a[2], box_b[2])
    y_b = min(box_a[3], box_b[3])

    width = x_b - x_a
    height = y_b - y_a
    if (width < 0) or (height < 0):
      return 0.0
    return width * height

  def _area(box: np.ndarray) -> float:
    return (box[2] - box[0]) * (box[3] - box[1])

  inter_area = _intersection_area(box_a, box_b)
  union_area = _area(box_a) + _area(box_b) - inter_area
  return inter_area / float(union_area)


def biodiversity_georeferencing_eval(
    model_response: str, ground_truth: str, verbosity: int = 0
) -> dict[str, Union[float, str]]:
  """Computes IOU between ground truth and model response bounding boxes.

  Args:
    ground_truth: A JSON string with lat/lng bounding box coordinates
    model_response: A JSON string with the model response.
    verbosity: Used for debugging.

  Returns:
    A dictionary with IOU keyed as "iou".
  """
  # Load in ground truth coordinates.
  ground_truth_coords = json5.loads(ground_truth)

  try:
    predicted_coords = parse_biodiversity_response(model_response)
  except Exception:  # pylint: disable=broad-except
    if verbosity > 0:
      print("Failed to extract coords from model response: ", model_response)
    return {"iou": "Can not parse model response"}

  iou = bb_intersection_over_union(
      coords_to_box(ground_truth_coords), coords_to_box(predicted_coords)
  )

  distance_metrics = compute_distance_metrics(
      predicted_coords, ground_truth_coords
  )
  biogr_metrics = {
      "iou": iou,
      "normalized_distance_error": distance_metrics[
          "normalized_distance_error"
      ],
      "relative_box_size": distance_metrics["relative_box_size"],
  }
  return biogr_metrics

## pdb eval

In [None]:
def best_sequence_alignment_counts(
    sequence_1: str, sequence_2: str
) -> Dict[str, Union[str, int]]:
  """Calculates the counts of gaps, identities, and mismatches in the best alignment of two sequences.

  Args:
    sequence_1: The first sequence to be aligned.
    sequence_2: The second sequence to be aligned.

  Returns:
    A dictionary containing the counts of gaps, identities, and mismatches:
      - "n_gaps": Number of gap characters introduced in the alignment.
      - "n_identities": Number of positions where the characters are identical.
      - "n_mismatches": Number of positions where the characters are different.
      - "normalized_levenshtein_distance": Levenshtein distance divided by the
          length of the longer sequence. Levenshtein distance is the minimum
          number of edits needed to transform one sequence into another.
      - "identity_ratio": n_identities divided by the length of the alignment.
  """
  # Create an alignment object.
  sequence_1 = sequence_1 if sequence_1 else " "
  sequence_2 = sequence_2 if sequence_2 else " "
  aligner = Align.PairwiseAligner()
  best_alignment = aligner.align(sequence_1, sequence_2)[0]

  max_length = max(len(sequence_1), len(sequence_2))
  if max_length == 0:
    normalized_distance = "Zero length sequences"
  else:
    normalized_distance = (
        Levenshtein.distance(sequence_1, sequence_2) / max_length
    )
  if not best_alignment[0]:
    identity_ratio = "Zero length alignment"
  else:
    identity_ratio = best_alignment.counts().identities / len(best_alignment[0])

  return {
      "n_gaps": best_alignment.counts().gaps,
      "n_identities": best_alignment.counts().identities,
      "n_mismatches": best_alignment.counts().mismatches,
      "normalized_levenshtein_distance": normalized_distance,
      "identity_ratio": identity_ratio,
  }

def extract_code_block(response_str: str) -> tuple[str, str]:
  """Extract code block and function name from response string.

  Args:
    response_str: The model response string.

  Returns:
    A tuple containing the code block and function name.
  """
  # We assume that the code block is always in a ```python block. Other language
  # blocks are not supported yet.
  if "```python" in response_str:
    start_idx = response_str.index("```python")
    try:
      end_idx = response_str.index("```", start_idx + 1)
    except ValueError:
      # Some times the model tries to add the input string which will exceed
      # the length of the decoded response. In this case, we will just return
      # the code block and function name.
      # Find where return is called and get the end of the line
      return_idx = response_str.index("return")
      end_idx = response_str.index("\n", return_idx)
    code_block = response_str[start_idx + 9 : end_idx]
    function_name = response_str.split("def ")[1].split("(")[0]
    return code_block, function_name
  else:
    return "", ""

def pdb_execute_code_eval(
    model_response: str,
    input_data_prompt: str,
) -> str:
  """Executes the code block and returns the pdb eval.

  Args:
    model_response: A string containing the model's response, also expected to
      have the predicted protein sequence on the line following a '>' line.
    input_data_prompt: The prompt text including the input data used for running
      inference on the model. This is used to reconstruct the input pdb string
      to the model and pass it to the code block generated in the model
      response.

  Returns:
    A dictionary containing the pdb eval metrics.
  """
  code_block, function_name = extract_code_block(model_response)
  pred_output = ""
  if not code_block and not function_name:
    return ""
  try:
    local_namespace = {}
    exec(code_block, {}, local_namespace)  # pylint: disable=exec-used
    pdb_data_string = "ATOM" + input_data_prompt.split("ATOM", 1)[1]
    if function_name in local_namespace:
      pred_output = local_namespace[function_name](pdb_data_string)
      return pred_output
    else:
      print("Function name not found in local namespace: ", function_name)
  except SyntaxError as e:
    print("SyntaxError: ", e)

  return pred_output

def pdb_reconstruction_eval(
    model_response: str,
    ground_truth_json_str: str,
    input_data_prompt: str,
) -> dict[str, Union[str, float]]:
  """Evaluates the alignment of a predicted sequence.

  Args:
    ground_truth_json_str: A json string containing the ground truth protein
      sequence data.
    model_response: A string containing the model's response, also expected to
      have the predicted protein sequence on the line following a '>' line.
    input_data_prompt: The prompt text including the input data used for running
      inference on the model. This is used to reconstruct the input pdb string
      to the model and pass it to the code block generated in the model
      response.

  Returns:
    A dictionary containing sequence alignment metrics.
  """
  try:
    targets = json5.loads(ground_truth_json_str)["sequence"]
  except ValueError:
    return {
        "n_gaps": "Can not parse ground truth",
        "n_identities": "Can not parse ground truth",
        "n_mismatches": "Can not parse ground truth",
        "normalized_levenshtein_distance": "Can not parse ground truth",
        "identity_ratio": "Can not parse ground truth",
        "rouge1": "Can not parse ground truth",
        "rouge2": "Can not parse ground truth",
        "rougeLsum": "Can not parse ground truth",
    }
  if "```python" in model_response:
    model_response = pdb_execute_code_eval(
        model_response=model_response, input_data_prompt=input_data_prompt
    )
  model_response = model_response.replace("\u003E", ">")
  lines = model_response.splitlines()
  for i, line in enumerate(lines):
    if line.startswith(">") and i < len(lines) - 1:
      parsed_model_response = lines[i + 1]
      pdb_eval_results = best_sequence_alignment_counts(
          targets, parsed_model_response
      )

      return pdb_eval_results
  return {
      "n_gaps": "Can not parse model response",
      "n_identities": "Can not parse model response",
      "n_mismatches": "Can not parse model response",
      "normalized_levenshtein_distance": "Can not parse model response",
      "identity_ratio": "Can not parse model response",
      "rouge1": "Can not parse model response",
      "rouge2": "Can not parse model response",
      "rougeLsum": "Can not parse model response",
  }


# eval helper functions

In [None]:
from logging import raiseExceptions
def get_annotated_structure_metadata_and_dft_params(
    gt_paper_code: str, verbose: int = 0
) -> dict[str, list[str]]:
  """Returns structure metadata and dft params from the ground truth code.

  Args:
    gt_paper_code: The ground truth code from .py file as a string.
    verbose: The verbosity level.
  """
  structures = []
  dft_params = []

  paper = gt_paper_code
  if "structure_metadata_" in paper:
    parts = paper.split("structure_metadata_")[1:]
    whole_parts = ["structure_metadata_" + part for part in parts]

    for part in whole_parts:
      if verbose > 1:
        print("PART:\n", part)
      if "parse_raw(" not in part:
        continue
      left, right, *_ = part.split("parse_raw(")
      if "StructureMetadata" in left:
        end_struc = ")"
        if "')" in right:
          end_struc = "')"
        elif "'\n)" in right:
          end_struc = "'\n)"
        struc_json = right.split(end_struc)[0].strip()
        if verbose > 0:
          print("Extracted structure:\n", struc_json)
        # clean_json = struc_json.replace('NaN', '"NaN"')
        # struc_json = ast.literal_eval(clean_json)
        structures.append(struc_json)

  if "dft_params_" in paper:
    parts = paper.split("dft_params_")[1:]
    # print(parts)
    whole_parts = ["dft_params_" + part for part in parts]

    for part in whole_parts:
      if verbose > 1:
        print("PART:\n", part)
      if "parse_raw(" not in part:
        continue
      left, right, *_ = part.split("parse_raw(")
      if "DFTParameters" in left:
        end_struc = ")"
        if "')" in right:
          end_struc = "')"
        elif "'\n)" in right:
          end_struc = "'\n)"
        dft_params_str = right.split(end_struc)[0].strip()
        if verbose > 0:
          print("Extracted dft_param:\n", dft_params_str)
        # clean_json = dft_params_str.replace('NaN', '"NaN"')
        # dft_params_str = ast.literal_eval(clean_json)
        dft_params.append(dft_params_str)

  gt_struc_jsons = []
  for struct_metadata in structures:
    gt_json = struct_metadata.split("'")[1]
    clean_json = gt_json.replace("NaN", '"NaN"')
    try:
      gt_structure_json = ast.literal_eval(clean_json)
      gt_struc_jsons.append(gt_structure_json)
    except Exception:  # pylint: disable=broad-exception-caught
      gt_struc_jsons.append(clean_json)

  gt_dft_params_jsons = []
  for dft_param in dft_params:
    gt_json = dft_param.split("'")[1]
    clean_json = gt_json.replace("NaN", '"NaN"')
    try:
      gt_dft_json = ast.literal_eval(clean_json)
      gt_dft_params_jsons.append(gt_dft_json)
    except Exception:  # pylint: disable=broad-exception-caught
      gt_dft_params_jsons.append(clean_json)

  return {
      "structures_metadata": gt_struc_jsons,
      "dft_params": gt_dft_params_jsons,
  }


def preprocess_ground_truth(
    ground_truth: str, task_name: str, prompt: str
) -> str:
  """Preprocesses the ground truth before sending to eval."""
  # Drops the record_ids.
  json_gt = json5.loads(ground_truth)
  if isinstance(json_gt, dict):
    json_gt.pop("record_id", None)
    json_gt.pop("arxiv_id", None)
    json_gt.pop("paper_id", None)
  if isinstance(json_gt, list):
    for item in json_gt:
      if isinstance(item, dict):
        item.pop("record_id", None)
        item.pop("arxiv_id", None)
        item.pop("paper_id", None)
  groundtruth_with_no_ids = json5.dumps(json_gt)
  # Preprocess for dft metadata tasks.
  if task_name == "dft" and prompt == "extract_dft_metadata_1_shot":
    processed = get_annotated_structure_metadata_and_dft_params(
        groundtruth_with_no_ids
    )["dft_params"]
  elif task_name == "dft" and prompt == "extract_structure_data_1_shot":
    processed = get_annotated_structure_metadata_and_dft_params(
        groundtruth_with_no_ids
    )["structures_metadata"]
  else:
    processed = groundtruth_with_no_ids
  return str(processed)


def read_task_ground_truth_and_response(
    ground_truth_path: str,
    model_response_path: str,
) -> Tuple[str, str, str]:
  """Reads in the ground truth and response for all tasks."""
  try:
    # Gets the ground truth.
    with open(ground_truth_path, "r") as f:
      ground_truth_info = f.read()

    model_response = ""
    inf_prompt = ""
    if os.path.exists(model_response_path):
      with open(model_response_path, "r") as f:
        full_model_response = json5.loads(f.read())
        if "response_text" in full_model_response:
          model_response = full_model_response["response_text"]
        else:
          raise ValueError(
              f"ERROR: The succeeded response for {model_response_path} does not contain response_text."
          )
        if 'pdb' in ground_truth_path:
          if 'prompt_text' in full_model_response:
            inf_prompt = full_model_response["prompt_text"]
          else:
            raise ValueError(
                f"ERROR: The succeeded response for {model_response_path} does not contain prompt_text."
            )

    failed_model_response = model_response_path.replace("success", "failure")
    # Gets the response.
    exception_message = ""
    if os.path.exists(failed_model_response):
      with open(failed_model_response, "r") as f:
        full_model_response = json5.loads(f.read())
        if "exception_message" in full_model_response:
          exception_message = full_model_response["exception_message"]
        elif "command-r-plus" in failed_model_response and "response_text" in full_model_response:
          exception_message = full_model_response["response_text"]
        else:
          raise ValueError(
              f"ERROR: The failure response for {failed_model_response} does not contain exception message."
          )

    return ground_truth_info, model_response, exception_message, inf_prompt
  except Exception as e:
    print(f"ERROR: {e}")
    raise Exception(e)

## static configs

In [None]:
_SHARED_METRCS = [get_rouge_score, get_bert_score]
_FULL_ADDITIONAL_METRICS = {
    "pdb": {
        "reconstruct_protein_amino_acid_sequence_0_shot": {
            pdb_reconstruction_eval
        },
    },
    "mpve": {
        "mat_paper_to_property_1_shot": {
            get_lmsim_score_mpve
        },
        "mat_paper_to_property_1_shot_exclude_trivia": {
            get_lmsim_score_mpve
        },
        "mat_paper_to_property_1_shot_bandgap_refractive": {
            get_lmsim_score_mpve
        }
    },
    "dft": {
        "extract_structure_data_1_shot": {
            get_lmsim_score_dft
        },
        "extract_dft_metadata_1_shot": {
            get_lmsim_score_dft
        },
    },
    "biogr": {
        "georeference_image_0_shot": {
            biodiversity_georeferencing_eval
        }
    }
}

_PRIMARY_ADDITIONAL_METRICS = {
    "pdb": {
        "reconstruct_protein_amino_acid_sequence_0_shot": {
            pdb_reconstruction_eval
        },
    },
    "biogr": {
        "georeference_image_0_shot": {
            biodiversity_georeferencing_eval
        }
    }

}

_LLM_LIST = ["command-r-plus", "longllama", "mixtral-gcp",
             "gemini-1.5-flash-latest", "gemini-1.0-pro", "gemini-1.5-pro-latest",
             "gpt-4o", "claude-3-opus-20240229", 'gemini-2.0-flash-latest']

_BIOGR_EXCLUDE_LLM = ["command-r-plus", "longllama", "mixtral-gcp"]
_TASK_EVAL_CONFIGS = {
    "hfe": {
        "extract_hamiltonian_0_shot": {
        }
    },
    "hfd": {
        "derivation_prompt": {
        }
    },
    "qecc_65": {
        "describe_code_in_paper": {
        }
    },
    "pdb": {
        "reconstruct_protein_amino_acid_sequence_0_shot": {
        }
    },
    "mpve": {
        "mat_paper_to_property_1_shot": {
        },
        "mat_paper_to_property_1_shot_bandgap_refractive": {
        },
        "mat_paper_to_property_1_shot_exclude_trivia": {
        },
    },
    "dft": {
        "write_code_for_paper_0_shot": {
        },
        "extract_structure_data_1_shot": {
        },
        "extract_dft_metadata_1_shot": {
        },
    },
    "geo": {
        "extract_dataset_from_geo_papers_0_shot": {
        }
    },
    "biogr": {
        "georeference_image_0_shot": {
        }
    },
}
all_ids_per_task = {'pdb': ['1A12', '1A33', '1AIL', '1AOA', '1AQA', '1AZS', '1BL8', '1BM8', '1BXO', '1CC5', '1CDK', '1CGF', '1CLL', '1CTF', '1DIN', '1DXX', '1E9H', '1EST', '1F9V', '1G6X', '1GOF', '1GP1', '1HCG', '1HHO', '1HNF', '1HNV', '1IAV', '1IGM', '1IGT', '1JM1', '1KXQ', '1M03', '1M17', '1M8Q', '1MBG', '1MBO', '1MHC', '1NKO', '1POH', '1PRC', '1R09', '1RCP', '1RGS', '1SBT', '1SU4', '1TIT', '1TNK', '1UBQ', '2A99', '2ACE', '2AYN', '2J1N', '2POR', '2R6G', '3ADN', '3C7E', '3LCK', '3R2E', '4CPA', '5CPA', '5HZN', '7B3N', '7L1E', '7V8O'],
                    'biogr': ['10212153_1', '260729_2', '531730_1', '556058_2', '563682_1', '564487_1', '575983_2', '578304_2', '583537_1', '585031_1', '587419_1', '590257_1', '591038_1', '592166_1', '592526_1', '592805_4', '594665_1', 'S0048969724009641_1', 'S1470160X22006951_1', 'a_decade_of_submersible_observations_1', 'a_new_species_of_river_1', 'a_preliminary_investigation_of_the_3', 'a_simple_genetic_method_to_1', 'a_small_warm_tributary_provides_1', 'abundance_of_longbilled_curlews_on_1a', 'an_overview_of_marine_biodiversity_3', 'an_overview_of_marine_biodiversity_8', 'assessment_of_ambystomatid_salamander_populations_1', 'assessment_of_potential_recovery_viability_2', 'availability_of_supplemental_corn_for_1', 'barriers_to_gene_flow_in_1', 'baseline_assessments_for_coral_reef_1', 'bat_predation_by_spiders_1', 'biotic_assemblages_of_gelatinous_zooplankton_3', 'bird_monitoring_at_effigy_mounds_1', 'birds_of_the_kilbuck_and_1', 'breeding_population_size_of_the_1', 'ceratonova_shasta_infection_in_lower_1', 'characterization_of_a_developing_recreational_1', 'chewing_lice_of_swan_geese_1', 'comparison_of_endoparasite_abundance_and_1', 'comparisons_of_walleye_fecundity_before_1', 'conservation_genetics_of_the_endangered_1', 'cooccurrence_of_ecologically_similar_species_1', 'deep_vs_shallow_gps_tags_1', 'density_of_axis_deer_in_1', 'despite_regional_variation_gymnorhinus_cyanocephalus_1', 'distribution_abundance_and_breeding_activities_1', 'distribution_and_abundance_of_least_3', 'distribution_morphology_and_karyotype_of_1', 'diurnal_human_activity_and_introduced_1', 'diving_patterns_and_foraging_locations_1', 'dna_barcoding_the_native_flowering_5', 'documentation_of_a_probable_spawning_1', 'ece310733_1', 'energy_density_of_three_prosopium_1', 'environment_affects_sucker_catch_rate_1', 'evaluating_spatial_coverage_of_the_1', 'evening_bats_captured_in_a_1', 'expansion_of_smallmouth_bass_distribution_1', 'extreme_wildlife_declines_and_concurrent_1', 'fecal_genotyping_to_estimate_small_1', 'first_record_of_paronatrema_vaginicola_1', 'fish_predation_by_semiaquatic_spiders_1', 'foraging_ecology_of_southern_sea_1', 'four_centuries_of_change_in_1', 'global_conservation_priorities_for_marine_4', 'habitat_suitability_assessment_for_tule_4', 'habitat_use_and_reproductive_success_1', 'hawaiian_hoary_bat_lasiurus_cinereus_1', 'hiding_in_plain_sight_federally_1', 'high_similarity_in_winter_diet_1', 'impacts_of_the_czu_lightning_1', 'incidental_take_of_giant_sea_7', 'incorporating_expanded_sampling_into_an_1', 'inventory_of_eelgrass_zostera_marina_1', 'jwmg22383_1', 'larval_and_juvenile_longfin_smelt_1', 'leveraging_angler_effort_to_inform_1', 'longterm_occupancy_monitoring_reveals_value_1', 'machine_learning_to_understand_patterns_1', 'macrohabitat_suitability_model_for_the_1', 'macroscale_effects_of_the_monument_1', 'marine_biodiversity_in_the_atlantic_4', 'microhabitat_characteristics_and_management_of_1', 'monitoring_fiveneedle_pine_on_bureau_1', 'monitoring_nesting_waterbirds_for_the_1', 'monitoring_questing_winter_tick_abundance_1', 'movement_patterns_of_two_bat_1', 'natal_contributions_of_kokanee_salmon_1', 'occurrence_of_a_reproducing_wild_3', 'occurrence_of_batrachochytrium_dendrobatidis_in_1', 'onceiconic_pismo_clams_persist_in_1', 'patterns_of_florida_bonneted_bat_2', 'population_and_spatial_dynamics_of_1', 'population_density_and_habitat_selection_2', 'population_genomic_surveys_for_six_1', 'postfire_survival_of_the_threatened_4', 'rangewide_genetic_analysis_of_an_1', 'rapid_population_decline_in_mckays_1', 'recovering_the_lost_potential_of_5', 'relative_influence_of_environmental_factors_1', 'rescuing_and_monitoring_white_sturgeon_2', 'revealing_biases_in_insect_observations_5', 'review_of_considerations_for_restoration_1', 'road_and_highway_undercrossings_as_1', 'roseate_tern_breeding_dispersal_and_1', 's41598022209644_1', 's41598023276709_1', 's4159802334533_1', 'sampling_duration_and_season_recommendations_1', 'sea_level_rise_vulnerability_assessment_1', 'seacliff_bedstraw_galium_buxifolium_patterns_4', 'seasonal_and_spatial_distribution_of_1', 'spatial_relationships_and_mesoscale_habitat_1', 'spatial_variation_in_density_of_1', 'status_and_distribution_of_arroyo_2', 'status_assessment_of_the_endangered_1', 'status_of_landbirds_in_the_1', 'striped_bass_movement_in_a_1', 'syntopy_in_california_redlegged_and_1', 'testing_a_singlevisit_sampling_approach_1', 'the_biodiversity_of_the_mediterranean_5', 'the_first_occurrence_of_the_3', 'the_importance_of_subarctic_intertidal_1', 'the_lion_in_west_africa_1', 'time_series_modeling_of_rainfall_1', 'trace_elements_in_blood_of_1', 'travel_management_planning_for_wildlife_1', 'trends_in_amphibian_occupancy_in_1', 'tricolored_blackbird_survey_methods_1', 'tule_elk_selection_of_surface_1', 'unintended_consequences_of_species_translocations_1', 'us_atlantic_and_gulf_of_1', 'use_of_aerial_distance_sampling_1', 'utilizing_the_timetoevent_framework_to_1', 'validating_a_nonlethal_method_of_1', 'western_purple_martin_progne_subis_1'],
                    'geo': ['00000', '14614a88b3e44e601c5cf8f71b5e07ca989beb0b', '213d2232a49507f81b4e17e50de7675c88fbc672', '33b0925f7681f3199a5d075324e7f3c5e33f2c76', '41f20bb04729a55ca9c2eaf579adf3ed5729044b', '54a9885771350f38135f30f43ef874e0a30be07b', '5c37c2aa2e108e17e37c6db29a4e5afe6a811119', '7dc47696eb876d85a3dfc6884f61fa8832d5e5e8', '83a1a10e3a2416e1d93bc3dbb482db4ccb707eda', '850ca33e8c1853c1735da63073ec3910bce91ddc', '9bdabc37e4af91c4fb53e205502204b510e3b972', 'a09b49e5f2c6b818e479bd29343eae9005f8ca26', 'ab6d648944f306fa1e2d275115b94d36478d9d2a', 'b90358f971e19a60c305acff2867b89dd197fdf6', 'c3a3a5a24206a9b38d9f4727f78cc8f323e398b2', 'e57ae1987bce88add50696843c8979456ce55561', 'e88aa0bccedc5f07bfee8f2db7a85351e65ec24a', 'e900993457d4d256cbfbe8a7527b6745f130a98e', 'e9c8932d5fcdf067821f8bf24b7462e5c7f73054'], 'hfd': ['1010.1819', '1106.6060', '1208.0116', '1310.2674', '1508.00296', '1812.04213', '2004.04168', '2008.08998', '2012.04554', '2108.02159', '2110.11330', '2111.01152', '2112.07523', '2308.03843', '2308.07488'], 'hfe': ['1010.1819', '1112.4222', '1202.4956', '1206.0608', '1208.0116', '1212.5363', '1401.2167', '1506.01488', '1507.06420', '1510.06887', '1512.02398', '1601.00996', '1812.04213', '1908.05417', '2007.15166', '2008.08998', '2102.13507', '2108.02159', '2111.09813', '2112.07523', '2206.10024', '2208.07620', '2209.15374', '2210.06674', '2210.08025', '2210.14517', '2302.04864', '2303.09821', '2303.18025', '2306.02127', '2306.12486', '2307.03793', '2307.04307', '2307.07531', '2307.11810', '2308.01997', '2308.03843', '2311.13191'], 'qecc_65': ['1501.07779', '1502.05267', '1503.06237', '1503.08800', '1505.02576', '1602.00008', '1603.04442', '1604.07925', '1703.02973', '1707.02308', '1708.08474', '1709.04471', '1709.08658', '1710.04631', '1712.07666', '1801.05897', '1802.07419', '1805.01474', '1809.09801', '1903.03937', '1906.11394', '1907.09528', '1910.10746', '2003.02717', '2007.09154', '2007.12152', '2008.09495', '2009.03921', '2010.06628', '2106.02649', '2107.02194', '2110.11510', '2112.01446', '2201.07802', '2203.00103', '2203.16534', '2209.11405', '2210.10808', '2210.16957', '2212.09935', '2303.02432', '2303.04798', '2306.11621', '2309.16503', '2311.07679', '2311.08653', '2311.13040', '2312.04522', '2402.07476', 'cond-mat_0010440', 'cond-mat_0607736', 'cond-mat_9707273', 'cs_0509062', 'quant-ph_0008040', 'quant-ph_0210097', 'quant-ph_0502086', 'quant-ph_0605138', 'quant-ph_0701020', 'quant-ph_0702075', 'quant-ph_9703002', 'quant-ph_9705052', 'quant-ph_9711021', 'quant-ph_9711049', 'quant-ph_9810055', 'quant-ph_9906114'], 'dft': ['2023_09_22_01b9cdba467fd7882e42g', '2023_09_22_07b4d66e23971ccb85c0g', '2023_09_22_0ce1b5ea9a8637db5435g', '2023_09_22_109a6cd5d015ce89e7f3g', '2023_09_22_12cb692c6b82606615a6g', '2023_09_22_13bcf90c3ef43f1413deg', '2023_09_22_14ab0a44fc8fc33fa338g', '2023_09_22_14ddd903c0b77b5e50c2g', '2023_09_22_182e0132c513bc81a414g', '2023_09_22_1a3a4803f9d9cec16d38g', '2023_09_22_1af7e6342ebcf1ce3ea5g', '2023_09_22_1b00ed3a142a7a1b2582g', '2023_09_22_1def59bea80d6e9f67ccg', '2023_09_22_22c19cfc32d2575f9a52g', '2023_09_22_24d7d9ed97e042af9f29g', '2023_09_22_2d35eebe2e85cecf4103g', '2023_09_22_2d44dda253969d6ce7f6g', '2023_09_22_2e5b7c3f50b6a643e33ag', '2023_09_22_2e9b1b9bffe7fd47b18fg', '2023_09_22_39b7d4444fd6ff852b12g', '2023_09_22_3ae1fab1e33569c30b8dg', '2023_09_22_430c3ddffa99af6a2545g', '2023_09_22_433b6bb3bfc2391f7300g', '2023_09_22_48fb54662c601e035a91g', '2023_09_22_49d6cc9c5e7a469afee0g', '2023_09_22_4eeeb89f0a6f52e58610g', '2023_09_22_4ef39bb4116ac1dad9e9g', '2023_09_22_5237e17b3b341fecc9d9g', '2023_09_22_54050a76c33463a8157fg', '2023_09_22_55f975d508230ef05caeg', '2023_09_22_63a752c620bbc784200cg', '2023_09_22_67e874a3c664208e2d2fg', '2023_09_22_6ac063e0fb85cfc10dd0g', '2023_09_22_70326f83ce0dcec87b50g', '2023_09_22_7433a6c7542334063731g', '2023_09_22_77a765fcab6029c666b4g', '2023_09_22_799ee1c298d190145c70g', '2023_09_22_7c5bbe7e076779b790ccg', '2023_09_22_7c76b066edb1f6f53739g', '2023_09_22_7dab45b11a3ede362147g', '2023_09_22_8181c4d6fa78a2ea82dag', '2023_09_22_81b3dfeb3597db5200c5g', '2023_09_22_84af4abb781aeead403eg', '2023_09_22_87d405f182ae3706ea0cg', '2023_09_22_8b46d7b3e561e7f28495g', '2023_09_22_8df0b56e310badc55de3g', '2023_09_22_900c617369212d6bc72fg', '2023_09_22_910aca2a500d3bf9bf47g', '2023_09_22_980121c407cbdaa46afdg', '2023_09_22_984f5b905c02b6f21733g', '2023_09_22_995805c76f676dddab4fg', '2023_09_22_9a007c3865721f379b39g', '2023_09_22_9a591ebf98377fd0ebe2g', '2023_09_22_9e2bc88db643c6ba8aa0g', '2023_09_22_a0271d2dc7b0f2506498g', '2023_09_22_b0501f9057db320b8ad9g', '2023_09_22_b2865949a80ad08a2835g', '2023_09_22_bba019fc933fc84ad347g', '2023_09_22_cb81fee8faa69f4d7078g', '2023_09_22_cc792b66f9a5779f9798g', '2023_09_22_cff3389f103b8f7971d0g', '2023_09_22_d84d81c022c4b2981048g', '2023_09_22_d90e94cbd96e4b6ddb8bg', '2023_09_22_dd9f0f77c116dc99583ag', '2023_09_22_ddfb75e0fb765dc682bbg', '2023_09_22_e06d11a6e698afe5f2d7g', '2023_09_22_e32d0198a1f3dddb5ba2g', '2023_09_22_e69f3d7ce6c4ff487115g', '2023_09_22_e8d1bc2fb9f3dce5f341g', '2023_09_22_edae82c7fbe0c4062118g', '2023_09_22_efbe854b8da1545fbe9bg', '2023_09_22_f752dc9d5ac72657e3f5g', '2023_09_22_f7714e3a468c91c6f56ag', '2023_09_22_f8875eb68affb0a6cb2bg'], 'mpve': ['10222315', '11093908', '11181068', '12841719', '135893324', '137261967', '137362119', '15804005', '17645319', '2837337', '317542', '53384093', '53519111', '55005437', '6183251', '68518', '97574650']}

# get full eval results



## calculates metrics

In [None]:
# Set this to True if you want all metrics including LLMSim scores.
runs_full_metric = False #@param

In [None]:
results_json = {}
for task in _TASK_EVAL_CONFIGS.keys():
  results_json[task] = {}
  for prompt in _TASK_EVAL_CONFIGS[task]:
    results_json[task][prompt] = {}
    for llm_name in _LLM_LIST:
      results_json[task][prompt][llm_name] = {}
      for record_id in all_ids_per_task[task]:
        print(f"task: {task}, prompt: {prompt}, llm_name: {llm_name}, record_id: {record_id}")
        results_json[task][prompt][llm_name][record_id] = {}
        print(f"{task}, {prompt}, {llm_name}, {record_id}")
        gt_path = os.path.join(file_root_path, "data", task, "ground_truth", record_id + ".json")
        model_response_path = os.path.join(file_root_path, "inference_outputs", task, prompt, llm_name, "success", record_id + ".json")
        try:
          ground_truth_info, model_response, exception_message, inf_prompt = read_task_ground_truth_and_response(gt_path, model_response_path)
          ground_truth_info = preprocess_ground_truth(ground_truth_info, task, prompt)
        except Exception as e:
          print(f"ERROR: {e}")
          continue
        if (not task == 'pdb' and model_response) or (task == 'pdb' and model_response and inf_prompt):

          full_additional_metrics = list(_FULL_ADDITIONAL_METRICS[task][prompt]) if task in _FULL_ADDITIONAL_METRICS  and prompt in _FULL_ADDITIONAL_METRICS[task] else []
          primary_additional_metrics = list(_PRIMARY_ADDITIONAL_METRICS[task][prompt]) if task in _PRIMARY_ADDITIONAL_METRICS  and prompt in _PRIMARY_ADDITIONAL_METRICS[task] else []
          additional_metrics = full_additional_metrics if runs_full_metric else primary_additional_metrics
          all_metrics = additional_metrics + _SHARED_METRCS
          for metric in all_metrics:
            try:
              if task == 'pdb' and metric in additional_metrics:
                res = metric(model_response, ground_truth_info, inf_prompt)
              else:
                res = metric(model_response, ground_truth_info)
              print(res)
              results_json[task][prompt][llm_name][record_id].update(res)
            except Exception as e:
              print("##### ERROR #####")
              print(e)
              print(f"##### skipped {task}, {prompt}, {llm_name}, {record_id} #####")
              continue


In [None]:
results_json

# evaluate one instance

In [None]:
llm_name = 'gemini-2.0-flash-latest' #@param
prompt = 'mat_paper_to_property_1_shot' #@param
task = 'mpve' #@param
record_id = '6183251' #@param

In [None]:
result = {}
gt_path = os.path.join(file_root_path, "data", task, "ground_truth", record_id + ".json")
model_response_path = os.path.join(file_root_path, "inference_outputs", task, prompt, llm_name, "success", record_id + ".json")
try:
  ground_truth_info, model_response, exception_message, inf_prompt = read_task_ground_truth_and_response(gt_path, model_response_path)
  ground_truth_info = preprocess_ground_truth(ground_truth_info, task, prompt)
except Exception as e:
  raise Exception(e)
if (not task == 'pdb' and model_response) or (task == 'pdb' and model_response and inf_prompt):
  full_additional_metrics = list(_FULL_ADDITIONAL_METRICS[task][prompt]) if task in _FULL_ADDITIONAL_METRICS  and prompt in _FULL_ADDITIONAL_METRICS[task] else []
  primary_additional_metrics = list(_PRIMARY_ADDITIONAL_METRICS[task][prompt]) if task in _PRIMARY_ADDITIONAL_METRICS  and prompt in _PRIMARY_ADDITIONAL_METRICS[task] else []
  additional_metrics = full_additional_metrics if runs_full_metric else primary_additional_metrics
  all_metrics = additional_metrics + _SHARED_METRCS
  for metric in all_metrics:
    try:
      if task == 'pdb' and metric in additional_metrics:
        res = metric(model_response, ground_truth_info, inf_prompt)
      else:
        res = metric(model_response, ground_truth_info)
      result.update(res)
    except Exception as e:
      raise Exception(e)

print(f'###### eval finished ##########\nresults: \n{json5.dumps(result, indent=4)}')