### Testing approach for protein-complex aware train / test split for cross-validation

In [39]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import random
import itertools

def generate_fast_dummy_gene_complex_data(
    n_genes=100,
    n_complexes=20,
    n_pairs=1000,
    pos_fraction=0.3,
    random_state=42
):
    random.seed(random_state)
    np.random.seed(random_state)

    # Step 1: Generate genes and assign to complexes
    genes = [f"GENE{i}" for i in range(n_genes)]
    complexes = [f"Complex_{i}" for i in range(n_complexes)]

    gene_to_complex = {
        gene: random.choice(complexes)
        for gene in genes
    }

    complex_map = pd.DataFrame({
        "Gene": list(gene_to_complex.keys()),
        "Complex": list(gene_to_complex.values())
    })

    # Step 2: Build all gene pairs
    complex_to_genes = complex_map.groupby("Complex")["Gene"].apply(list)

    # Positive pairs: same complex
    positive_pairs = []
    for genes_in_complex in complex_to_genes:
        pairs = list(itertools.combinations(sorted(genes_in_complex), 2))
        positive_pairs.extend(pairs)

    # Sample positive pairs
    n_pos = int(n_pairs * pos_fraction)
    pos_sample = random.sample(positive_pairs, min(n_pos, len(positive_pairs)))

    # Negative pairs: different complexes
    all_pairs = list(itertools.combinations(sorted(genes), 2))
    positive_set = set(pos_sample)
    negative_pairs = [pair for pair in all_pairs if gene_to_complex[pair[0]] != gene_to_complex[pair[1]]]
    n_neg = n_pairs - len(pos_sample)
    neg_sample = random.sample(negative_pairs, min(n_neg, len(negative_pairs)))

    # Build final DataFrame
    data = pos_sample + neg_sample
    labels = [1] * len(pos_sample) + [0] * len(neg_sample)

    df_pairs = pd.DataFrame(data, columns=["Gene_A", "Gene_B"])
    df_pairs["Label"] = labels

    return df_pairs, complex_map

def assign_complex_to_row(df, complex_map):
    """Map each gene pair to a shared complex, or unique ID if they differ."""
    # Merge complex info
    df = df.copy()
    complex_map = complex_map.rename(columns={'Gene': 'Gene_A', 'Complex': 'Complex_A'})
    df = df.merge(complex_map, on='Gene_A', how='left')

    complex_map = complex_map.rename(columns={'Gene_A': 'Gene_B', 'Complex_A': 'Complex_B'})
    df = df.merge(complex_map, on='Gene_B', how='left')

    def determine_group(row):
        if pd.notna(row['Complex_A']) and row['Complex_A'] == row['Complex_B']:
            return row['Complex_A']
        else:
            return f"UNRELATED_{hash(frozenset([row['Complex_A'], row['Complex_B']]))}"

    df['Complex_Group'] = df.apply(determine_group, axis=1)
    return df

def stratified_split_by_complex(df, test_size=0.2, random_state=42):
    """
    Ensure no overlapping complexes between train and test.
    """
    df = df.copy()
    unique_complexes = df['Complex_Group'].unique()

    # Stratify on whether the complex group label is 'in complex' or not
    # You could instead stratify based on counts or labels depending on balance
    stratify_labels = df.drop_duplicates('Complex_Group')[['Complex_Group', 'Label']].set_index('Complex_Group').reindex(unique_complexes)['Label'].fillna(0)

    train_complexes, test_complexes = train_test_split(
        unique_complexes,
        test_size=test_size,
        random_state=random_state,
        stratify=stratify_labels
    )

    train_df = df[df['Complex_Group'].isin(train_complexes)].reset_index(drop=True)
    test_df = df[df['Complex_Group'].isin(test_complexes)].reset_index(drop=True)

    return train_df, test_df

def extract_involved_complexes(row):
    complexes = set()
    if pd.notna(row['Complex_A']):
        complexes.add(row['Complex_A'])
    if pd.notna(row['Complex_B']):
        complexes.add(row['Complex_B'])
    return complexes

def stratified_split_by_complex_no_leak(df, test_size=0.2, random_state=42):
    df = df.copy()

    # Collect complex sets per row
    df['Complexes_Set'] = df.apply(extract_involved_complexes, axis=1)

    # Flatten to all unique complexes
    all_complexes = set().union(*df['Complexes_Set'])

    # Stratify based on presence of positive labels per complex (approx)
    complex_label_map = {}
    for complex_id in all_complexes:
        rows_with_complex = df[df['Complexes_Set'].apply(lambda x: complex_id in x)]
        complex_label_map[complex_id] = int(rows_with_complex['Label'].mean() > 0)

    complexes = list(complex_label_map.keys())
    labels = [complex_label_map[c] for c in complexes]

    # Split the complexes
    train_complexes, test_complexes = train_test_split(
        complexes,
        test_size=test_size,
        random_state=random_state,
        stratify=labels
    )

    train_complexes = set(train_complexes)
    test_complexes = set(test_complexes)

    # Filter rows where all involved complexes are in the same group
    train_df = df[df['Complexes_Set'].apply(lambda s: s.issubset(train_complexes))].reset_index(drop=True)
    test_df = df[df['Complexes_Set'].apply(lambda s: s.issubset(test_complexes))].reset_index(drop=True)

    return train_df, test_df


In [3]:
df = pd.DataFrame({
        'Gene_A': ['BRCA1', 'TP53', 'BRCA2', 'RAD51', 'MTOR', 'PIK3CA'],
        'Gene_B': ['BRCA2', 'MDM2', 'PALB2', 'RAD52', 'AKT1', 'AKT2'],
        'Label':   [1, 1, 1, 0, 1, 0]})

complex_map = pd.DataFrame({
        'Gene': ['BRCA1', 'BRCA2', 'PALB2', 'TP53', 'MDM2', 'RAD51', 'RAD52', 'MTOR', 'AKT1', 'PIK3CA'],
        'Complex': ['Complex1', 'Complex1', 'Complex1', 'Complex2', 'Complex2', 'Complex3', 'Complex4', 'Complex5', 'Complex5', 'Complex6']
    })


In [4]:
df

Unnamed: 0,Gene_A,Gene_B,Label
0,BRCA1,BRCA2,1
1,TP53,MDM2,1
2,BRCA2,PALB2,1
3,RAD51,RAD52,0
4,MTOR,AKT1,1
5,PIK3CA,AKT2,0


In [5]:
complex_map

Unnamed: 0,Gene,Complex
0,BRCA1,Complex1
1,BRCA2,Complex1
2,PALB2,Complex1
3,TP53,Complex2
4,MDM2,Complex2
5,RAD51,Complex3
6,RAD52,Complex4
7,MTOR,Complex5
8,AKT1,Complex5
9,PIK3CA,Complex6


In [18]:
df, complex_map = generate_fast_dummy_gene_complex_data()

In [19]:
df

Unnamed: 0,Gene_A,Gene_B,Label
0,GENE44,GENE64,1
1,GENE23,GENE75,1
2,GENE46,GENE61,1
3,GENE35,GENE6,1
4,GENE64,GENE96,1
...,...,...,...
995,GENE6,GENE64,0
996,GENE36,GENE97,0
997,GENE50,GENE57,0
998,GENE18,GENE65,0


In [20]:
complex_map

Unnamed: 0,Gene,Complex
0,GENE0,Complex_3
1,GENE1,Complex_0
2,GENE2,Complex_8
3,GENE3,Complex_7
4,GENE4,Complex_7
...,...,...
95,GENE95,Complex_12
96,GENE96,Complex_14
97,GENE97,Complex_4
98,GENE98,Complex_8


In [21]:
df_with_complex = assign_complex_to_row(df, complex_map)


In [22]:
df_with_complex

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group
0,GENE44,GENE64,1,Complex_14,Complex_14,Complex_14
1,GENE23,GENE75,1,Complex_7,Complex_7,Complex_7
2,GENE46,GENE61,1,Complex_3,Complex_3,Complex_3
3,GENE35,GENE6,1,Complex_3,Complex_3,Complex_3
4,GENE64,GENE96,1,Complex_14,Complex_14,Complex_14
...,...,...,...,...,...,...
995,GENE6,GENE64,0,Complex_3,Complex_14,UNRELATED_-8900013985987075014
996,GENE36,GENE97,0,Complex_2,Complex_4,UNRELATED_-5499943997916966653
997,GENE50,GENE57,0,Complex_9,Complex_7,UNRELATED_-88597226039519464
998,GENE18,GENE65,0,Complex_0,Complex_11,UNRELATED_3652095148036549621


In [26]:
train_df, test_df = stratified_split_by_complex(df_with_complex)

In [27]:
train_df

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group
0,GENE44,GENE64,1,Complex_14,Complex_14,Complex_14
1,GENE23,GENE75,1,Complex_7,Complex_7,Complex_7
2,GENE46,GENE61,1,Complex_3,Complex_3,Complex_3
3,GENE35,GENE6,1,Complex_3,Complex_3,Complex_3
4,GENE64,GENE96,1,Complex_14,Complex_14,Complex_14
...,...,...,...,...,...,...
775,GENE6,GENE64,0,Complex_3,Complex_14,UNRELATED_-8900013985987075014
776,GENE36,GENE97,0,Complex_2,Complex_4,UNRELATED_-5499943997916966653
777,GENE50,GENE57,0,Complex_9,Complex_7,UNRELATED_-88597226039519464
778,GENE18,GENE65,0,Complex_0,Complex_11,UNRELATED_3652095148036549621


In [29]:
test_df

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group
0,GENE5,GENE99,1,Complex_4,Complex_4,Complex_4
1,GENE31,GENE42,1,Complex_8,Complex_8,Complex_8
2,GENE70,GENE79,1,Complex_8,Complex_8,Complex_8
3,GENE2,GENE88,1,Complex_8,Complex_8,Complex_8
4,GENE14,GENE69,1,Complex_6,Complex_6,Complex_6
...,...,...,...,...,...,...
215,GENE40,GENE89,0,Complex_11,Complex_2,UNRELATED_-1112816051940215057
216,GENE1,GENE31,0,Complex_0,Complex_8,UNRELATED_7055619691820458489
217,GENE46,GENE91,0,Complex_3,Complex_18,UNRELATED_-6345802722480175068
218,GENE10,GENE27,0,Complex_13,Complex_0,UNRELATED_2807033227285835388


In [30]:
## perhaps want to change this so that even the 'UNRELATED' complexes are not in the test set 

In [32]:
train, test = stratified_split_by_complex_no_leak(df_with_complex)

In [33]:
train

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set
0,GENE44,GENE64,1,Complex_14,Complex_14,Complex_14,{Complex_14}
1,GENE23,GENE75,1,Complex_7,Complex_7,Complex_7,{Complex_7}
2,GENE46,GENE61,1,Complex_3,Complex_3,Complex_3,{Complex_3}
3,GENE35,GENE6,1,Complex_3,Complex_3,Complex_3,{Complex_3}
4,GENE64,GENE96,1,Complex_14,Complex_14,Complex_14,{Complex_14}
...,...,...,...,...,...,...,...
695,GENE13,GENE61,0,Complex_2,Complex_3,UNRELATED_1810200194389755326,"{Complex_2, Complex_3}"
696,GENE6,GENE64,0,Complex_3,Complex_14,UNRELATED_-8900013985987075014,"{Complex_3, Complex_14}"
697,GENE36,GENE97,0,Complex_2,Complex_4,UNRELATED_-5499943997916966653,"{Complex_2, Complex_4}"
698,GENE50,GENE57,0,Complex_9,Complex_7,UNRELATED_-88597226039519464,"{Complex_7, Complex_9}"


In [35]:
train[train['Complex_A']=='Complex_6']

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set


In [36]:
train[train['Complex_B']=='Complex_6']

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set


In [34]:
test

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set
0,GENE10,GENE22,1,Complex_13,Complex_13,Complex_13,{Complex_13}
1,GENE14,GENE69,1,Complex_6,Complex_6,Complex_6,{Complex_6}
2,GENE54,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
3,GENE28,GENE73,1,Complex_5,Complex_5,Complex_5,{Complex_5}
4,GENE17,GENE72,1,Complex_19,Complex_19,Complex_19,{Complex_19}
5,GENE17,GENE51,1,Complex_19,Complex_19,Complex_19,{Complex_19}
6,GENE20,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
7,GENE20,GENE33,1,Complex_6,Complex_6,Complex_6,{Complex_6}
8,GENE66,GENE76,1,Complex_5,Complex_5,Complex_5,{Complex_5}
9,GENE28,GENE76,1,Complex_5,Complex_5,Complex_5,{Complex_5}


In [37]:
test[test['Complex_A']=='Complex_6']

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set
1,GENE14,GENE69,1,Complex_6,Complex_6,Complex_6,{Complex_6}
2,GENE54,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
6,GENE20,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
7,GENE20,GENE33,1,Complex_6,Complex_6,Complex_6,{Complex_6}
10,GENE33,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
11,GENE14,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
15,GENE54,GENE90,1,Complex_6,Complex_6,Complex_6,{Complex_6}
17,GENE14,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
18,GENE90,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
22,GENE33,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}


In [38]:
test[test['Complex_B']=='Complex_6']

Unnamed: 0,Gene_A,Gene_B,Label,Complex_A,Complex_B,Complex_Group,Complexes_Set
1,GENE14,GENE69,1,Complex_6,Complex_6,Complex_6,{Complex_6}
2,GENE54,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
6,GENE20,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
7,GENE20,GENE33,1,Complex_6,Complex_6,Complex_6,{Complex_6}
10,GENE33,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
11,GENE14,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
15,GENE54,GENE90,1,Complex_6,Complex_6,Complex_6,{Complex_6}
17,GENE14,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
18,GENE90,GENE93,1,Complex_6,Complex_6,Complex_6,{Complex_6}
22,GENE33,GENE54,1,Complex_6,Complex_6,Complex_6,{Complex_6}
