In [1]:
import pandas as pd
import pickle
import numpy as np
import re
import random
import time
import datetime as dt
from itertools import combinations
from itertools import accumulate
from collections import defaultdict

In [2]:
def match_label(pair, pos_ppi_dict, neg_ppi_dict):
    if pair in pos_ppi_dict.values():
        return(1)
    elif pair in neg_ppi_dict.keys():
        return(-1)
    else:
        return(0)

In [3]:
def match_group(pair, pos_ppi_dict, neg_ppi_dict):
    group_list = []
    for k, v in pos_ppi_dict.items():
        if pair in v:
            group_list.append(k)
    if len(group_list) > 1:
        return(group_list)
    elif len(group_list) == 1:
        return(group_list[0])
    elif pair in neg_ppi_dict.keys():
        return(neg_ppi_dict.get(pair))
    else:
        return(np.nan)

In [4]:
def match_spr_grp(group, super_grp_dict):
    if type(group) == list:
        return(super_grp_dict.get(group[0]))
    else:
        return(super_grp_dict.get(group))

In [5]:
def merge_groups(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = True
    while merged:
        merged = False
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = True
                    common |= x
            results.append(common)
        sets = results
    return(sets)

In [6]:
def make_sprgrp_dict(labeled_fmat):
    mgroup_list = []
    for i in range(len(labeled_fmat)):
        group = labeled_fmat['group'][i]
        if type(group) == list:
            mgroup_list.append(group)
    merged = merge_groups(mgroup_list)
    super_grp_num = 1
    super_grp_dict = dict()
    for group in merged:
        for old_gnum in group:
            super_grp_dict[old_gnum] = super_grp_num
        super_grp_num += 1
    return(super_grp_dict)

In [7]:
def make_pos_dict(gs_file):
    print(f'[{dt.datetime.now()}] Generating grouped positive PPI labels from gold standard complexes ...')
    pos_ppi_dict = dict()
    group_no = 1
    dupes = []
    with open(gold_std_file, 'r') as f:
        ppis = f.read().splitlines() 
        for p in ppis:
            ogs = p.split(' ')
            fsets = [frozenset({i, j}) for i,j in list(combinations(ogs, 2))]
            pos_ppi_dict.update({group_no: fsets})
            group_no += 1
    print(f'[{dt.datetime.now()}] Finished generating positive PPI labels!')
    return(pos_ppi_dict)

In [8]:
def make_neg_dict(pos_dict):
    # get random prots from positive PPIs to make negative PPIs
    print(f'[{dt.datetime.now()}] Getting random proteins from positive PPIs to generate negative PPIs ...')
    random_prots = set()
    for group_no, fsets in pos_dict.items():
        prot_set = set()
        if len(fsets) > 1:
            for pair in fsets:
                prot_set.add(list(pair)[0])
                prot_set.add(list(pair)[1])
            neg_prots = random.sample(list(prot_set), 3)
            for p in neg_prots:
                random_prots.add(p)
    print(f'[{dt.datetime.now()}] Generating negative PPIs ...')
    neg_ppis = [frozenset({i, j}) for i,j in list(combinations(random_prots, 2))]

    # remove any overlap between neg & pos PPIs
    print(f'[{dt.datetime.now()}] Removing overlap between random negative PPIs & positive PPIs ...')
    t0 = time.time()
    all_pos_ppis = list(pos_dict.values())
    flat_pos_ppis = [pair for pair_list in all_pos_ppis for pair in pair_list]
    overlap = set(neg_ppis).intersection(set(flat_pos_ppis))
    overlap_count = len(overlap)
    for i in overlap:
        neg_ppis.remove(i)
    print(f'[{dt.datetime.now()}] # overlapping negative PPIs found & removed = {overlap_count}; total time: {time.time() - t0} seconds)')
    
    # get negative PPI splits
    num_groups = len(pos_dict)
    print(f'[{dt.datetime.now()}] Randomly splitting negative PPIs into {num_groups} groups ...')
    neg_cmplx_sizes = [random.randint(2, 30) for x in range(num_groups)]
    neg_ppi_grouped = [neg_ppis[x - y: x] for x, y in zip(
            accumulate(neg_cmplx_sizes), neg_cmplx_sizes)]
    print(f'[{dt.datetime.now()}] # of negative PPI groups =', len(neg_ppi_grouped))
    
    # get negative PPI groups
    print(f'[{dt.datetime.now()}] Generating grouped negative PPI labels ...')
    neg_ppi_dict = dict()
    group_sizes = []
    group_no = 1
    for group in neg_ppi_grouped:
        group_sizes.append(len(group))
        for pair in group:
            neg_ppi_dict.update({pair: group_no})
        group_no += 1
    
    print(f'[{dt.datetime.now()}] Finished generating negative PPI labels!')
    return(neg_ppi_dict)

In [9]:
def label_fmat(fmat_file, pos_dict, neg_dict):
    print(f'[{dt.datetime.now()}] Loading features from {fmat_file}...')
    with open(data_dir+fmat_file, 'rb') as handle:
        fmat = pickle.load(handle)

    print(f'[{dt.datetime.now()}] Formatting feature matrix ID columns & rows ...')
    fmat[['ID1','ID2']] = fmat['ID'].str.split(' ',expand=True)
    fmat = fmat[fmat['ID2'].notna()]
    
    t0 = time.time()
    # TODO: long step; potentially optimize ..?
    print(f'[{dt.datetime.now()}] Labeling feature matrix (takes awhile)...')
    fmat['label'] = [match_label(frozenset({i, j}), pos_dict, neg_dict) for i, j in zip(fmat['ID1'], fmat['ID2'])]
    fmat['group'] = [match_group(frozenset({i, j}), pos_dict, neg_dict) for i, j in zip(fmat['ID1'], fmat['ID2'])]
    print(f'[{dt.datetime.now()}] Total time to label {len(fmat)} rows: {time.time() - t0} seconds')
    
    num_pos = len(fmat[(fmat['label'] == 1)])
    num_neg = len(fmat[(fmat['label'] == -1)])
    print(f'[{dt.datetime.now()}] Total # positive PPIs = {num_pos}')
    print(f'[{dt.datetime.now()}] Total # negative PPIs = {num_neg}')
    return(fmat)

In [10]:
def label_fmat_supergrps(labeled_fmat):
    print(f'[{dt.datetime.now()}] Generating merged complex groups ...')
    sdict = make_sprgrp_dict(labeled_fmat)
    print(f'[{dt.datetime.now()}] Labeling non-redundant complex groups ...')
    labeled_fmat['super_group'] = [match_spr_grp(i, sdict) for i in labeled_fmat['group']]
    return(labeled_fmat)

In [11]:
def format_fmat(labeled_fmat, drop_overlap_groups=False, shuffle_feats=False, shuffle_rows=False):
    # get col names for labels, features
    print(f'[{dt.datetime.now()}] Reformatting columns ...')
    label_cols = ['ID', 'group', 'super_group', 'label']
    feature_cols = [c for c in labeled_fmat.columns.values.tolist() if c not in label_cols]
    # optionally shuffle feature order
    if shuffle_feats:
        print(f'[{dt.datetime.now()}] Shuffling feature columns ...')
        random.shuffle(feature_cols)
    # reorder columns
    fmat_fmt = labeled_fmat[label_cols + feature_cols]
    # optionally drop group_col with redundant PPIs
    # probably always want to do this tbh
    if drop_overlap_groups:
        print(f'[{dt.datetime.now()}] Dropping redundant complex groups ...')
        fmat_fmt = fmat_fmt.drop(['group'], axis=1)
    # optionally shuffle row order;
    # --> technically will be done later w/ sklearn.model_selection.GroupShuffleSplit
    # --> but it's here if you want to shuffle at this step for some reason
    if shuffle_rows:
        print(f'[{dt.datetime.now()}] Shuffling non-redundant protein complex super groups ...')
        grps = fmat_fmt['super_group'].unique()
        random.shuffle(grps)
        fmat_fmt = fmat_fmt.set_index('super_group').loc[grps].reset_index()
    print('Final feature matrix:')
    print(fmat_fmt.head())
    return(fmat_fmt)

In [12]:
def write_fmat_files(labeled_fmat, fmat_file, outfile=None):
    # format outfile path/name if none specified
    if not outfile:
        plist = fmat_file.split('/', 1)
        outpath = '/'.join(plist[:-1])+'/'
        outfile = outpath+'labeled_featmat'
    
    # write out matrices
    t0 = time.time()
    print(f'[{dt.datetime.now()}] Writing out matrices:')
    print(f"[{dt.datetime.now()}] \t► Full matrix (labeled + unlabeled) --> {outfile}")
    print(f"[{dt.datetime.now()}] \t► Positive & negative PPIs --> {outfile+'_traintest'}")
    print(f"[{dt.datetime.now()}] \t► Gold standard (positive) PPIs only --> {outfile+'_goldstd'}")
    
    # gold standard (positive/known) PPIs only
    t1 = time.time()
    print(f'[{dt.datetime.now()}] Extracting gold standard PPIs ...')
    goldstd = labeled_fmat[(labeled_fmat['label'] == 1)]
    goldstd.reset_index(drop=True, inplace=True)
    print(f"[{dt.datetime.now()}] Writing serialized gold standard matrix to {outfile+'_goldstd.pkl'} ... ")
    goldstd.to_pickle(outfile+'_goldstd.pkl')
    print(f"[{dt.datetime.now()}] Writing comma-separated gold standard matrix to {outfile+'_goldstd'} ... ")
    goldstd.to_csv(outfile+'_goldstd', index=False)
    print(f"[{dt.datetime.now()}] Total time to write gold standard feature matrix of shape {goldstd.shape}: {time.time() - t1} seconds")
    
    # positive + negative PPIs only
    t2 = time.time()
    print(f'[{dt.datetime.now()}] Extracting train/test rows ...')
    traintest = labeled_fmat[(labeled_fmat['label'] == 1) | (labeled_fmat['label'] == -1)]
    traintest.reset_index(drop=True, inplace=True)
    print(f"[{dt.datetime.now()}] Writing serialized train/test matrix to {outfile+'_traintest.pkl'} ... ")
    traintest.to_pickle(outfile+'_traintest.pkl')
    print(f"[{dt.datetime.now()}] Writing comma-separated train/test matrix to {outfile+'_traintest'} ... ")
    traintest.to_csv(outfile+'_traintest', index=False)
    print(f"[{dt.datetime.now()}] Total time to write train/test feature matrix of shape {traintest.shape}: {time.time() - t2} seconds")
    
    # all data, labeled & unlabeled
    t3 = time.time()
    print(f"[{dt.datetime.now()}] Writing full serialized matrix to {outfile+'.pkl'} ... ")
    labeled_fmat.to_pickle(outfile+'.pkl')
    print(f"[{dt.datetime.now()}] Writing full comma-separated matrix to {outfile} ... ")
    labeled_fmat.to_csv(outfile, index=False)
    print(f"[{dt.datetime.now()}] Total time to write full feature matrix of shape {labeled_fmat.shape}: {time.time() - t3} seconds")
    
    print(f"[{dt.datetime.now()}] ---------------------------------------------------------")
    print(f"[{dt.datetime.now()}] Total time to write all files: {time.time() - t0} seconds")
    print(f"[{dt.datetime.now()}] ---------------------------------------------------------")

## Input vars

In [13]:
fmat_file = '../ppi_ml/data/featmats/featmat.pkl'
outfile = '../ppi_ml/data/featmats/featmat_labeled'
gold_std_file = '../ppi_ml/data/gold_stds/all.gold.cmplx.noRibos.merged.txt'
seed = 13

## Run script

In [14]:
random.seed(seed)
pos_dict = make_pos_dict(gold_std_file)
neg_dict = make_neg_dict(pos_dict)

[2023-02-07 14:53:33.576621] Generating grouped positive PPI labels from gold standard complexes ...
[2023-02-07 14:53:33.640402] Finished generating positive PPI labels!
[2023-02-07 14:53:33.640601] Getting random proteins from positive PPIs to generate negative PPIs ...
[2023-02-07 14:53:33.652776] Generating negative PPIs ...
[2023-02-07 14:53:35.963242] Removing overlap between random negative PPIs & positive PPIs ...
[2023-02-07 14:56:12.146275] # overlapping negative PPIs found & removed = 6957; total time: 156.18288898468018 seconds)
[2023-02-07 14:56:12.146518] Randomly splitting negative PPIs into 1499 groups ...
[2023-02-07 14:56:12.150453] # of negative PPI groups = 1499
[2023-02-07 14:56:12.150475] Generating grouped negative PPI labels ...
[2023-02-07 14:56:12.157232] Finished generating negative PPI labels!


In [15]:
labeled_fmat = label_fmat(fmat_file, pos_dict, neg_dict)
labeled_fmat_final = label_fmat_supergrps(labeled_fmat)

[2023-02-07 14:56:12.252535] Loading features from ../ppi_ml/data/featmats/featmat.pkl...


NameError: name 'data_dir' is not defined

In [None]:
fmat_out = format_fmat(labeled_fmat_final, drop_overlap_groups=True, shuffle_feats=True)

In [None]:
write_fmat_files(fmat_out, fmat_file, outfile)

## Checks & balances below here

In [None]:
index_list = []
multi_count = 0
for i in range(len(fmat_out)):
    group = fmat_out['super_group'][i]
    if type(group) == list:
        index_list.append(i)
        multi_count += 1

In [None]:
multi_count

In [None]:
pos_check = random.choice(list(pos_dict.values()))
neg_check = random.choice(list(neg_dict.keys()))
print(pos_check)
print(neg_check)

In [None]:
len(fmat_out[fmat_out['label'] == 1])

In [None]:
len(fmat_out[fmat_out['label'] == -1])