In [1]:
import re
import numpy as np

In [2]:
import re
import numpy as np

def check_string(s):
    pattern = r"^C\d{7}$"
    return bool(re.match(pattern, s))

def check_string_DB(s):
    pattern = r"^DB\d{5}$"
    return bool(re.match(pattern, s))


def extract_diseases_from_file(filename):
    with open(filename, 'r') as file:
        lines = file.readlines()
    
    diseases = []
    
    for line in lines:
        parts = line.strip().split('\t')
        if check_string(parts[0]):
            diseases.append(parts[0])
        elif check_string(parts[2]):
            diseases.append(parts[2])
    
    return diseases, lines

def split_train_test_val(filename, random_diseases_set=None, test_percentage=0.15, val_percentage=0.15):
    # Step 1: Extract diseases and lines
    diseases, lines = extract_diseases_from_file(filename)
    
    # Step 2: Get unique diseases
    unique_diseases = set(diseases)
    
    # Step 3: Randomly select test diseases if a set is not provided
    if random_diseases_set is None:
        random_diseases = np.random.choice(list(unique_diseases), size=int(len(unique_diseases) * test_percentage), replace=False)
        random_diseases_set = set(random_diseases)
    
    test_lines = []
    train_lines = []
    counterfactual_lines_ct = 0
    
    # Step 4: Split lines into test and train based on random diseases
    for line in lines:
        parts = line.strip().split('\t')
        if parts[0] in random_diseases_set or parts[2] in random_diseases_set:
            if parts[1] in {'THERAPY', 'TREAT'}:
                test_lines.append(line)
            elif parts[1] in {'COUNTERFACTUAL_POSITIVE', 'COUNTERFACTUAL_NEGATIVE'}:
                if check_string_DB(parts[0]) or check_string_DB(parts[2]):
                    counterfactual_lines_ct += 1 # Remove augmented lines close to ground truth
                    # Optionally: continue or do something to mark these lines
                else:
                    # Append counterfactual lines that don't meet the condition to train
                    train_lines.append(line)
            else:
                train_lines.append(line)
        else:
            train_lines.append(line)
    
    # Step 5: Split train lines into train and validation sets
    num_val_lines = int(len(train_lines) * val_percentage)
    val_indices = np.random.choice(len(train_lines), size=num_val_lines, replace=False)
    
    validation_lines = [train_lines[i] for i in val_indices]
    train_lines = [line for i, line in enumerate(train_lines) if i not in val_indices]
    
    return train_lines, validation_lines, test_lines, counterfactual_lines_ct

def write_lines_to_file(lines, filename):
    with open(filename, 'w') as file:
        for line in lines:
            file.write(line)

In [3]:
def main(random_diseases_set=None, medium='nonaug'):
    input_file = f'train_{medium}.txt'
    
    train_file = f'train_lines_{medium}.txt'
    val_file = f'val_lines_{medium}.txt'
    test_file = f'test_lines_{medium}.txt'
    
    # Split the train, validation, and test lines, allowing for a provided random diseases set
    train_lines, val_lines, test_lines, counterfactual_lines_ct = split_train_test_val(input_file, random_diseases_set)
    
    # Write test lines to a file
    write_lines_to_file(test_lines, test_file)
    
    # Write validation lines to a file
    write_lines_to_file(val_lines, val_file)
    
    # Write train lines to a file
    write_lines_to_file(train_lines, train_file)
    
    # Check if split is consistent
    with open(input_file, 'r') as file:
        original_lines = file.readlines()
    
    print(len(original_lines))
    print(len(train_lines))
    print(len(val_lines))
    print(len(test_lines))
    print(counterfactual_lines_ct)

    assert len(test_lines) + len(train_lines) + len(val_lines) + counterfactual_lines_ct == len(original_lines), "Mismatch in line counts!"

if __name__ == '__main__':
    # Optionally, you can provide a set of random diseases
    #read random_diseases_set from a file as a set
    with open('random_diseases_set.txt', 'r') as file:
        random_diseases_set = set(line.strip() for line in file.readlines())
        
    main(random_diseases_set, medium='propagation')


3751235
3059659
539939
10204
141433


In [1]:
#### NEGATIVE SAMPLING ####

import random

def generate_negative_triples(positive_triples):
    # Extract heads, tails, and relations from positive triples
    heads = [triple[0] for triple in positive_triples]
    tails = [triple[2] for triple in positive_triples]
    relations = [triple[1] for triple in positive_triples]
    
    # Initialize the negative triples set
    negative_triples = set()
    
    # Create a set of positive triples for quick lookup
    positive_set = set(positive_triples)
    
    # Generate negative triples
    while len(negative_triples) < len(positive_triples):
        for i in range(len(positive_triples)):
            # Decide randomly whether to change the head or the tail
            if random.random() > 0.5:
                # Replace head node
                random_head = random.choice(heads)
                negative_triple = (random_head, relations[i], tails[i])
            else:
                # Replace tail node
                random_tail = random.choice(tails)
                negative_triple = (heads[i], relations[i], random_tail)
            
            # Ensure the generated negative triple is not a positive one
            if negative_triple not in positive_set and negative_triple not in negative_triples:
                negative_triples.add(negative_triple)
            
            # Break if we have enough negative samples
            if len(negative_triples) >= len(positive_triples):
                break

    return list(negative_triples)

def read_positive_triples(file_path):
    positive_triples = []
    with open(file_path, 'r') as file:
        for line in file:
            head, relation, tail = line.strip().split('\t')
            positive_triples.append((head, relation, tail))
    return positive_triples

def write_triples_to_file(output_path, triples):
    with open(output_path, 'w') as file:
        for triple in triples:
            file.write(f"{triple[0]}\t{triple[1]}\t{triple[2]}\n")

def generate_and_save_triples(input_file, output_file):
    # Step 1: Read positive triples from the input file
    positive_triples = read_positive_triples(input_file)
    
    # Step 2: Generate negative triples
    negative_triples = generate_negative_triples(positive_triples)
    
    # Step 3: Combine positive and negative triples
    all_triples = positive_triples + negative_triples
    
    # Step 4: Shuffle the combined list
    random.shuffle(all_triples)
    
    # Step 5: Write the shuffled triples to the output file
    write_triples_to_file(output_file, all_triples)


In [8]:
mode = 'propagation'
input_file = f'{mode}_train_set/test_lines_{mode}.txt'
output_file = f'{mode}_train_set/test_lines_{mode}_neg.txt'

generate_and_save_triples(input_file, output_file)
