In [1]:
!pwd

/home/bournelab/qbi_hackathon


In [40]:
from pathlib import Path
import pandas as pd

In [3]:
clusters = Path("codnas_groups_pdb")

In [9]:
from biotite.structure.io.pdb import PDBFile
from biotite.structure import get_residue_starts
from biotite.sequence import ProteinSequence

In [26]:
def three_to_one(aa):
    try:
        return ProteinSequence.convert_letter_3to1(aa)
    except KeyError:
        return "X"

In [31]:
for cluster_dir in clusters.rglob("*"):
    cluster = cluster_dir.stem
    cluster_fasta = f"cluster_seqs/{cluster}_seqs.fasta"
    print(cluster_fasta)
    with open(f"cluster_seqs/{cluster}_seqs.fasta", "w") as f:
        for pdb_file in cluster_dir.glob("*.pdb"):
            if "fixed" in pdb_file.stem: continue
            parts = pdb_file.stem.split("-")
            pdb_f = PDBFile.read(str(pdb_file))
            pdb = pdb_f.get_structure(model=1)
            residues = "".join([three_to_one(aa) for aa in pdb.res_name[get_residue_starts(pdb)]])
            print(f">{parts[0]}_{parts[-1]} {parts[1] if len(pdb_file.stem) > 3 else ''}\n{residues}", file=f)
    print("done")
            

cluster_seqs/1EQT_A_seqs.fasta
done
cluster_seqs/2APO_B_seqs.fasta
done
cluster_seqs/3HWE_A_seqs.fasta
done
cluster_seqs/1ZOQ_C_seqs.fasta
done
cluster_seqs/2ELR_A_seqs.fasta
done
cluster_seqs/5B6Q_A_seqs.fasta
done
cluster_seqs/2MS8_A_seqs.fasta
done
cluster_seqs/2RUG_A_seqs.fasta
done
cluster_seqs/1X9V_A_seqs.fasta
done
cluster_seqs/4HI8_A_seqs.fasta
done
cluster_seqs/6OGK_A_seqs.fasta
done
cluster_seqs/2BJE_A_seqs.fasta
done
cluster_seqs/1X0F_A_seqs.fasta
done
cluster_seqs/2YT0_A_seqs.fasta
done
cluster_seqs/4YUD_A_seqs.fasta
done
cluster_seqs/6AC0_A_seqs.fasta
done
cluster_seqs/1OFF_A_seqs.fasta
done
cluster_seqs/2KP2_A_seqs.fasta
done
cluster_seqs/1HO9_A_seqs.fasta
done
cluster_seqs/1BHB_A_seqs.fasta
done
cluster_seqs/2L87_A_seqs.fasta
done
cluster_seqs/5URN_B_seqs.fasta
done
cluster_seqs/2RPQ_A_seqs.fasta
done
cluster_seqs/2L66_A_seqs.fasta
done
cluster_seqs/2Q44_A_seqs.fasta
done
cluster_seqs/1VPU_A_seqs.fasta
done
cluster_seqs/1F7M_A_seqs.fasta
done
cluster_seqs/6F4J_C_seqs.fas

In [81]:
from tqdm import tqdm

In [86]:
with open("full_data.csv", "w") as full_data:
    print("pdb_index,cluster_id,pdb_cluster_id,file_name,pdb,chain,model,causes", file=full_data)
    i = 0
    for cluster_seqs in tqdm(list(Path("cluster_seqs/full").glob("*.fasta"))):
        cluster_name = cluster_seqs.stem
        with cluster_seqs.open() as fasta:
            for j, line in enumerate(fasta):
                fields = line[1:].rstrip().split()
                pdb, chain = fields[0].split("_")
                if len(fields) > 1:
                    model = int(fields[1])
                else:
                    model = None
                seq = next(fasta)

                df = pd.read_csv(f"codnas_raw_data/{cluster_name[:6]}.csv", index_col=0)
    
                if "Mammoth_RMS" not in df.columns:
                    continue
                    
                df = df[pd.to_numeric(df['Mammoth_RMS'], errors='coerce').notnull()]
            
                df["Mammoth_RMS"] = df["Mammoth_RMS"].astype(float)
                
                if len(df[df.Mammoth_RMS>4])==0:
                    continue
            
                try:
                    df = df.groupby(pd.cut(df["Mammoth_RMS"], np.arange(0, df["Mammoth_RMS"].max(), df["Mammoth_RMS"].quantile(0.2)))).head(4)
                except ValueError:
                    continue
                
            
                try:
                    df = df.assign(
                        pdb1=df["PDB_1"].str[:4],
                        pdb2=df["PDB_2"].str[:4],
                        model1=df["PDB_1"].astype(str).apply(lambda s: s.split("-")[1].split("_")[0] if "-" in s else -1).astype(int),
                        model2=df["PDB_2"].astype(str).apply(lambda s: s.split("-")[1].split("_")[0] if "-" in s else -1).astype(int),
                        chain1=df["PDB_1"].str.split("_", expand=True)[1],
                        chain2=df["PDB_2"].str.split("_", expand=True)[1],
                    )
                except (TypeError, KeyError):
                    continue
    
                if model is None:
                    small_df = df[((df["pdb1"]==pdb)&(df["chain1"]==chain))|((df["pdb2"]==pdb)&(df["chain2"]==chain))]
                else:
                    small_df = df[((df["pdb1"]==pdb)&(df["chain1"]==chain)&(df["model1"]==model))|((df["pdb2"]==pdb)&(df["chain2"]==chain)&(df["model2"]==model))]

                causes = ";".join(set(small_df["Causes_DC"].str.split(",", expand=True).dropna().values.flatten()))

                if model is None:
                    file_name = f"codnas_groups_pdb/{cluster_name}/{pdb}--1-{chain}_fixed.pdb"
                else:
                    file_name = f"codnas_groups_pdb/{cluster_name}/{pdb}-{model}-{chain}_fixed.pdb"
                print(i, cluster_name, f"{cluster_name}_{j}", file_name, pdb, chain, model, causes, sep=", ", file=full_data)

                i += 1
            

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1428/1428 [26:36<00:00,  1.12s/it]


In [85]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [145]:
all_data_df = pd.read_csv("full_data.csv", sep=", ")

  all_data_df = pd.read_csv("full_data.csv", sep=", ")


In [146]:
all_data_df

Unnamed: 0,pdb_index,cluster_id,pdb_cluster_id,file_name,pdb,chain,model,causes
0,0,4JNE_A_seqs,4JNE_A_seqs_0,codnas_groups_pdb/4JNE_A_seqs/2VB5-6-A_fixed.pdb,2VB5,A,6,MUTATION
1,1,4JNE_A_seqs,4JNE_A_seqs_1,codnas_groups_pdb/4JNE_A_seqs/2VB5-17-A_fixed.pdb,2VB5,A,17,MUTATION
2,2,4JNE_A_seqs,4JNE_A_seqs_2,codnas_groups_pdb/4JNE_A_seqs/2VB5-20-A_fixed.pdb,2VB5,A,20,MUTATION
3,3,4JNE_A_seqs,4JNE_A_seqs_3,codnas_groups_pdb/4JNE_A_seqs/2VB5-15-A_fixed.pdb,2VB5,A,15,MUTATION
4,4,4JNE_A_seqs,4JNE_A_seqs_4,codnas_groups_pdb/4JNE_A_seqs/2KF2-11-A_fixed.pdb,2KF2,A,11,LIGAND;TEMPERATURE;LIGAND BIOLIP;PH;MUTATION
...,...,...,...,...,...,...,...,...
19197,19197,1AP0_A_seqs,1AP0_A_seqs_4,codnas_groups_pdb/1AP0_A_seqs/1AP0-6-A_fixed.pdb,1AP0,A,6,MUTATION
19198,19198,1AP0_A_seqs,1AP0_A_seqs_5,codnas_groups_pdb/1AP0_A_seqs/1AP0-3-A_fixed.pdb,1AP0,A,3,MUTATION
19199,19199,1AP0_A_seqs,1AP0_A_seqs_6,codnas_groups_pdb/1AP0_A_seqs/1AP0-1-A_fixed.pdb,1AP0,A,1,MUTATION
19200,19200,1AP0_A_seqs,1AP0_A_seqs_7,codnas_groups_pdb/1AP0_A_seqs/1AP0-16-A_fixed.pdb,1AP0,A,16,MUTATION


In [158]:
all_data_df = all_data_df.assign(**{f"embedding{x}":np.nan for x in range(1280)})
all_data_df.head()

Unnamed: 0,pdb_index,cluster_id,pdb_cluster_id,file_name,pdb,chain,model,causes,embedding0,embedding1,...,embedding1270,embedding1271,embedding1272,embedding1273,embedding1274,embedding1275,embedding1276,embedding1277,embedding1278,embedding1279
0,0,4JNE_A_seqs,4JNE_A_seqs_0,codnas_groups_pdb/4JNE_A_seqs/2VB5-6-A_fixed.pdb,2VB5,A,6,MUTATION,,,...,,,,,,,,,,
1,1,4JNE_A_seqs,4JNE_A_seqs_1,codnas_groups_pdb/4JNE_A_seqs/2VB5-17-A_fixed.pdb,2VB5,A,17,MUTATION,,,...,,,,,,,,,,
2,2,4JNE_A_seqs,4JNE_A_seqs_2,codnas_groups_pdb/4JNE_A_seqs/2VB5-20-A_fixed.pdb,2VB5,A,20,MUTATION,,,...,,,,,,,,,,
3,3,4JNE_A_seqs,4JNE_A_seqs_3,codnas_groups_pdb/4JNE_A_seqs/2VB5-15-A_fixed.pdb,2VB5,A,15,MUTATION,,,...,,,,,,,,,,
4,4,4JNE_A_seqs,4JNE_A_seqs_4,codnas_groups_pdb/4JNE_A_seqs/2KF2-11-A_fixed.pdb,2KF2,A,11,LIGAND;TEMPERATURE;LIGAND BIOLIP;PH;MUTATION,,,...,,,,,,,,,,


In [173]:
all_data_df["model"] = all_data_df["model"].fillna("None,")

In [184]:
all_data_df["model"] = all_data_df["model"].apply(lambda s: s[:-1] if s.endswith(",") else s).apply(lambda s: "nan" if s=="None" else s)

In [126]:
from tqdm import tqdm
for emb_file in tqdm(Path.cwd().glob("*.pt")):
    emb = torch.load(str(emb_file))["mean_representations"][33]
    print(emb)
    np.save(str(emb_file.with_suffix(".npy")), emb.numpy())

/home/bournelab/miniconda_qbi/envs/ConfDiff/bin/python


In [156]:
emb_cols = [f"embedding{x}" for x in range(1280)]

In [212]:
all_embs = []
for emb_file in tqdm(list(Path("embeddings/New_ESM1_Embeddings").glob("*.npy"))):
    #7L55_A_3_19200.npy
    #7JQC_F__19061.npy
    fields = emb_file.stem.split("_")
    pdb = fields[0]
    chain = fields[1]
    model = fields[2] if fields[2] != "" else "nan"
    emb = np.load(str(emb_file))
    all_embs.append([pdb,chain,model,*emb.tolist()])
    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5407/5407 [00:00<00:00, 8527.72it/s]


In [214]:
all_embs = pd.DataFrame(all_embs, columns=["pdb", "chain", "model", *emb_cols])

In [241]:
all_embs.to_csv("all_embeddings.csv")

In [232]:
cluster_reps_emb_df = pd.merge(all_data_df.drop(columns=emb_cols), all_embs, on=["pdb", "chain", "model"])

In [239]:
cluster_reps_emb_df.drop_duplicates(subset=["cluser_id"])

KeyError: Index(['cluser_id'], dtype='object')

In [240]:
cluster_reps_emb_df

Unnamed: 0,pdb_index,cluster_id,pdb_cluster_id,file_name,pdb,chain,model,causes,embedding0,embedding1,...,embedding1270,embedding1271,embedding1272,embedding1273,embedding1274,embedding1275,embedding1276,embedding1277,embedding1278,embedding1279
0,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,-0.090952,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
1,8,1Y7J_A_seqs,1Y7J_A_seqs_1,codnas_groups_pdb/1Y7J_A_seqs/2KZA-11-A_fixed.pdb,2KZA,A,11,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,-0.090952,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
2,9,1Y7J_A_seqs,1Y7J_A_seqs_2,codnas_groups_pdb/1Y7J_A_seqs/2KZA-10-A_fixed.pdb,2KZA,A,10,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,-0.090952,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
3,10,1Y7J_A_seqs,1Y7J_A_seqs_3,codnas_groups_pdb/1Y7J_A_seqs/2KZA-3-A_fixed.pdb,2KZA,A,3,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,-0.090952,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
4,11,1Y7J_A_seqs,1Y7J_A_seqs_4,codnas_groups_pdb/1Y7J_A_seqs/2KZA-6-A_fixed.pdb,2KZA,A,6,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,-0.090952,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5552,19053,6P9Y_P_seqs,6P9Y_P_seqs_2,codnas_groups_pdb/6P9Y_P_seqs/6LPB--1-P_fixed.pdb,6LPB,P,,DISORDER,-0.100420,0.085383,...,0.068483,0.165166,-0.056988,-0.094667,0.050656,0.074851,-0.004986,-0.044384,-0.096862,0.215431
5553,19056,6P9Y_P_seqs,6P9Y_P_seqs_5,codnas_groups_pdb/6P9Y_P_seqs/6P9Y--1-P_fixed.pdb,6P9Y,P,,LIGAND;PH;DISORDER,-0.082104,0.119497,...,0.082608,0.244709,-0.071787,-0.093098,0.082454,0.099108,-0.001207,-0.049472,-0.112675,0.278666
5554,19066,1AN1_I_seqs,1AN1_I_seqs_0,codnas_groups_pdb/1AN1_I_seqs/1AN1-1-I_fixed.pdb,1AN1,I,1,,-0.270084,0.032212,...,-0.083465,0.120330,0.035210,-0.018941,-0.313733,-0.125486,-0.113282,-0.110297,-0.135823,0.125113
5555,19117,1L4W_A_seqs,1L4W_A_seqs_12,codnas_groups_pdb/1L4W_A_seqs/1L4W-1-B_fixed.pdb,1L4W,B,1,,-0.186269,0.093108,...,-0.128212,0.101672,0.013218,0.046467,-0.113594,0.041483,-0.203715,-0.088847,-0.064474,0.253057


In [236]:
pd.merge(
    all_data_df.drop(columns=emb_cols), 
    cluster_reps_emb_df.drop(columns=["pdb_index", "pdb_cluster_id", "file_name", "pdb", "chain", "model"]),
    how="inner",
    on="cluster_id")

Unnamed: 0,pdb_index,cluster_id,pdb_cluster_id,file_name,pdb,chain,model,causes_x,causes_y,embedding0,...,embedding1270,embedding1271,embedding1272,embedding1273,embedding1274,embedding1275,embedding1276,embedding1277,embedding1278,embedding1279
0,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
1,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
2,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
3,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
4,7,1Y7J_A_seqs,1Y7J_A_seqs_0,codnas_groups_pdb/1Y7J_A_seqs/2KZA-2-A_fixed.pdb,2KZA,A,2,TEMPERATURE;PH;DISORDER;MUTATION,TEMPERATURE;PH;DISORDER;MUTATION,-0.143577,...,0.129885,0.049554,0.182815,-0.127217,-0.249767,-0.023298,0.016290,0.038949,0.096279,0.228959
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
130221,19190,1L4W_A_seqs,1L4W_A_seqs_85,codnas_groups_pdb/1L4W_A_seqs/1IDL-1-A_fixed.pdb,1IDL,A,1,,,-0.186269,...,-0.128212,0.101672,0.013218,0.046467,-0.113594,0.041483,-0.203715,-0.088847,-0.064474,0.253057
130222,19191,1L4W_A_seqs,1L4W_A_seqs_86,codnas_groups_pdb/1L4W_A_seqs/1IKC-1-A_fixed.pdb,1IKC,A,1,TEMPERATURE;PH,,-0.186269,...,-0.128212,0.101672,0.013218,0.046467,-0.113594,0.041483,-0.203715,-0.088847,-0.064474,0.253057
130223,19191,1L4W_A_seqs,1L4W_A_seqs_86,codnas_groups_pdb/1L4W_A_seqs/1IKC-1-A_fixed.pdb,1IKC,A,1,TEMPERATURE;PH,,-0.186269,...,-0.128212,0.101672,0.013218,0.046467,-0.113594,0.041483,-0.203715,-0.088847,-0.064474,0.253057
130224,19192,1L4W_A_seqs,1L4W_A_seqs_87,codnas_groups_pdb/1L4W_A_seqs/1HC9--1-A_fixed.pdb,1HC9,A,,,,-0.186269,...,-0.128212,0.101672,0.013218,0.046467,-0.113594,0.041483,-0.203715,-0.088847,-0.064474,0.253057


In [242]:
all_data_df.drop(columns=emb_cols).to_csv("all_input_data.csv")

In [163]:
"7JQC_F__19061".split("_")

['7JQC', 'F', '', '19061']

In [110]:
trainValClusters = pd.read_csv("trainValClusters_cluster.tsv", sep="\t", header=None, names=["group", "pdb"])

In [93]:
for name, grp in df.groupby("group"):
    print(name, grp.shape)

1A1X_A__3502 (14, 2)
1A5E_A_2_5446 (11, 2)
1A5R_A_1_32 (42, 2)
1A66_A_5_70 (9, 2)
1A6X_A_6_101 (17, 2)
1A8V_B__54 (11, 2)
1ACZ_C_3_145 (9, 2)
1AFO_A_2_8565 (30, 2)
1AHD_A_2_193 (7, 2)
1AJW_A__1568 (11, 2)
1AK7_A_7_208 (15, 2)
1AN1_I_1_219 (1, 2)
1AP7_A_1_482 (1, 2)
1AQT_A__18867 (4, 2)
1AX3_A_4_240 (9, 2)
1B4G_A_3_5359 (15, 2)
1B7F_B__14913 (46, 2)
1B8Q_B_7_3350 (9, 2)
1B9Q_A_8_295 (9, 2)
1BAL_A_4_4428 (26, 2)
1BET_A__16882 (10, 2)
1BHB_A_1_410 (15, 2)
1BIP_A_1_4057 (1, 2)
1BJX_A_1_446 (13, 2)
1BNP_A_1_483 (15, 2)
1BON_B_4_515 (8, 2)
1BSE_C_1_4863 (13, 2)
1BUI_C__13421 (10, 2)
1BWX_A_6_1757 (12, 2)
1BYV_A_9_464 (11, 2)
1C17_B__121 (13, 2)
1C3Y_A_11_529 (12, 2)
1C6W_A_18_8664 (11, 2)
1C89_A_25_550 (10, 2)
1C9F_A_1_1240 (1, 2)
1CDC_A__57 (13, 2)
1CEJ_A_7_560 (9, 2)
1CFG_A_1_1329 (15, 2)
1CJ5_A_6_586 (9, 2)
1CKY_A_2_610 (24, 2)
1CM9_A__649 (38, 2)
1CWW_A_19_701 (15, 2)
1CYU_A_1_2820 (2, 2)
1D2Z_A_1_716 (4, 2)
1D8J_A_6_720 (9, 2)
1D9N_A_2_729 (15, 2)
1DAV_A_9_744 (23, 2)
1DDF_A__14280 (1, 

In [117]:
from numbers import Number
def split_dataset_at_level(df: pd.DataFrame, level_key: str, level_name: str,
                           split_size: dict[str, float] = {"train":0.8, "validation":0.1, "test":0.1}) -> None:
    """Split a dataset into train/validation/test sets, saving the splits into new h5 groups with
    links back to the the main dataset.
    
    Parameters
    ----------
    cath_full_h5 : str
        Path to H5 file on HSDS enpoint
    superfamily : str
        Group prefix, can be empty ('') for h5 file
    sfam_df : pd.DataFrame
        The data frame to split. Each row must a single protein and the df M\must contain 2 columns: 
            (i) "cath_domain", the column of protein domain names, must match groups of the same name in this 'superfamily' group;
            (ii) level_key, custom variable name for the name of the cluster the protein domain belongs to
    level_key : str
        Name of the column that contains cluster names for use in pandas groupby
    level_name : str
        Name of the data split, e.g. "S35"
    split_size : Dict [split_name->split perecent]
        A dictionary containing the total number of splits
    """
    if isinstance(split_size, (list, tuple)):
        if len(split_size)==1:
            assert split_size[0]<1
            other_size = (1-split_size[0])/2
            split_size = dict(zip(["train", "validation", "test"],
                sorted([split_size[0], other_size, other_size])))
        else:
            assert sum(split_size) == 1
            split_size = {f"split{i}":s for i, s in enumerate(split_size)}
    elif isinstance(split_size, Number) and split_size<1:
        other_size = (1-split_size)/2
        split_size = {"train":split_size, "validation":other_size, "test":other_size}
    elif isinstance(split_size, dict) and sum(split_size.values()) == 1:
        #Correct
        pass
    else:
        raise RuntimeError("Invalid split_size. Must be dict {split_name:split_pct}, a list of split sizes (names automatically assinged), or a single number")

    start = 0
    subsets = []

    split_sizes = sorted(split_size.items(), key=lambda x: x[1], reverse=True)

    clusters = df.groupby(level_key)
    sorted_cluster_indices = list(sorted(clusters.indices.keys()))
    split_index_start = 0
    last_size = [None, None]

    for split_num, (split_name, split_pct) in enumerate(split_sizes):
        if split_num < len(split_sizes)-1:
            ideal_set1_size = int(clusters.ngroups*split_pct)

            while True:
                set1_clusters = sorted_cluster_indices[split_index_start:split_index_start+ideal_set1_size]
                set1 = [idx for cluster in set1_clusters for idx in clusters.get_group(cluster).index] #.indices[cluster]]
                size_pct = len(set1)/(len(df))
                print("size", len(set1), len(df), size_pct, ideal_set1_size)
                if size_pct in last_size:
                    break
                if size_pct > split_pct+.01:
                    ideal_set1_size -= 1
                elif size_pct < split_pct-.01:
                    ideal_set1_size += 1
                else:
                    break

                last_size[0] = last_size[1]
                last_size[1] = size_pct

            subset_idx = list(sorted(set1))
            subset = df[df.index.isin(set1)]["pdb"]

            #Reset index for next iteration to skip current domains
            split_index_start += ideal_set1_size
        else:
            set1_clusters = sorted_cluster_indices[split_index_start:]
            set1 = [idx for cluster in set1_clusters for idx in clusters.get_group(cluster).index] #.indices[cluster]]
            size_pct = len(set1)/(len(df))
            subset = df[df.index.isin(set1)]["pdb"]

        yield subset

In [120]:
train, valid, test = split_dataset_at_level(trainValClusters, "group", "dummy")

size 15394 19202 0.8016873242370587 1014
size 1713 19202 0.0892094573481929 126
size 1722 19202 0.08967815852515364 127
size 1735 19202 0.09035517133631914 128


In [121]:
train

0         2LC6_A_2_9771
1         1RY4_A_8_9772
2         2LC6_A_7_9773
3        1RY4_A_10_9774
4         2LC6_A_3_9775
              ...      
19197     2LB6_A_2_9766
19198     2LB6_A_6_9767
19199     2LB6_A_5_9768
19200    2LB6_A_10_9769
19201     2LB6_A_4_9770
Name: pdb, Length: 15394, dtype: object

In [243]:
train.to_csv("train.csv")

In [244]:
valid.to_csv("valid.csv")

In [245]:
test.to_csv("test.csv")

In [252]:
set(all_data_df.causes.str.lower().str.split(";", expand=True).values.flatten())

{None,
 'disorder',
 'ligand',
 'ligand biolip',
 'mutation',
 'oligomeric state',
 'oligomeric state author',
 'ph',
 'post-translational mod',
 'temperature',
 'without causes of dc'}

In [255]:
train = train.str.split("_", expand=True)[[0,1,2]].rename(columns={0:"pdb", 1:"chain", 2:"model"})

In [264]:
train_df = pd.merge(train, all_data_df.drop(columns=emb_cols), how="left", on=["pdb", "chain", "model"]).drop_duplicates(subset=["pdb", "chain", "model"])

In [263]:
a.drop_duplicates(subset=["pdb", "chain", "model"])

Unnamed: 0,pdb,chain,model,pdb_index,cluster_id,pdb_cluster_id,file_name,causes
0,2LC6,A,2,14024.0,2LC6_A_seqs,2LC6_A_seqs_0,codnas_groups_pdb/2LC6_A_seqs/2LC6-2-A_fixed.pdb,TEMPERATURE;MUTATION
1,1RY4,A,8,14025.0,2LC6_A_seqs,2LC6_A_seqs_1,codnas_groups_pdb/2LC6_A_seqs/1RY4-8-A_fixed.pdb,TEMPERATURE;MUTATION
2,2LC6,A,7,14026.0,2LC6_A_seqs,2LC6_A_seqs_2,codnas_groups_pdb/2LC6_A_seqs/2LC6-7-A_fixed.pdb,MUTATION
3,1RY4,A,10,14027.0,2LC6_A_seqs,2LC6_A_seqs_3,codnas_groups_pdb/2LC6_A_seqs/1RY4-10-A_fixed.pdb,TEMPERATURE;MUTATION
4,2LC6,A,3,14028.0,2LC6_A_seqs,2LC6_A_seqs_4,codnas_groups_pdb/2LC6_A_seqs/2LC6-3-A_fixed.pdb,TEMPERATURE;MUTATION
...,...,...,...,...,...,...,...,...
15753,2LB6,A,2,12943.0,2LB6_A_seqs,2LB6_A_seqs_12,codnas_groups_pdb/2LB6_A_seqs/2LB6-2-A_fixed.pdb,WITHOUT CAUSES OF DC
15754,2LB6,A,6,12944.0,2LB6_A_seqs,2LB6_A_seqs_13,codnas_groups_pdb/2LB6_A_seqs/2LB6-6-A_fixed.pdb,WITHOUT CAUSES OF DC
15755,2LB6,A,5,12945.0,2LB6_A_seqs,2LB6_A_seqs_14,codnas_groups_pdb/2LB6_A_seqs/2LB6-5-A_fixed.pdb,WITHOUT CAUSES OF DC
15756,2LB6,A,10,12946.0,2LB6_A_seqs,2LB6_A_seqs_15,codnas_groups_pdb/2LB6_A_seqs/2LB6-10-A_fixed.pdb,WITHOUT CAUSES OF DC


In [266]:
train_df.to_csv("train_df.csv")

In [267]:
valid_df = pd.merge(
    valid.str.split("_", expand=True)[[0,1,2]].rename(columns={0:"pdb", 1:"chain", 2:"model"}), 
    all_data_df.drop(columns=emb_cols), how="left", on=["pdb", "chain", "model"]).drop_duplicates(subset=["pdb", "chain", "model"])

In [269]:
valid_df.to_csv("valid_df.csv")

In [270]:
test_df = pd.merge(
    test.str.split("_", expand=True)[[0,1,2]].rename(columns={0:"pdb", 1:"chain", 2:"model"}), 
    all_data_df.drop(columns=emb_cols), how="left", on=["pdb", "chain", "model"]).drop_duplicates(subset=["pdb", "chain", "model"])

In [272]:
test_df.to_csv("test_df.csv")

In [None]:
embeddings_df = 