In [17]:
import os
import json
import shutil
import sys
import re

In [18]:
def find_duplicates(data):
    """
    Return a list of IDs that appear more than once in the data.
    """
    counts = {}
    for item in data:
        # Safely get 'id' field
        id_ = item.get("id")
        if id_ is not None:
            counts[id_] = counts.get(id_, 0) + 1
    # Extract IDs with count > 1
    return [id_ for id_, count in counts.items() if count > 1]

def find_missing_ids(data):
    """
    Extract the numeric suffix from each 'id', then identify
    any missing numbers in the continuous range between the
    smallest and largest found.
    Returns a sorted list of missing integers.
    """
    numbers = []
    for item in data:
        id_str = item.get("id", "")
        m = re.search(r"(\d+)$", id_str)
        if m:
            numbers.append(int(m.group(1)))

    if not numbers:
        return []

    min_n, max_n = min(numbers), max(numbers)
    full_range = set(range(min_n, max_n + 1))
    missing = sorted(full_range - set(numbers))
    return missing


In [19]:
input_root  = 'inference_outputs/converted_outputs'
output_root = 'inference_outputs/converted_outputs_trimmed'
look_for_duplicates = True

trim_counts = {
    'covid_fact' : 1000,
    'hover_train': 2000,
    'politi_hop' : None,
}

In [20]:
for root, dirs, files in os.walk(input_root):
    rel = os.path.relpath(root, input_root)
    dst_dir = os.path.join(output_root, rel)
    os.makedirs(dst_dir, exist_ok=True)

    for fn in files:
        src_path = os.path.join(root, fn)
        dst_path = os.path.join(dst_dir, fn)

        if fn.endswith('.json'):

            if look_for_duplicates:
                with open(src_path, 'r') as f:
                    data = json.load(f)
                duplicates = find_duplicates(data)
                if duplicates:
                    print(f"Duplicate IDs found in {fn}:")
                    for id_ in duplicates:
                        print(f"  - {id_}")

                missing_ids = find_missing_ids(data)
                if missing_ids:
                    print(f"Missing IDs found in {fn}:")
                    for id_ in missing_ids:
                        print(f"  - {id_}")
            else:
                key = next((k for k in trim_counts if fn.startswith(k)), None)
                N   = trim_counts.get(key)

                if N is None:
                    shutil.copy(src_path, dst_path)
                else:
                    with open(src_path, 'r') as f:
                        data = json.load(f)
                    trimmed = data[:N]
                    with open(dst_path, 'w') as f:
                        json.dump(trimmed, f, indent=2)

        else:
            shutil.copy(src_path, dst_path)

Duplicate IDs found in hover_train_mistral_7b_cot.json:
  - hover_train_full_docs-000850
Missing IDs found in hover_train_mistral_7b_no_cot.json:
  - 1805
  - 1806
