In [4]:
file_path = '/data/luis/hgb/LP/benchmark/data/CellDrug/node.dat'
line_count = 0
print(f"Checking file: {file_path}")
try:
    with open(file_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f):
            line = line.strip() # Remove leading/trailing whitespace including newline
            if not line: # Skip empty lines
                print(f"Line {line_number + 1}: Is empty, skipping.")
                continue

            parts = line.split('\t')
            num_parts = len(parts)
            print(f"Line {line_number + 1}: '{line}' -> {num_parts} parts when split by tab.")

            if num_parts < 3:
                print(f"  ERROR: Line {line_number + 1} has fewer than 3 parts when split by tab!")

            line_count += 1
            if line_count >= 20: # Check first 20 non-empty lines
                print("Reached 20 non-empty lines, stopping check.")
                break
    if line_count == 0:
        print("File is empty or contains only empty lines.")

except FileNotFoundError:
    print(f"ERROR: File not found at {file_path}")
except Exception as e:
    print(f"An error occurred: {e}")

Checking file: /data/luis/hgb/LP/benchmark/data/CellDrug/node.dat
Line 1: '0	0	0' -> 3 parts when split by tab.
Line 2: '1	1	0' -> 3 parts when split by tab.
Line 3: '2	2	0' -> 3 parts when split by tab.
Line 4: '3	3	0' -> 3 parts when split by tab.
Line 5: '4	4	0' -> 3 parts when split by tab.
Line 6: '6	6	0' -> 3 parts when split by tab.
Line 7: '7	7	0' -> 3 parts when split by tab.
Line 8: '9	9	0' -> 3 parts when split by tab.
Line 9: '10	10	0' -> 3 parts when split by tab.
Line 10: '11	11	0' -> 3 parts when split by tab.
Line 11: '12	12	0' -> 3 parts when split by tab.
Line 12: '14	14	0' -> 3 parts when split by tab.
Line 13: '15	15	0' -> 3 parts when split by tab.
Line 14: '16	16	0' -> 3 parts when split by tab.
Line 15: '17	17	0' -> 3 parts when split by tab.
Line 16: '18	18	0' -> 3 parts when split by tab.
Line 17: '19	19	0' -> 3 parts when split by tab.
Line 18: '20	20	0' -> 3 parts when split by tab.
Line 19: '21	21	0' -> 3 parts when split by tab.
Line 20: '22	22	0' -> 3 part

In [6]:
import os

# --- Configuration: Please verify these paths ---
# Path to your custom dataset directory
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Make sure this is correct

original_node_dat_path = os.path.join(dataset_dir, 'node.dat')
corrected_node_dat_path = os.path.join(dataset_dir, 'node_corrected.dat')
# --- End Configuration ---

def create_minimal_3_column_node_dat(original_file, corrected_file):
    print(f"Processing original node.dat: {original_file}")
    print(f"A new 3-column node.dat will be saved to: {corrected_file}")
    print(f"The middle 'name' column will be an empty string.")

    lines_processed = 0
    lines_transformed_to_3_cols = 0
    lines_already_3_or_more_cols = 0
    problematic_lines = 0

    try:
        with open(original_file, 'r', encoding='utf-8') as infile, \
             open(corrected_file, 'w', encoding='utf-8') as outfile:
            
            for line_number, line_content in enumerate(infile):
                line_content_stripped = line_content.strip()
                if not line_content_stripped:  # Skip truly empty lines
                    outfile.write('\n') 
                    continue

                parts = line_content_stripped.split('\t')
                num_parts = len(parts)

                if num_parts == 2:
                    node_id_str, node_type_id_str = parts[0], parts[1]
                    
                    # Basic validation that parts look like IDs
                    if not node_id_str.strip().isdigit() or not node_type_id_str.strip().isdigit():
                        print(f"  WARNING: Line {line_number + 1} - 2-part line with non-digit ID/Type_ID: '{line_content_stripped}'. Writing as is (may cause errors later).")
                        outfile.write(line_content_stripped + '\n')
                        problematic_lines +=1
                        continue
                        
                    # Use an empty string for the placeholder node_name
                    placeholder_node_name = "" 
                    
                    new_line = f"{node_id_str}\t{placeholder_node_name}\t{node_type_id_str}\n"
                    outfile.write(new_line)
                    lines_transformed_to_3_cols += 1
                elif num_parts >= 3:
                    # Line already has 3 or more parts, assume it's correct or has features
                    outfile.write(line_content_stripped + '\n')
                    lines_already_3_or_more_cols += 1
                else:
                    # Line has 0 or 1 part after stripping, which is problematic
                    print(f"  WARNING: Line {line_number + 1} has unexpected format (parts={num_parts}): '{line_content_stripped}'. Writing as is (will likely cause errors).")
                    outfile.write(line_content_stripped + '\n')
                    problematic_lines +=1
                
                lines_processed += 1
        
        print("\nProcessing Summary:")
        print(f"Total lines processed (non-empty): {lines_processed}")
        print(f"Lines transformed (2 parts -> 3 parts with empty name): {lines_transformed_to_3_cols}")
        print(f"Lines unchanged (already >=3 parts): {lines_already_3_or_more_cols}")
        print(f"Problematic lines (fewer than 2 parts or non-digit IDs, written as is): {problematic_lines}")
        print(f"Corrected file saved to: {corrected_file}")
        return True

    except FileNotFoundError:
        print(f"ERROR: Original node.dat not found at {original_file}")
        return False
    except Exception as e:
        print(f"An error occurred during processing: {e}")
        return False

if __name__ == '__main__':
    if os.path.exists(original_node_dat_path):
        success = create_minimal_3_column_node_dat(original_node_dat_path, corrected_node_dat_path)
        if success:
            print("\nIMPORTANT: Please review the 'node_corrected.dat' file.")
            print("Each line should now have 3 tab-separated columns.")
            print("For lines that originally had 2 columns, the middle 'node_name' column should now be an EMPTY STRING.")
            print(f"If it looks correct, backup your original 'node.dat' and then rename/copy 'node_corrected.dat' to 'node.dat':")
            print(f"cd {os.path.dirname(original_node_dat_path)}") # Example command to navigate
            print(f"mv node.dat node_original.dat")
            print(f"mv node_corrected.dat node.dat")
        else:
            print("\nScript finished with errors. Please check messages above.")
    else:
        print(f"Error: Input node.dat file not found at {original_node_dat_path}")

Processing original node.dat: /data/luis/hgb/LP/benchmark/data/CellDrug/node.dat
A new 3-column node.dat will be saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/node_corrected.dat
The middle 'name' column will be an empty string.

Processing Summary:
Total lines processed (non-empty): 23815
Lines transformed (2 parts -> 3 parts with empty name): 23815
Lines unchanged (already >=3 parts): 0
Problematic lines (fewer than 2 parts or non-digit IDs, written as is): 0
Corrected file saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/node_corrected.dat

IMPORTANT: Please review the 'node_corrected.dat' file.
Each line should now have 3 tab-separated columns.
For lines that originally had 2 columns, the middle 'node_name' column should now be an EMPTY STRING.
If it looks correct, backup your original 'node.dat' and then rename/copy 'node_corrected.dat' to 'node.dat':
cd /data/luis/hgb/LP/benchmark/data/CellDrug
mv node.dat node_original.dat
mv node_corrected.dat node.dat


In [7]:
import os

# --- Configuration: Please verify this path ---
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Make sure this is correct
link_dat_path = os.path.join(dataset_dir, 'link.dat')
# --- End Configuration ---

def check_link_dat_format(file_path):
    print(f"Checking file: {file_path}")
    line_count = 0
    all_lines_ok = True
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_number, line_content in enumerate(f):
                line_content_stripped = line_content.strip()
                if not line_content_stripped: # Skip empty lines
                    print(f"Line {line_number + 1}: Is empty, skipping.")
                    continue

                parts = line_content_stripped.split('\t')
                num_parts = len(parts)
                
                current_line_ok = True
                if num_parts != 4:
                    print(f"  ERROR: Line {line_number + 1} has {num_parts} parts (expected 4) -> '{line_content_stripped}'")
                    all_lines_ok = False
                    current_line_ok = False
                else:
                    # Check if IDs and type are integers, and weight is float
                    try:
                        int(parts[0]) # head_node_id
                        int(parts[1]) # tail_node_id
                        int(parts[2]) # relation_type_id
                        float(parts[3]) # link_weight
                    except ValueError:
                        print(f"  ERROR: Line {line_number + 1} - type conversion failed for IDs/weight -> '{line_content_stripped}'")
                        all_lines_ok = False
                        current_line_ok = False
                
                if current_line_ok and line_number < 5 : # Print first few good lines
                     print(f"Line {line_number + 1}: OK -> '{line_content_stripped}'")

                line_count += 1
                if line_count >= 20 and not all_lines_ok: # Stop early if errors are found
                    print("Stopping check early due to errors.")
                    break
                elif line_count >= 50 and all_lines_ok: # Check more lines if no errors yet
                    print("Checked first 50 lines, all seem ok so far. Stopping check.")
                    break
        
        if line_count == 0:
            print("File is empty or contains only empty lines.")
        elif all_lines_ok and line_count > 0 :
             print(f"\nChecked {line_count} lines. All lines appear to have the correct 4-column tab-separated format and basic data types.")
        else:
            print(f"\nFound formatting issues in {file_path}. Please review the ERROR messages above.")

    except FileNotFoundError:
        print(f"ERROR: File not found at {file_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == '__main__':
    check_link_dat_format(link_dat_path)
    # If you have a link.dat.test, you can check it too:
    # test_link_dat_path = os.path.join(dataset_dir, 'link.dat.test')
    # if os.path.exists(test_link_dat_path):
    #     print("\n--- Checking link.dat.test ---")
    #     check_link_dat_format(test_link_dat_path)

Checking file: /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat
Line 1: OK -> '0	20646	1	1.0'
Line 2: OK -> '0	20915	1	1.0'
Line 3: OK -> '0	20878	1	1.0'
Line 4: OK -> '0	20302	1	1.0'
Line 5: OK -> '0	20902	1	1.0'
Checked first 50 lines, all seem ok so far. Stopping check.

Checked 50 lines. All lines appear to have the correct 4-column tab-separated format and basic data types.


In [8]:
import os

dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/'
link_file_path = os.path.join(dataset_dir, 'link.dat')
max_allowed_id = 23814 # N-1, where N is self.nodes['total']

print(f"Checking link file: {link_file_path} for IDs > {max_allowed_id}")
problem_found = False
try:
    with open(link_file_path, 'r', encoding='utf-8') as f:
        for line_number, line_content in enumerate(f):
            line_content_stripped = line_content.strip()
            if not line_content_stripped:
                continue

            parts = line_content_stripped.split('\t')
            if len(parts) == 4:
                try:
                    head_id = int(parts[0])
                    tail_id = int(parts[1])

                    if head_id > max_allowed_id:
                        print(f"  ERROR: Line {line_number + 1} - Head ID {head_id} > {max_allowed_id}. Line: '{line_content_stripped}'")
                        problem_found = True
                    if tail_id > max_allowed_id:
                        print(f"  ERROR: Line {line_number + 1} - Tail ID {tail_id} > {max_allowed_id}. Line: '{line_content_stripped}'")
                        problem_found = True
                except ValueError:
                    print(f"  WARNING: Line {line_number + 1} - Could not parse IDs as integers: '{line_content_stripped}'")
                    problem_found = True # Treat as problem
            else:
                # This shouldn't happen if your link.dat format is correct from previous checks
                print(f"  WARNING: Line {line_number + 1} - Does not have 4 parts: '{line_content_stripped}'")
                problem_found = True 
    if not problem_found:
        print("No out-of-bounds node IDs found in link.dat based on max_allowed_id.")

except FileNotFoundError:
    print(f"ERROR: File not found at {link_file_path}")
except Exception as e:
    print(f"An error occurred: {e}")

Checking link file: /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat for IDs > 23814
  ERROR: Line 710867 - Tail ID 24219 > 23814. Line: '86	24219	2	1.0'
  ERROR: Line 710868 - Tail ID 25884 > 23814. Line: '123	25884	2	1.0'
  ERROR: Line 710869 - Tail ID 26616 > 23814. Line: '123	26616	2	1.0'
  ERROR: Line 710870 - Tail ID 26616 > 23814. Line: '123	26616	2	1.0'
  ERROR: Line 710871 - Tail ID 23861 > 23814. Line: '123	23861	2	1.0'
  ERROR: Line 710872 - Tail ID 25865 > 23814. Line: '123	25865	2	1.0'
  ERROR: Line 710873 - Tail ID 24806 > 23814. Line: '123	24806	2	1.0'
  ERROR: Line 710874 - Tail ID 25852 > 23814. Line: '123	25852	2	1.0'
  ERROR: Line 710875 - Tail ID 26037 > 23814. Line: '123	26037	2	1.0'
  ERROR: Line 710876 - Tail ID 25700 > 23814. Line: '123	25700	2	1.0'
  ERROR: Line 710877 - Tail ID 26765 > 23814. Line: '123	26765	2	1.0'
  ERROR: Line 710879 - Tail ID 23830 > 23814. Line: '129	23830	2	1.0'
  ERROR: Line 710881 - Tail ID 26238 > 23814. Line: '161	26238	2	1.0'
  ER

In [9]:
import os

# --- Configuration ---
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Your dataset directory
link_file_path = os.path.join(dataset_dir, 'link.dat')
# This value comes from the error message: "matrix dimension 23815", so max ID is 23815 - 1
max_allowed_node_id = 23814
# --- End Configuration ---

def count_problematic_links(file_path, max_id):
    print(f"Analyzing link file: {file_path}")
    print(f"Checking for node IDs > {max_id}")

    total_lines_in_file = 0           # Total lines read (including empty/malformed)
    valid_format_links = 0            # Links that have 4 columns and parseable IDs
    problematic_links_count = 0       # Links with out-of-bounds IDs
    lines_with_format_issues = 0      # Lines that don't have 4 columns or unparseable IDs
    first_n_problems_to_show = 20     # How many problem examples to print
    problems_shown_count = 0

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line_number, line_content in enumerate(f):
                total_lines_in_file += 1
                line_content_stripped = line_content.strip()
                
                if not line_content_stripped:
                    if problems_shown_count < first_n_problems_to_show:
                        # print(f"  INFO: Line {line_number + 1} is empty.")
                        pass # Optionally report empty lines
                    continue

                parts = line_content_stripped.split('\t')
                
                if len(parts) != 4:
                    lines_with_format_issues += 1
                    if problems_shown_count < first_n_problems_to_show:
                        print(f"  FORMAT ERROR: Line {line_number + 1} has {len(parts)} parts (expected 4) -> '{line_content_stripped}'")
                        problems_shown_count += 1
                    continue # Skip to next line if format is wrong

                valid_format_links += 1 # If we reach here, it has 4 parts
                
                try:
                    head_id = int(parts[0])
                    tail_id = int(parts[1])
                    # int(parts[2]) # relation_type_id - can also validate if needed
                    # float(parts[3]) # link_weight - can also validate if needed
                    
                    is_problematic = False
                    if head_id > max_id:
                        if problems_shown_count < first_n_problems_to_show:
                            print(f"  ID ERROR: Line {line_number + 1} - Head ID {head_id} > {max_id}. Line: '{line_content_stripped}'")
                            problems_shown_count += 1
                        is_problematic = True
                    if tail_id > max_id:
                        if problems_shown_count < first_n_problems_to_show and not is_problematic: # Avoid double printing for same line if both are bad
                             print(f"  ID ERROR: Line {line_number + 1} - Tail ID {tail_id} > {max_id}. Line: '{line_content_stripped}'")
                             problems_shown_count +=1
                        elif problems_shown_count < first_n_problems_to_show and head_id <= max_id : # if head was fine print tail error
                             print(f"  ID ERROR: Line {line_number + 1} - Tail ID {tail_id} > {max_id} (Head ID {head_id} was OK). Line: '{line_content_stripped}'")
                             problems_shown_count +=1
                        is_problematic = True
                    
                    if is_problematic:
                        problematic_links_count += 1
                        
                except ValueError:
                    lines_with_format_issues += 1
                    valid_format_links -=1 # It had 4 parts, but IDs were not int
                    if problems_shown_count < first_n_problems_to_show:
                        print(f"  TYPE ERROR: Line {line_number + 1} - Could not parse IDs as integers: '{line_content_stripped}'")
                        problems_shown_count += 1
                    continue
        
        print("\n--- Analysis Summary ---")
        print(f"Total lines read from file: {total_lines_in_file}")
        print(f"Lines with potential link data (4 tab-separated columns): {valid_format_links}")
        print(f"Lines with formatting issues (not 4 columns or non-integer IDs where expected): {lines_with_format_issues}")
        print(f"Number of validly formatted links containing out-of-bounds node IDs (>{max_id}): {problematic_links_count}")
        
        if valid_format_links > 0:
            percentage_problematic = (problematic_links_count / valid_format_links) * 100
            print(f"Percentage of validly formatted links that are problematic: {percentage_problematic:.2f}%")
        else:
            print("No validly formatted links found to calculate percentage.")

    except FileNotFoundError:
        print(f"ERROR: File not found at {file_path}")
    except Exception as e:
        print(f"An error occurred: {e}")

if __name__ == '__main__':
    if os.path.exists(link_file_path):
        count_problematic_links(link_file_path, max_allowed_node_id)
        print("\nBased on this count, you can decide if filtering is acceptable or if re-generating your data is necessary.")
    else:
        print(f"Error: Input link.dat file not found at {link_file_path}")

Analyzing link file: /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat
Checking for node IDs > 23814
  ID ERROR: Line 710867 - Tail ID 24219 > 23814. Line: '86	24219	2	1.0'
  ID ERROR: Line 710868 - Tail ID 25884 > 23814. Line: '123	25884	2	1.0'
  ID ERROR: Line 710869 - Tail ID 26616 > 23814. Line: '123	26616	2	1.0'
  ID ERROR: Line 710870 - Tail ID 26616 > 23814. Line: '123	26616	2	1.0'
  ID ERROR: Line 710871 - Tail ID 23861 > 23814. Line: '123	23861	2	1.0'
  ID ERROR: Line 710872 - Tail ID 25865 > 23814. Line: '123	25865	2	1.0'
  ID ERROR: Line 710873 - Tail ID 24806 > 23814. Line: '123	24806	2	1.0'
  ID ERROR: Line 710874 - Tail ID 25852 > 23814. Line: '123	25852	2	1.0'
  ID ERROR: Line 710875 - Tail ID 26037 > 23814. Line: '123	26037	2	1.0'
  ID ERROR: Line 710876 - Tail ID 25700 > 23814. Line: '123	25700	2	1.0'
  ID ERROR: Line 710877 - Tail ID 26765 > 23814. Line: '123	26765	2	1.0'
  ID ERROR: Line 710879 - Tail ID 23830 > 23814. Line: '129	23830	2	1.0'
  ID ERROR: Line 710881

In [10]:
import os

# --- Configuration ---
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Your dataset directory
original_link_dat_path = os.path.join(dataset_dir, 'link.dat')
filtered_link_dat_path = os.path.join(dataset_dir, 'link_filtered.dat')
max_allowed_node_id = 23814 # This is self.nodes['total'] - 1 from the error
# --- End Configuration ---

def filter_links(original_file, filtered_file, max_id):
    print(f"Filtering link file: {original_file}")
    print(f"Keeping only links where both head and tail IDs are <= {max_id}")
    print(f"Filtered links will be saved to: {filtered_file}")

    lines_processed = 0
    lines_kept = 0
    lines_removed = 0
    
    try:
        with open(original_file, 'r', encoding='utf-8') as infile, \
             open(filtered_file, 'w', encoding='utf-8') as outfile:
            
            for line_number, line_content in enumerate(infile):
                line_content_stripped = line_content.strip()
                if not line_content_stripped:
                    outfile.write('\n')
                    continue
                
                lines_processed += 1
                parts = line_content_stripped.split('\t')
                
                if len(parts) == 4:
                    try:
                        head_id = int(parts[0])
                        tail_id = int(parts[1])
                        
                        if head_id <= max_id and tail_id <= max_id:
                            outfile.write(line_content_stripped + '\n')
                            lines_kept += 1
                        else:
                            lines_removed += 1
                            if lines_removed <= 20: # Print first few removed lines
                                print(f"  REMOVING Line {line_number + 1}: Contains out-of-bounds ID(s) -> '{line_content_stripped}'")
                    except ValueError:
                        print(f"  WARNING: Line {line_number + 1} - Could not parse IDs as integers. Skipping line: '{line_content_stripped}'")
                        lines_removed +=1 # Count as removed
                else:
                    print(f"  WARNING: Line {line_number + 1} - Does not have 4 parts. Skipping line: '{line_content_stripped}'")
                    lines_removed +=1 # Count as removed

        print("\nFiltering Summary:")
        print(f"Total lines processed (non-empty): {lines_processed}")
        print(f"Lines kept (valid IDs): {lines_kept}")
        print(f"Lines removed (invalid IDs or format): {lines_removed}")
        print(f"Filtered file saved to: {filtered_file}")
        return True

    except FileNotFoundError:
        print(f"ERROR: Original link.dat not found at {original_file}")
        return False
    except Exception as e:
        print(f"An error occurred during filtering: {e}")
        return False

if __name__ == '__main__':
    if os.path.exists(original_link_dat_path):
        success = filter_links(original_link_dat_path, filtered_link_dat_path, max_allowed_node_id)
        if success:
            print("\nIMPORTANT: Review 'link_filtered.dat'.")
            print(f"If it looks correct and you accept the data loss, backup your original 'link.dat' and rename/copy 'link_filtered.dat' to 'link.dat'.")
            print(f"cd {dataset_dir}")
            print(f"mv link.dat link_original.dat")
            print(f"mv link_filtered.dat link.dat")
    else:
        print(f"Error: Input link.dat file not found at {original_link_dat_path}")

Filtering link file: /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat
Keeping only links where both head and tail IDs are <= 23814
Filtered links will be saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/link_filtered.dat
  REMOVING Line 710867: Contains out-of-bounds ID(s) -> '86	24219	2	1.0'
  REMOVING Line 710868: Contains out-of-bounds ID(s) -> '123	25884	2	1.0'
  REMOVING Line 710869: Contains out-of-bounds ID(s) -> '123	26616	2	1.0'
  REMOVING Line 710870: Contains out-of-bounds ID(s) -> '123	26616	2	1.0'
  REMOVING Line 710871: Contains out-of-bounds ID(s) -> '123	23861	2	1.0'
  REMOVING Line 710872: Contains out-of-bounds ID(s) -> '123	25865	2	1.0'
  REMOVING Line 710873: Contains out-of-bounds ID(s) -> '123	24806	2	1.0'
  REMOVING Line 710874: Contains out-of-bounds ID(s) -> '123	25852	2	1.0'
  REMOVING Line 710875: Contains out-of-bounds ID(s) -> '123	26037	2	1.0'
  REMOVING Line 710876: Contains out-of-bounds ID(s) -> '123	25700	2	1.0'
  REMOVING Line 710877: Contains ou

In [16]:
import random
import os

dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/'
original_link_file = os.path.join(dataset_dir, 'link.dat') # Assuming you use the filtered one
new_train_link_file = os.path.join(dataset_dir, 'link.dat') # This will be the new training file
test_link_file = os.path.join(dataset_dir, 'link.dat.test')
target_relation_id = 0  # Assuming '0' is your cell-drug effective link
test_split_ratio = 0.1 # e.g., 10% for testing

all_links = []
target_links = []
other_links = []

with open(original_link_file, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) == 4 and int(parts[2]) == target_relation_id:
            target_links.append(line)
        else:
            other_links.append(line)

random.shuffle(target_links)
split_point = int(len(target_links) * test_split_ratio)
test_set_target_links = target_links[:split_point]
train_set_target_links = target_links[split_point:]

with open(new_train_link_file, 'w') as f:
    for line in other_links:
        f.write(line)
    for line in train_set_target_links:
        f.write(line)

with open(test_link_file, 'w') as f:
    for line in test_set_target_links:
        f.write(line)

print(f"Created {new_train_link_file} with {len(other_links) + len(train_set_target_links)} links.")
print(f"Created {test_link_file} with {len(test_set_target_links)} links.")
# Remember to backup your original link.dat before overwriting

Created /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat with 716544 links.
Created /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat.test with 478 links.


In [13]:
import os
import json

# --- Configuration ---
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Make sure this is correct
original_node_dat_path = os.path.join(dataset_dir, 'node.dat') # Your current node.dat
info_dat_path = os.path.join(dataset_dir, 'info.dat') # Needed to know valid type_ids

# Output files
reindexed_node_dat_path = os.path.join(dataset_dir, 'node_reindexed.dat')
id_map_path = os.path.join(dataset_dir, 'original_to_new_id_map.json')
# --- End Configuration ---

def reindex_nodes(original_file, info_file, reindexed_file, map_file):
    print(f"Re-indexing nodes from: {original_file}")
    
    node_data = [] # List to store (original_id_int, original_type_id_int, original_line_parts)
    valid_type_ids = set()

    # Try to load node type information from info.dat to validate type_ids
    try:
        with open(info_file, 'r', encoding='utf-8') as f:
            info_data = json.load(f)
        # Assuming info.dat structure like: {"node.dat": {"0": "name0", "1": "name1"}}
        if "node.dat" in info_data and isinstance(info_data["node.dat"], dict):
            valid_type_ids = set(info_data["node.dat"].keys())
            print(f"Valid node type IDs from info.dat: {valid_type_ids}")
        else:
            print(f"Warning: Could not parse node type IDs from {info_file}. Type ID validation will be limited.")
    except Exception as e:
        print(f"Warning: Could not load or parse {info_file}: {e}. Type ID validation will be limited.")

    print("Reading original node.dat...")
    try:
        with open(original_file, 'r', encoding='utf-8') as infile:
            for line_number, line_content in enumerate(infile):
                line_content_stripped = line_content.strip()
                if not line_content_stripped:
                    continue

                parts = line_content_stripped.split('\t')
                
                if len(parts) < 2:
                    print(f"  WARNING: Line {line_number + 1} has < 2 parts: '{line_content_stripped}'. Skipping.")
                    continue
                
                original_id_str = parts[0]
                # If original file was 2-col: parts[1] is type. If 3-col: parts[2] is type.
                # Assuming the problematic files had type as the last column before placeholder name.
                original_type_id_str = parts[-1] # Takes the last element, assuming it's the type

                try:
                    original_id_int = int(original_id_str)
                    original_type_id_int = int(original_type_id_str)
                    if valid_type_ids and original_type_id_str not in valid_type_ids:
                         print(f"  WARNING: Line {line_number + 1} - Type ID '{original_type_id_str}' not in info.dat valid types. Original: '{line_content_stripped}'")
                    node_data.append({'original_id': original_id_int, 'type_id': original_type_id_int, 'original_parts': parts})
                except ValueError:
                    print(f"  WARNING: Line {line_number + 1} - Cannot parse ID/Type as int: '{line_content_stripped}'. Skipping.")
                    continue
        print(f"Read {len(node_data)} valid node entries.")

    except FileNotFoundError:
        print(f"ERROR: Original node.dat not found at {original_file}")
        return False
    except Exception as e:
        print(f"An error occurred reading {original_file}: {e}")
        return False

    if not node_data:
        print("No node data to process.")
        return False

    # Sort nodes: primarily by type_id, secondarily by original_id
    print("Sorting node data...")
    node_data.sort(key=lambda x: (x['type_id'], x['original_id']))

    print("Assigning new global IDs and creating mapping...")
    original_to_new_id_map = {}
    current_new_global_id = 0
    
    with open(reindexed_file, 'w', encoding='utf-8') as outfile:
        for node_entry in node_data:
            original_id = node_entry['original_id']
            type_id = node_entry['type_id']
            original_parts = node_entry['original_parts']

            original_to_new_id_map[str(original_id)] = current_new_global_id # Store original (as string, like in files) to new mapping

            placeholder_name = "" # Using empty string as per your preference
            if len(original_parts) >= 2 and original_parts[0] != original_parts[-1]: # If there was an actual name in original middle column
                 # Check if the original middle column was different from original ID AND not the type ID
                 # This heuristic might need adjustment based on your *actual* original 3-column file structure.
                 # For now, if original was 2-col, parts[1] is type. if 3-col, parts[1] is name.
                 if len(original_parts) == 3: # original was ID / NAME / TYPE
                     placeholder_name = original_parts[1]
                 elif len(original_parts) == 2: # original was ID / TYPE
                     placeholder_name = "" # Keep it empty as decided

            outfile.write(f"{current_new_global_id}\t{placeholder_name}\t{type_id}\n")
            current_new_global_id += 1
            
    print(f"Re-indexed {current_new_global_id} nodes saved to: {reindexed_file}")

    try:
        with open(map_file, 'w', encoding='utf-8') as f_map:
            json.dump(original_to_new_id_map, f_map, indent=4)
        print(f"Original ID to New ID map saved to: {map_file}")
    except Exception as e:
        print(f"Error saving ID map file: {e}")
        return False
        
    return True

if __name__ == '__main__':
    if os.path.exists(original_node_dat_path) and os.path.exists(info_dat_path):
        success = reindex_nodes(original_node_dat_path, info_dat_path, reindexed_node_dat_path, id_map_path)
        if success:
            print("\nIMPORTANT STEPS NEXT:")
            print(f"1. Review '{reindexed_node_dat_path}'. It should have contiguous global IDs starting from 0, grouped by type.")
            print(f"2. If it looks correct, backup your original 'node.dat' and rename '{os.path.basename(reindexed_node_dat_path)}' to 'node.dat'.")
            print(f"   cd {dataset_dir}")
            print(f"   mv node.dat node_very_original.dat")
            print(f"   mv node_reindexed.dat node.dat")
            print(f"3. **CRITICAL:** You MUST now update your 'link.dat' (and 'link.dat.test') to use these new global IDs.")
            print(f"   Use the map saved in '{id_map_path}' for this. This script does NOT update link.dat.")
        else:
            print("\nScript finished with errors. Please check messages above.")
    else:
        if not os.path.exists(original_node_dat_path):
            print(f"Error: Input node.dat file not found at {original_node_dat_path}")
        if not os.path.exists(info_dat_path):
            print(f"Error: Input info.dat file not found at {info_dat_path}")

Re-indexing nodes from: /data/luis/hgb/LP/benchmark/data/CellDrug/node.dat
Valid node type IDs from info.dat: {'2', '1', '0'}
Reading original node.dat...
Read 23815 valid node entries.
Sorting node data...
Assigning new global IDs and creating mapping...
Re-indexed 23815 nodes saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/node_reindexed.dat
Original ID to New ID map saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/original_to_new_id_map.json

IMPORTANT STEPS NEXT:
1. Review '/data/luis/hgb/LP/benchmark/data/CellDrug/node_reindexed.dat'. It should have contiguous global IDs starting from 0, grouped by type.
2. If it looks correct, backup your original 'node.dat' and rename 'node_reindexed.dat' to 'node.dat'.
   cd /data/luis/hgb/LP/benchmark/data/CellDrug/
   mv node.dat node_very_original.dat
   mv node_reindexed.dat node.dat
3. **CRITICAL:** You MUST now update your 'link.dat' (and 'link.dat.test') to use these new global IDs.
   Use the map saved in '/data/luis/hgb/LP/bench

In [15]:
import os
import json

# --- Configuration ---
dataset_dir = '/data/luis/hgb/LP/benchmark/data/CellDrug/' # Your dataset directory

# IMPORTANT: This should be your link file that STILL USES THE ORIGINAL NODE IDs
# For example, 'link_original.dat' (if you backed up before filtering) or 
# 'link_filtered.dat' (if you filtered it but it still uses original IDs).
# If your current link.dat still has the old IDs, that's fine.
# DO NOT use a link.dat that you *think* might have new IDs already.
input_link_dat_path = os.path.join(dataset_dir, '/data/luis/hgb/LP/benchmark/data/CellDrug/link.dat') # Or 'link_filtered.dat' or 'link_before_reindex.dat' - THE ONE WITH OLD IDs

id_map_path = os.path.join(dataset_dir, 'original_to_new_id_map.json') # Generated by reindex_nodes.py
output_reindexed_link_dat_path = os.path.join(dataset_dir, 'link_reindexed.dat') # New output file
# --- End Configuration ---

def update_link_ids(original_link_file, id_map_file, reindexed_link_file):
    print(f"Loading ID map from: {id_map_file}")
    try:
        with open(id_map_file, 'r', encoding='utf-8') as f:
            id_map = json.load(f) # Loads original string IDs to new integer IDs
            # Example: id_map might look like {"old_id_str_0": new_id_int_0, "old_id_str_1": new_id_int_1, ...}
        print("ID map loaded successfully.")
    except FileNotFoundError:
        print(f"ERROR: ID map file not found at {id_map_file}. "
              "Please ensure 'reindex_nodes.py' was run and 'original_to_new_id_map.json' was created.")
        return False
    except Exception as e:
        print(f"ERROR: Could not load or parse ID map file {id_map_file}: {e}")
        return False

    print(f"\nUpdating links from: {original_link_file}")
    print(f"Re-indexed links will be saved to: {reindexed_link_file}")

    lines_processed = 0
    lines_updated = 0
    lines_skipped_due_missing_id_in_map = 0
    lines_with_format_issues = 0
    first_n_problems_to_show = 20
    problems_shown_count = 0

    try:
        with open(original_link_file, 'r', encoding='utf-8') as infile, \
             open(reindexed_link_file, 'w', encoding='utf-8') as outfile:
            
            for line_number, line_content in enumerate(infile):
                line_content_stripped = line_content.strip()
                if not line_content_stripped:
                    outfile.write('\n')
                    continue
                
                lines_processed += 1
                parts = line_content_stripped.split('\t')

                if len(parts) == 4:
                    original_head_id_str = parts[0]
                    original_tail_id_str = parts[1]
                    relation_type_id_str = parts[2]
                    link_weight_str = parts[3]

                    # Check if original IDs are in the map
                    if original_head_id_str in id_map and original_tail_id_str in id_map:
                        new_head_id = id_map[original_head_id_str]
                        new_tail_id = id_map[original_tail_id_str]
                        
                        outfile.write(f"{new_head_id}\t{new_tail_id}\t{relation_type_id_str}\t{link_weight_str}\n")
                        lines_updated += 1
                    else:
                        lines_skipped_due_missing_id_in_map += 1
                        if problems_shown_count < first_n_problems_to_show:
                            if original_head_id_str not in id_map:
                                print(f"  SKIPPING Line {line_number + 1}: Original HEAD ID '{original_head_id_str}' not found in map. Line: '{line_content_stripped}'")
                            if original_tail_id_str not in id_map:
                                print(f"  SKIPPING Line {line_number + 1}: Original TAIL ID '{original_tail_id_str}' not found in map. Line: '{line_content_stripped}'")
                            problems_shown_count +=1
                else:
                    lines_with_format_issues += 1
                    if problems_shown_count < first_n_problems_to_show:
                        print(f"  FORMAT WARNING: Line {line_number + 1} does not have 4 parts. Writing as is: '{line_content_stripped}'")
                        problems_shown_count += 1
                    outfile.write(line_content_stripped + '\n') # Write as is if format is wrong

        print("\nLink Re-indexing Summary:")
        print(f"Total lines processed from original link file: {lines_processed}")
        print(f"Lines successfully updated with new IDs: {lines_updated}")
        print(f"Lines skipped (an original ID was not found in the map): {lines_skipped_due_missing_id_in_map}")
        print(f"Lines with format issues (not 4 columns, written as is): {lines_with_format_issues}")
        print(f"Re-indexed link file saved to: {reindexed_link_file}")
        return True

    except FileNotFoundError:
        print(f"ERROR: Original link file not found at {original_link_file}")
        return False
    except Exception as e:
        print(f"An error occurred during link re-indexing: {e}")
        return False

if __name__ == '__main__':
    # Make sure the reindexed node.dat is already in place (renamed to node.dat)
    # And that the original_to_new_id_map.json exists.
    if not os.path.exists(os.path.join(dataset_dir, 'node.dat')):
        print(f"Warning: Expected re-indexed 'node.dat' not found at {os.path.join(dataset_dir, 'node.dat')}. "
              "Ensure you have run reindex_nodes.py and renamed its output.")
              
    if os.path.exists(input_link_dat_path) and os.path.exists(id_map_path):
        success = update_link_ids(input_link_dat_path, id_map_path, output_reindexed_link_dat_path)
        if success and lines_skipped_due_missing_id_in_map == 0 and lines_with_format_issues == 0:
            print(f"\nSUCCESS: Link re-indexing complete. All links were updated or had correct format.")
            print(f"Next steps:")
            print(f"1. Backup your current 'link.dat' (e.g., 'mv {input_link_dat_path} link_before_id_update.dat')")
            print(f"2. Rename '{os.path.basename(output_reindexed_link_dat_path)}' to 'link.dat'")
            print(f"   (e.g., 'mv {output_reindexed_link_dat_path} {os.path.join(dataset_dir, 'link.dat')}')")
            print(f"3. Do the same for 'link.dat.test' if you have one.")
            print(f"4. IMPORTANT: Delete the old dl_pickle file before running main.py again:")
            print(f"   rm /data/luis/hgb/LP/benchmark/methods/HetGNN/CellDrug-temp/CellDrug_dl_pickle.pkl") # Adjust path if needed
        elif success:
            print(f"\nPARTIAL SUCCESS: Link re-indexing completed, but some lines were skipped or had format issues.")
            print(f"Please review the warnings and the output file '{output_reindexed_link_dat_path}' carefully.")
            print(f"If the number of skipped/problematic lines is acceptable, you can proceed with renaming.")
        else:
            print("\nLink re-indexing script encountered an error.")
    else:
        if not os.path.exists(input_link_dat_path):
            print(f"Error: Input link file not found at {input_link_dat_path}. This should be your link file with OLD IDs.")
        if not os.path.exists(id_map_path):
            print(f"Error: ID map file ('original_to_new_id_map.json') not found at {id_map_path}.")

Loading ID map from: /data/luis/hgb/LP/benchmark/data/CellDrug/original_to_new_id_map.json
ID map loaded successfully.

Updating links from: /data/luis/hgb/LP/benchmark/data/CellDrug/link.dat
Re-indexed links will be saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/link_reindexed.dat

Link Re-indexing Summary:
Total lines processed from original link file: 717022
Lines successfully updated with new IDs: 717022
Lines skipped (an original ID was not found in the map): 0
Lines with format issues (not 4 columns, written as is): 0
Re-indexed link file saved to: /data/luis/hgb/LP/benchmark/data/CellDrug/link_reindexed.dat


NameError: name 'lines_skipped_due_missing_id_in_map' is not defined