In [1]:
import json
import os, csv
from datasets import Dataset, load_from_disk, concatenate_datasets, load_dataset

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('/scratch/gpfs/JHA/mb5157/tokenizers/biomedbert_fast_tokenizer')

  from .autonotebook import tqdm as notebook_tqdm


# Postprocessing

### 1. Unite dataset chunks after combining tokens

You need this one if you launched combine_tails.py in multiple parallel jobs. If you didn't that still should work correctly.

In [2]:
def unite_output(dataset_path, original_dataset_size, chunk_size, llm_name, start_idx=0):
    """
    original_dataset_size -- how many records in the original dataset;
    it's also in the filename in the folder -- the largest end index
    """
    output_path = os.path.join(dataset_path, f'{llm_name}_all')
    end_idx = start_idx

    datasets = []
    while end_idx < original_dataset_size:
        end_idx = min(start_idx + chunk_size, original_dataset_size)
        path = os.path.join(dataset_path, f"{llm_name}_{start_idx}-{end_idx}")
        dataset = load_from_disk(path)            
        print(f'loaded from {path}')
        datasets.append(dataset)
        start_idx += chunk_size
    
    united_dataset = concatenate_datasets(datasets)
    united_dataset.save_to_disk(output_path)
    print('saved to', output_path)
    return output_path

#### Set your own variables in the cell below

They are used to buld the correct path to the data

In [6]:
prefix = '../..'
# specify path to trained graphmert model
model_path = 'outputs/test/span/rel29/bs256_lr_0.0001' 
top_k = 10
dataset_path = os.path.join(prefix, model_path, f'predictions/top_{top_k}')
llm_name = 'qwen3-1.7b' # llm used to combine tokens
original_dataset_size = 100

saved_to_path = None # to notice if the function unite_output didn't work after re-run
saved_to_path = unite_output(dataset_path, original_dataset_size, chunk_size=10000, llm_name = llm_name, start_idx=0)
united_dataset = load_from_disk(saved_to_path)

loaded from ../../outputs/test/span/rel29/bs256_lr_0.0001/predictions/top_10/qwen3-1.7b_0-100


Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 19917.86 examples/s]

saved to ../../outputs/test/span/rel29/bs256_lr_0.0001/predictions/top_10/qwen3-1.7b_all





In [None]:
# if you don't have chunks, you can simly load your dataset here
# saved_to_path = ...
# united_dataset = load_from_disk(saved_to_path)

### 2. Remove hallicinated tails: not from the prediction

Use filter_tails func

In [4]:
def filter_tails(example):
    """aggressive filtering: remove all tails with a token not in GLM prediction"""
    preds = set(example['predictions'].split())
    kept, removed, removed_missing = [], [], []
    for tail in example['tails']:
        tl = tail.lower()
        toks = tokenizer.tokenize(tl)
        if all(tok in preds for tok in toks):
            kept.append(tl)
        else:
            removed.append(tl)
            removed_missing.append([tok for tok in toks if tok not in preds])
    return {
        'tails': kept,
        'removed_tails': removed, # the actual tail strings that were filtered out
        'removed_missing_tokens': removed_missing, # tokens that triggered the removal
    }


In [7]:
clean_dataset = united_dataset.map(
    filter_tails, num_proc=20, desc="filter tails"
)

# ---- totals ----
kept_total    = sum(len(x) for x in clean_dataset['tails'])
removed_total = sum(len(x) for x in clean_dataset['removed_tails'])
total_before  = kept_total + removed_total

print(f"tails kept: {kept_total}")
print(f"tails removed: {removed_total}")
print(f"tails total (before filtering): {total_before}")
print(f'records in the dataset: {len(clean_dataset)}')

clean_output_path = f'{saved_to_path}_clean'
clean_dataset.select_columns(['id', 'input_ids', 'head', 'relation', 'predictions', 'tails']).save_to_disk(clean_output_path)
print(f'saved to {clean_output_path}')

tails kept: 0
tails removed: 0
tails total (before filtering): 0
records in the dataset: 100


Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 9965.32 examples/s] 

saved to outputs/test/span/rel29/bs256_lr_0.0001/predictions/top_10/qwen3-32b_subset_100_0-100_clean





### Functions below are experimental

In [None]:
def filter_tails_relaxed(example):
    """relaxed filtering: allows tail tokens that are from the sequence"""
    preds = set(example['predictions'].split())
    kept, removed, removed_missing = [], [], []
    for tail in example['tails']:
        tl = tail.lower()
        toks = set(tokenizer.tokenize(tl))
        # len_toks_before = len(toks)
        toks -= preds
        if len(toks) == 0:
            kept.append(tl)
        else:
            seq = set(tokenizer.decode(example['input_ids'], skip_special_tokens=True))
            toks -= seq
            if len(toks) == 0: #or len_toks_before
                kept.append(tl)
            else:
                removed.append(tl)
                removed_missing.append(list(toks))
    return {
        'tails': kept,
        'removed_tails': removed, # the actual tail strings that were filtered out
        'removed_missing_tokens': removed_missing, # tokens that triggered the removal
    }


def filter_tails_loose(example):
    """
    loose filtering:
    allows tail tokens that are from the sequence
    and one token from LLM for tails longer than 3 tokens
    """
    preds = set(example['predictions'].split())
    kept, removed, removed_missing = [], [], []
    for tail in example['tails']:
        tl = tail.lower()
        toks = set(tokenizer.tokenize(tl))
        len_toks_before = len(toks)
        toks -= preds
        if len(toks) == 0:
            kept.append(tl)
        else:
            seq = set(tokenizer.decode(example['input_ids'], skip_special_tokens=True))
            toks -= seq
            if len(toks) == 0 or (len(toks) == 1 and len_toks_before > 2):
                kept.append(tl)
            else:
                removed.append(tl)
                removed_missing.append(list(toks))
    return {
        'tails': kept,
        'removed_tails': removed, # the actual tail strings that were filtered out
        'removed_missing_tokens': removed_missing, # tokens that triggered the removal
    }


def write_to_log(dataset, log_path):
    with open(log_path, 'w', newline='', encoding='utf-8') as f:
        w = csv.writer(f)
        header = ['row_index', 'id', 'removed_tail', 'missing_tokens']
        w.writerow(header)
        has_id = 'id' in dataset.column_names
        for i, row in enumerate(dataset):
            if not i % 10_000:
                print(f'at line {i}')
            rid = row['id'] if has_id else ''
            for tail, missing in zip(row['removed_tails'], row['removed_missing_tokens']):
                w.writerow([i, rid, tail, ' '.join(missing)])
    print(f"saved to {log_path}")

#### relaxed cleaning

In [74]:
clean_dataset_relaxed = united_dataset.map(
    filter_tails_relaxed, num_proc=20, desc="filter tails"
)

# ---- totals ----
kept_total    = sum(len(x) for x in clean_dataset_relaxed['tails'])
removed_total = sum(len(x) for x in clean_dataset_relaxed['removed_tails'])
total_before  = kept_total + removed_total

print(f"tails kept: {kept_total}")
print(f"tails removed: {removed_total}")
print(f"tails total (before filtering): {total_before}")

clean_output_path = f'{saved_to_path}_clean_relaxed'
clean_dataset_relaxed.select_columns(['id', 'input_ids', 'head', 'relation', 'predictions', 'tails']).save_to_disk(clean_output_path)
print(f'saved to {clean_output_path}')

filter tails (num_proc=20): 100%|██████████| 1341137/1341137 [01:04<00:00, 20647.17 examples/s]


tails kept: 1544558
tails removed: 215530
tails total (before filtering): 1760088


Saving the dataset (2/2 shards): 100%|██████████| 1341137/1341137 [00:01<00:00, 718713.04 examples/s]

saved to ../outputs/350abstracts/qwen32/score0.55/span/rel29/bs256_lr_0.0004/predictions/top_20/qwen3-32b_all_clean_relaxed





In [None]:
# log_path = os.path.join(dataset_path, f'removed_tails_relaxed.csv')
# write_to_log(clean_dataset_relaxed, log_path)

#### loose cleaning

In [30]:
clean_dataset_loose = united_dataset.map(
    filter_tails_loose, num_proc=20, desc="filter tails"
)

kept_total    = sum(len(x) for x in clean_dataset_loose['tails'])
removed_total = sum(len(x) for x in clean_dataset_loose['removed_tails'])
total_before  = kept_total + removed_total

print(f"tails kept: {kept_total}")
print(f"tails removed: {removed_total}")
print(f"tails total (before filtering): {total_before}")

clean_output_path = f'{saved_to_path}_clean_loose'
clean_dataset_loose.select_columns(['id', 'input_ids', 'head', 'relation', 'predictions', 'tails']).save_to_disk(clean_output_path)
print(f'saved to {clean_output_path}')

filter tails (num_proc=20): 100%|██████████| 1517900/1517900 [01:22<00:00, 18486.29 examples/s]


tails kept: 4012197
tails removed: 202083
tails total (before filtering): 4214280


Saving the dataset (3/3 shards): 100%|██████████| 1517900/1517900 [00:02<00:00, 672660.80 examples/s]

saved to ../outputs/350abstracts/qwen32/score0.55/ablations/seed_kg_size/llm_kg/span/rel23/bs256_lr_0.0004/predictions/top_20/qwen3-32b_all_clean_loose





In [None]:
log_path = os.path.join(dataset_path, f'removed_tails_loose.csv')
write_to_log(clean_dataset_loose, log_path)