Merge the triplets

In [None]:
import json
import random

# File paths
file1 = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/ctg-triplets.jsonl"
file2 = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/triplet_UMLS_no_numbers.jsonl"
output_file = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl"

# Function to merge two JSONL files and randomize the order
def merge_and_randomize_jsonl(file1_path, file2_path, output_path):
    try:
        merged_entries = []

        # Read first file
        with open(file1_path, 'r') as infile1:
            for line in infile1:
                merged_entries.append(json.loads(line.strip()))

        # Read second file
        with open(file2_path, 'r') as infile2:
            for line in infile2:
                merged_entries.append(json.loads(line.strip()))

        # Randomize the order of entries
        random.shuffle(merged_entries)

        # Write back to the output file
        with open(output_path, 'w') as outfile:
            for entry in merged_entries:
                outfile.write(json.dumps(entry) + '\n')

        print(f"Files merged and randomized. Output saved to {output_path}.")
    except Exception as e:
        print(f"Error merging and randomizing JSONL files: {e}")

# Run the script
merge_and_randomize_jsonl(file1, file2, output_file)


Remove duplicates without considering categories

In [None]:
import json

# File path
input_file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl'

# Set to track unique query-pos pairs
unique_pairs = set()
filtered_triplets = []

# Load the input file and filter duplicates
with open(input_file_path, 'r') as input_file:
    for line in input_file:
        triplet = json.loads(line.strip())
        query = triplet['query']
        pos_list = triplet['pos']

        # Generate all unique query-pos pairs (ignoring category)
        for pos in pos_list:
            pair = tuple(sorted([query, pos]))  # Sort to handle pos-query equivalence
            if pair not in unique_pairs:
                unique_pairs.add(pair)
                filtered_triplets.append(triplet)
                break  # Only keep the first triplet for this pair

# Overwrite the input file with filtered triplets
with open(input_file_path, 'w') as output_file:
    for triplet in filtered_triplets:
        output_file.write(json.dumps(triplet) + '\n')

print(f"Duplicates removed from {input_file_path}, only unique query-pos pairs retained.")


Create a SY file and remove them from training 

In [None]:
import json

# File paths
input_file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/triplet_UMLS_no_numbers.jsonl'
output_file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Test/SY_triplets.jsonl'

# Filter entries with category 'SY'
filtered_triplets = []
remaining_triplets = []
with open(input_file_path, 'r') as input_file:
    for line in input_file:
        triplet = json.loads(line.strip())
        if triplet.get('category') == 'SY':
            filtered_triplets.append(triplet)
        else:
            remaining_triplets.append(triplet)

# Save the filtered entries to a new file
with open(output_file_path, 'w') as output_file:
    for triplet in filtered_triplets:
        output_file.write(json.dumps(triplet) + '\n')

# Overwrite the input file with remaining entries
with open(input_file_path, 'w') as input_file:
    for triplet in remaining_triplets:
        input_file.write(json.dumps(triplet) + '\n')

print(f"Filtered entries with category 'SY' saved to {output_file_path}")
print(f"Input file updated to exclude 'SY' category entries")


Create a testing file 

In [None]:
import json
import random

# File paths
input_file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl'
output_file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Test/test_all_triplets.jsonl'

# Number of entries to sample
sample_size = 400000

# Load all entries from the input file
with open(input_file_path, 'r') as input_file:
    triplets = [line.strip() for line in input_file]

# Ensure sample size does not exceed the total number of entries
if sample_size > len(triplets):
    raise ValueError(f"Sample size ({sample_size}) exceeds the total number of entries ({len(triplets)})")

# Randomly sample the entries
random_sample = set(random.sample(triplets, sample_size))  # Use a set for faster lookups

# Identify the remaining entries
remaining_triplets = [triplet for triplet in triplets if triplet not in random_sample]

# Save the sampled entries to the output file
with open(output_file_path, 'w') as output_file:
    output_file.writelines(triplet + '\n' for triplet in random_sample)

# Overwrite the input file with the remaining entries
with open(input_file_path, 'w') as input_file:
    input_file.writelines(triplet + '\n' for triplet in remaining_triplets)

print(f"Randomly sampled {sample_size} entries saved to {output_file_path}")
print(f"Input file updated to exclude the sampled entries")

Preview file

In [None]:
import json

# Load the JSONL data from the file
input_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl'

# Read all lines from the input file
with open(input_file, 'r') as f:
    lines = f.readlines()

# Print the length of the file
print(f'Number of entries in the file: {len(lines)}')

# Print the head of the file (first 5 entries)
for i in range(min(200, len(lines))):
    print(json.loads(lines[i]))

Number of entries per category

In [None]:
import json
from collections import Counter

# File path
file_path = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl'

# Initialize a counter for categories
category_counts = Counter()

# Read the file and count entries per category
with open(file_path, 'r') as file:
    for line in file:
        triplet = json.loads(line.strip())
        category = triplet.get('category', 'unknown')  # Default to 'unknown' if category is missing
        category_counts[category] += 1

# Print the counts for each category
for category, count in category_counts.items():
    print(f"Category: {category}, Count: {count}")
