Create pseudo random negatives

In [None]:
import json
from collections import defaultdict
import random
from typing import Dict, List, Set
import multiprocessing as mp
from itertools import islice
import os

# Number of negatives to pick based on suffix similarity
NUM_NEG_SAMPLES_SUFFIX = 2

# Number of negatives to pick based on prefix similarity
NUM_NEG_SAMPLES_PREFIX = 2

# Total number of negatives to pick
TOTAL_NUM_NEGATIVES = 8

def chunk_file(filename: str, chunk_size: int = 1000):
    """Generator to read file in chunks."""
    with open(filename, 'r') as f:
        while True:
            lines = list(islice(f, chunk_size))
            if not lines:
                break
            yield lines

def safe_get_last_word(text: str) -> str:
    """Safely get the last word of a string."""
    if not text or not isinstance(text, str):
        return ""
    words = text.split()
    return words[-1].lower() if words else ""

def get_valid_candidates(candidates: List[str], pos: str, pos_last_word: str) -> List[str]:
    """
    Suffix-based filter:
      - Candidate != pos
      - Candidate's last word != pos's last word
    """
    return [
        p for p in candidates
        if p != pos and safe_get_last_word(p) != pos_last_word
    ]

def get_valid_candidates_prefix(
    candidates: List[str], pos: str, prefix_length: int
) -> List[str]:
    """
    Prefix-based filter:
      - Candidate != pos
      - Candidate has at least 'prefix_length' letters
      - Same first 'prefix_length' letters
      - The (prefix_length+1)th letter is different from pos's (prefix_length+1)th letter, if both exist
    """
    valid = []
    if len(pos) < prefix_length:
        return valid
    
    pos_prefix = pos[:prefix_length]
    
    # The next character in pos, if it exists
    next_char = pos[prefix_length] if len(pos) >= prefix_length + 1 else None
    
    for c in candidates:
        if c == pos:
            continue
        if len(c) < prefix_length:
            continue
        if c[:prefix_length] != pos_prefix:
            continue
        
        # Check that the (prefix_length+1)th character is different (if both exist)
        if next_char is not None and len(c) >= prefix_length + 1:
            if c[prefix_length] == next_char:
                continue
        
        valid.append(c)
    return valid

def process_chunk(chunk: List[str]) -> List[dict]:
    """
    Process a chunk of lines. We do:
      1) Build local suffix and prefix dictionaries
      2) For each entry, pick:
         a) 2 suffix-based negatives
         b) 2 prefix-based negatives
         c) Random negatives to make the total negatives 7
    """
    local_suffix_dict_5 = defaultdict(set)
    local_suffix_dict_4 = defaultdict(set)
    local_suffix_dict_3 = defaultdict(set)
    
    # Build prefix dictionaries for lengths 4, 3, 2, 1
    local_prefix_dict_4 = defaultdict(set)
    local_prefix_dict_3 = defaultdict(set)
    local_prefix_dict_2 = defaultdict(set)
    local_prefix_dict_1 = defaultdict(set)
    
    # Keep track of ALL pos in the chunk for purely random negatives
    all_positions_in_chunk = set()

    # ========== 1) Build dictionaries ==========
    entries = []
    for line in chunk:
        entry = json.loads(line.strip().replace("'", '"'))
        entries.append(entry)
        
        for pos in entry.get('pos', []):
            if not pos:
                continue
            # Track all positions for random negatives
            all_positions_in_chunk.add(pos)
            
            # ========== Suffix logic ==========
            if len(pos) >= 3:
                suffix_3 = pos[-3:]
                local_suffix_dict_3[suffix_3].add(pos)
            if len(pos) >= 4:
                suffix_4 = pos[-4:]
                local_suffix_dict_4[suffix_4].add(pos)
            if len(pos) >= 5:
                suffix_5 = pos[-5:]
                local_suffix_dict_5[suffix_5].add(pos)
            
            # ========== Prefix logic ==========
            if len(pos) >= 1:
                local_prefix_dict_1[pos[:1]].add(pos)
            if len(pos) >= 2:
                local_prefix_dict_2[pos[:2]].add(pos)
            if len(pos) >= 3:
                local_prefix_dict_3[pos[:3]].add(pos)
            if len(pos) >= 4:
                local_prefix_dict_4[pos[:4]].add(pos)
    
    # Convert sets to lists
    suffix_dict_3 = {k: list(v) for k, v in local_suffix_dict_3.items()}
    suffix_dict_4 = {k: list(v) for k, v in local_suffix_dict_4.items()}
    suffix_dict_5 = {k: list(v) for k, v in local_suffix_dict_5.items()}

    prefix_dict_4 = {k: list(v) for k, v in local_prefix_dict_4.items()}
    prefix_dict_3 = {k: list(v) for k, v in local_prefix_dict_3.items()}
    prefix_dict_2 = {k: list(v) for k, v in local_prefix_dict_2.items()}
    prefix_dict_1 = {k: list(v) for k, v in local_prefix_dict_1.items()}

    # A list for random picks
    all_positions_list = list(all_positions_in_chunk)
    
    # ========== 2) Find negative samples for each entry ==========
    updated_entries = []
    for entry in entries:
        pos_list = entry.get('pos', [])
        
        # We'll store all negatives (suffix-based + prefix-based + random) here
        neg_samples_all = []
        
        # ========== (a) SUFFIX-BASED NEGATIVES (2) ==========
        suffix_neg_samples = []
        for pos in pos_list:
            # We skip if too short for suffix logic
            if not pos or len(pos) < 3:
                continue
            
            pos_last_word = safe_get_last_word(pos)
            all_candidates_suff = []
            
            # 1. Try suffix_5
            if len(pos) >= 5:
                suffix_5_part = pos[-5:]
                candidates = get_valid_candidates(
                    suffix_dict_5.get(suffix_5_part, []), pos, pos_last_word
                )
                all_candidates_suff.extend(candidates)
            
            # 2. Suffix_4 if still needed
            if len(all_candidates_suff) < NUM_NEG_SAMPLES_SUFFIX and len(pos) >= 4:
                suffix_4_part = pos[-4:]
                candidates = get_valid_candidates(
                    suffix_dict_4.get(suffix_4_part, []), pos, pos_last_word
                )
                all_candidates_suff.extend(candidates)
            
            # 3. Suffix_3 if still needed
            if len(all_candidates_suff) < NUM_NEG_SAMPLES_SUFFIX and len(pos) >= 3:
                suffix_3_part = pos[-3:]
                candidates = get_valid_candidates(
                    suffix_dict_3.get(suffix_3_part, []), pos, pos_last_word
                )
                all_candidates_suff.extend(candidates)
            
            # Choose up to 2 from suffix-based
            selected_suff_neg = random.sample(
                all_candidates_suff, 
                min(NUM_NEG_SAMPLES_SUFFIX, len(all_candidates_suff))
            )
            suffix_neg_samples.extend(selected_suff_neg)
        
        # ========== (b) PREFIX-BASED NEGATIVES (2) ==========
        prefix_neg_samples = []
        for pos in pos_list:
            if not pos:
                continue
            
            all_candidates_pref = []
            
            # 1. Prefix_4
            if len(pos) >= 4:
                prefix_4_part = pos[:4]
                candidates = get_valid_candidates_prefix(
                    prefix_dict_4.get(prefix_4_part, []), pos, 4
                )
                all_candidates_pref.extend(candidates)
            
            # 2. Prefix_3
            if len(all_candidates_pref) < NUM_NEG_SAMPLES_PREFIX and len(pos) >= 3:
                prefix_3_part = pos[:3]
                candidates = get_valid_candidates_prefix(
                    prefix_dict_3.get(prefix_3_part, []), pos, 3
                )
                all_candidates_pref.extend(candidates)
            
            # 3. Prefix_2
            if len(all_candidates_pref) < NUM_NEG_SAMPLES_PREFIX and len(pos) >= 2:
                prefix_2_part = pos[:2]
                candidates = get_valid_candidates_prefix(
                    prefix_dict_2.get(prefix_2_part, []), pos, 2
                )
                all_candidates_pref.extend(candidates)
            
            # 4. Prefix_1
            if len(all_candidates_pref) < NUM_NEG_SAMPLES_PREFIX and len(pos) >= 1:
                prefix_1_part = pos[:1]
                candidates = get_valid_candidates_prefix(
                    prefix_dict_1.get(prefix_1_part, []), pos, 1
                )
                all_candidates_pref.extend(candidates)
            
            # Select up to 2 prefix-based negatives
            selected_prefix_neg = random.sample(
                all_candidates_pref,
                min(NUM_NEG_SAMPLES_PREFIX, len(all_candidates_pref))
            )
            prefix_neg_samples.extend(selected_prefix_neg)
        
        # ========== (c) RANDOM NEGATIVES TO REACH TOTAL 7 ==========
        total_neg_samples = suffix_neg_samples + prefix_neg_samples
        num_random_needed = max(0, TOTAL_NUM_NEGATIVES - len(total_neg_samples))
        
        random_neg_samples = []
        if num_random_needed > 0:
            valid_random_candidates = [x for x in all_positions_list if x not in total_neg_samples]
            random_neg_samples = random.sample(
                valid_random_candidates,
                min(num_random_needed, len(valid_random_candidates))
            )
        
        # Combine all three sets of negatives
        neg_samples_all = total_neg_samples + random_neg_samples
        
        # Deduplicate, preserving order
        seen = set()
        neg_samples_deduped = []
        for neg in neg_samples_all:
            if neg not in seen:
                seen.add(neg)
                neg_samples_deduped.append(neg)
        
        entry['neg'] = neg_samples_deduped
        updated_entries.append(entry)
    
    return updated_entries

def main():
    input_file = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets.jsonl"
    output_file = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets_4types_7neg.jsonl"
    
    # Determine optimal chunk size based on file size
    file_size = os.path.getsize(input_file)
    chunk_size = max(1000, min(10000, file_size // (mp.cpu_count() * 1000000)))
    
    # Process chunks in parallel
    with mp.Pool(mp.cpu_count()) as pool:
        results = []
        for chunk_result in pool.imap_unordered(process_chunk, chunk_file(input_file, chunk_size)):
            results.extend(chunk_result)
    
    # Write results
    with open(output_file, 'w') as outfile:
        for entry in results:
            json.dump(entry, outfile)
            outfile.write('\n')
    
    print(f"Updated entries written to {output_file}")

if __name__ == "__main__":
    main()


Set #negatives to 7

In [None]:
import json

# File path to the output file
file_path = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets_4types_7neg.jsonl"

# Counters for tracking
deleted_entries_count = 0
updated_entries = []

# Process the file
with open(file_path, 'r') as file:
    for line in file:
        entry = json.loads(line.strip())
        negatives = entry.get('neg', [])

        # If the number of negatives is less than 7, delete the entry
        if len(negatives) < 7:
            deleted_entries_count += 1
            continue

        # If the number of negatives is more than 7, truncate to 7
        if len(negatives) > 7:
            entry['neg'] = negatives[:7]

        # Add the updated entry to the list
        updated_entries.append(entry)

# Write the updated entries back to the file
with open(file_path, 'w') as file:
    for entry in updated_entries:
        file.write(json.dumps(entry) + '\n')

# Print the total number of deleted entries
print(f"Total number of deleted entries: {deleted_entries_count}")


View JSONL

In [None]:
import json

# Load the JSONL data from the file
input_file = '/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets_4types_7neg.jsonl'

# Read all lines from the input file
with open(input_file, 'r') as f:
    lines = f.readlines()

# Print the length of the file
print(f'Number of entries in the file: {len(lines)}')

# Print the head of the file (first 5 entries)
for i in range(min(200, len(lines))):
    print(json.loads(lines[i]))

top 200 most reccurent terms

In [None]:
import os
import json
from collections import Counter

# Path to the JSONL file
file_path = "/home/jh537/Clinical_Trial_Embending/Clinical_Trial_data/Clinical_Trial_Triplet_v3/Train/all_triplets_4types_7neg.jsonl"

# Check if the file exists
if not os.path.exists(file_path):
    print(f"File not found: {file_path}")
else:
    # Counter to store term frequencies
    term_counter = Counter()

    # Reading the JSONL file
    with open(file_path, 'r') as file:
        for line in file:
            data = json.loads(line.strip())
            # Combine query and positive terms
            terms = [data['query']] + data.get('pos', [])
            # Update the term counter
            term_counter.update(terms)

    # Get the top 200 most frequent unique terms
    top_200_terms = term_counter.most_common(200)

    # Display the results
    print("Top 200 most frequent unique terms in 'query' or 'pos':")
    for term, count in top_200_terms:
        print(f"{term}: {count}")
