In [1]:
import functools, gzip, itertools, json, math, os
import Bio.SeqIO, numpy as np, pandas as pd, snakemake, tqdm.contrib.concurrent

In [2]:
@functools.cache
def slurm_ntasks():
    try:
        return int(os.environ['SLURM_NTASKS'])
    except:
        return 4 # nproc --all?

def tqdm_map(function, iterable):
    return tqdm.contrib.concurrent.process_map(function, iterable, max_workers=slurm_ntasks(), chunksize=10)

def read_summary_confidences(path, cols = ['fraction_disordered', 'has_clash', 'iptm', 'ptm', 'ranking_score', 'chain_iptm', 'chain_pair_iptm', 'chain_pair_pae_min', 'chain_ptm']):
    with gzip.open(path, 'r') as handle:
        js = json.load(handle)
        s_ = pd.Series(list(js[col] for col in cols), index=cols)
        return s_

def read_predictions(path='alphafold3_predictions', ids=None, include_all_models=True, seed=1):
    if ids is None:
        ids, = snakemake.io.glob_wildcards(os.path.join(path, '{id}/{id}_model.cif.gz'))

    df_ = pd.DataFrame({'id': ids})
    df_['model'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/{id}_model.cif.gz'))
    df_['summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/{id}_summary_confidences.json.gz'))
    df_summary_confidences = pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['summary_confidences']))
    df_ = pd.concat([df_, df_summary_confidences], axis=1)

    if include_all_models:
        df_['model0_summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/seed-{seed}_sample-0/summary_confidences.json.gz'))
        df_['model1_summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/seed-{seed}_sample-1/summary_confidences.json.gz'))
        df_['model2_summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/seed-{seed}_sample-2/summary_confidences.json.gz'))
        df_['model3_summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/seed-{seed}_sample-3/summary_confidences.json.gz'))
        df_['model4_summary_confidences'] = df_['id'].map(lambda id: os.path.join(path, f'{id}/seed-{seed}_sample-4/summary_confidences.json.gz'))
        df_model_summary_confidences = pd.concat([
            pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['model0_summary_confidences'])).add_prefix('model0_'),
            pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['model1_summary_confidences'])).add_prefix('model1_'),
            pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['model2_summary_confidences'])).add_prefix('model2_'),
            pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['model3_summary_confidences'])).add_prefix('model3_'),
            pd.DataFrame.from_records(tqdm_map(read_summary_confidences, df_['model4_summary_confidences'])).add_prefix('model4_'),
        ], axis=1)
        df_ = pd.concat([df_, df_model_summary_confidences], axis=1)

    return df_


def chain_pair_iptm_01(s):
    # df_summary_['chain_pair_iptm'].map(af2genomics.alphafold3.chain_pair_iptm_01)
    return json.loads(s)[0][1]

def chain_pair_iptm_triu(s):
    if type(s) is str:
        arr = np.array(json.loads(s))
    elif type(s) is list:
        arr = np.array(s)
    else:
        arr = s
    tri = np.triu_indices_from(arr, k=1)
    return arr[tri]

def parse_pair_iptms(df_summary):
    def interactions_(s):
        l_ = list(itertools.combinations(s.split('_'), 2))
        #random.shuffle(l_)
        return l_

    df_summary['ids'] = df_summary['id'].map(interactions_)
    for model in range(5):
        df_summary[f'model{model}_chain_pair_iptm_triu'] = df_summary[f'model{model}_chain_pair_iptm'].map(chain_pair_iptm_triu)

    iptm_cols_ = [
        'model0_chain_pair_iptm_triu',
        'model1_chain_pair_iptm_triu',
        'model2_chain_pair_iptm_triu',
        'model3_chain_pair_iptm_triu',
        'model4_chain_pair_iptm_triu',
    ]
    df_pairs = df_summary.explode(['ids',] + iptm_cols_).reset_index(drop=True)
    df_pairs['pair_iptm_mean'] = df_pairs[iptm_cols_].mean(axis=1)
    df_pairs['pair_iptm_max'] = df_pairs[iptm_cols_].max(axis=1)
    return df_pairs

In [3]:
path = '/cluster/project/beltrao/jjaenes/25.06.03_batch-infer/projects-25.04/af3_mgen_pairs/alphafold3_predictions'
predictions = read_predictions(path)
predictions

  0%|          | 0/15304 [00:00<?, ?it/s]

  0%|          | 0/15304 [00:00<?, ?it/s]

  0%|          | 0/15304 [00:00<?, ?it/s]

  0%|          | 0/15304 [00:00<?, ?it/s]

  0%|          | 0/15304 [00:00<?, ?it/s]

  0%|          | 0/15304 [00:00<?, ?it/s]

Unnamed: 0,id,model,summary_confidences,fraction_disordered,has_clash,iptm,ptm,ranking_score,chain_iptm,chain_pair_iptm,...,model3_chain_ptm,model4_fraction_disordered,model4_has_clash,model4_iptm,model4_ptm,model4_ranking_score,model4_chain_iptm,model4_chain_pair_iptm,model4_chain_pair_pae_min,model4_chain_ptm
0,aac71280.2_aac71285.1_aac71638.1_aac72465.1_aa...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.10,0.0,0.25,0.34,0.32,"[0.08, 0.09, 0.12, 0.08, 0.08, 0.08, 0.07, 0.13]","[[0.79, 0.09, 0.08, 0.09, 0.06, 0.08, 0.06, 0....",...,"[0.79, 0.7, 0.69, 0.66, 0.7, 0.62, 0.27, 0.7]",0.10,0.0,0.24,0.33,0.31,"[0.08, 0.09, 0.12, 0.08, 0.07, 0.08, 0.07, 0.12]","[[0.79, 0.09, 0.08, 0.08, 0.08, 0.08, 0.05, 0....","[[0.76, 31.23, 30.63, 27.58, 29.1, 30.08, 31.1...","[0.79, 0.7, 0.69, 0.66, 0.7, 0.61, 0.26, 0.7]"
1,aac71305.1_aac71380.1_aac71562.1_aac71438.1,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.11,0.0,0.07,0.27,0.17,"[0.04, 0.04, 0.04, 0.04]","[[0.64, 0.04, 0.05, 0.04], [0.04, 0.54, 0.03, ...",...,"[0.65, 0.53, 0.63, 0.75]",0.10,0.0,0.07,0.27,0.16,"[0.04, 0.04, 0.04, 0.04]","[[0.66, 0.04, 0.04, 0.04], [0.04, 0.54, 0.03, ...","[[0.76, 30.38, 30.24, 30.07], [29.38, 0.76, 31...","[0.66, 0.54, 0.63, 0.76]"
2,aac71648.1_aac71311.1_aac71324.2_aac71410.1,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.08,0.0,0.18,0.69,0.32,"[0.11, 0.11, 0.09, 0.17]","[[0.64, 0.12, 0.05, 0.16], [0.12, 0.49, 0.05, ...",...,"[0.64, 0.48, 0.79, 0.83]",0.05,0.0,0.17,0.69,0.30,"[0.11, 0.11, 0.08, 0.16]","[[0.64, 0.12, 0.04, 0.16], [0.12, 0.5, 0.05, 0...","[[0.76, 21.77, 30.75, 29.86], [21.93, 0.76, 31...","[0.64, 0.5, 0.79, 0.83]"
3,aac72448.1_aac72470.2_aac71635.1,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.02,0.0,0.18,0.57,0.27,"[0.08, 0.14, 0.14]","[[0.73, 0.07, 0.08], [0.07, 0.85, 0.2], [0.08,...",...,"[0.73, 0.86, 0.91]",0.02,0.0,0.17,0.57,0.26,"[0.08, 0.13, 0.14]","[[0.74, 0.07, 0.08], [0.07, 0.86, 0.19], [0.08...","[[0.76, 28.52, 30.09], [28.36, 0.76, 23.24], [...","[0.74, 0.86, 0.91]"
4,aac71371.1_aac71299.1_aac71328.1_aac71648.1_aa...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.04,0.0,0.28,0.47,0.34,"[0.08, 0.08, 0.1, 0.08, 0.18, 0.08, 0.1, 0.09,...","[[0.6, 0.07, 0.05, 0.07, 0.17, 0.05, 0.07, 0.0...",...,"[0.6, 0.52, 0.67, 0.62, 0.78, 0.79, 0.42, 0.43...",0.04,0.0,0.28,0.47,0.34,"[0.08, 0.08, 0.1, 0.08, 0.17, 0.08, 0.1, 0.09,...","[[0.59, 0.06, 0.05, 0.07, 0.17, 0.05, 0.07, 0....","[[0.76, 26.51, 31.14, 24.68, 28.08, 31.21, 30....","[0.59, 0.51, 0.67, 0.62, 0.78, 0.79, 0.41, 0.4..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15299,aac71317.1_aac71425.1_aac71637.1_aac71506.1_aa...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.07,0.0,0.18,0.39,0.26,"[0.1, 0.06, 0.08, 0.05, 0.07, 0.08, 0.05]","[[0.84, 0.08, 0.18, 0.08, 0.08, 0.09, 0.08], [...",...,"[0.84, 0.75, 0.77, 0.21, 0.43, 0.68, 0.44]",0.07,0.0,0.18,0.39,0.26,"[0.1, 0.06, 0.08, 0.05, 0.07, 0.08, 0.05]","[[0.84, 0.08, 0.18, 0.08, 0.08, 0.09, 0.08], [...","[[0.76, 29.78, 20.16, 30.58, 30.21, 29.7, 29.3...","[0.84, 0.76, 0.77, 0.24, 0.43, 0.67, 0.43]"
15300,aac71219.1_aac71240.1_aac71584.1_aac71274.1_aa...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.07,0.0,0.23,0.31,0.28,"[0.1, 0.08, 0.08, 0.07, 0.08, 0.08, 0.08, 0.07...","[[0.53, 0.11, 0.11, 0.09, 0.1, 0.1, 0.11, 0.1,...",...,"[0.53, 0.66, 0.65, 0.49, 0.66, 0.78, 0.31, 0.5...",0.03,0.0,0.23,0.31,0.26,"[0.1, 0.08, 0.08, 0.07, 0.08, 0.08, 0.08, 0.08...","[[0.53, 0.12, 0.1, 0.09, 0.11, 0.1, 0.11, 0.1,...","[[0.76, 29.04, 30.25, 30.98, 31.12, 31.01, 30....","[0.53, 0.64, 0.66, 0.49, 0.66, 0.78, 0.31, 0.5..."
15301,aac71311.1_aac71522.2_aac71510.1_aac71634.2,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.04,0.0,0.16,0.42,0.23,"[0.1, 0.09, 0.12, 0.07]","[[0.59, 0.08, 0.18, 0.05], [0.08, 0.85, 0.1, 0...",...,"[0.49, 0.84, 0.79, 0.33]",0.03,0.0,0.16,0.42,0.23,"[0.1, 0.09, 0.13, 0.09]","[[0.57, 0.08, 0.17, 0.06], [0.08, 0.85, 0.1, 0...","[[0.76, 29.96, 21.08, 26.04], [29.81, 0.76, 29...","[0.57, 0.85, 0.79, 0.32]"
15302,aac71472.1_aac72465.1_aac71294.1_aac71558.1_aa...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,/cluster/project/beltrao/jjaenes/25.06.03_batc...,0.03,0.0,0.24,0.35,0.27,"[0.07, 0.07, 0.06, 0.06, 0.11, 0.07, 0.06, 0.1...","[[0.69, 0.07, 0.05, 0.06, 0.1, 0.07, 0.06, 0.1...",...,"[0.67, 0.63, 0.29, 0.65, 0.75, 0.64, 0.67, 0.6...",0.03,0.0,0.23,0.35,0.27,"[0.07, 0.07, 0.07, 0.07, 0.12, 0.07, 0.06, 0.1...","[[0.68, 0.07, 0.05, 0.06, 0.09, 0.07, 0.06, 0....","[[0.76, 30.24, 31.19, 30.5, 30.38, 31.38, 28.2...","[0.68, 0.63, 0.25, 0.65, 0.75, 0.62, 0.67, 0.6..."


In [4]:
pairs = parse_pair_iptms(predictions)[['ids', 'pair_iptm_mean']]
pairs[['pair_id1', 'pair_id2']] = pd.DataFrame(pairs['ids'].tolist())
pairs = pairs[['pair_id1', 'pair_id2', 'pair_iptm_mean']]
pairs

Unnamed: 0,pair_id1,pair_id2,pair_iptm_mean
0,aac71280.2,aac71285.1,0.09
1,aac71280.2,aac71638.1,0.08
2,aac71280.2,aac72465.1,0.094
3,aac71280.2,aac71306.1,0.064
4,aac71280.2,aac71492.1,0.078
...,...,...,...
296120,aac71303.1,aac71635.1,0.09
296121,aac71376.1,aac71628.1,0.066
296122,aac71376.1,aac71635.1,0.076
296123,aac71628.1,aac71635.1,0.088


In [5]:
def read_proteins():
    def get_protein_id(description):
        for rec in description.split():
            if rec.startswith('[protein_id='):
                return rec.lstrip('[protein_id=').rstrip(']')

    def get_locus_tag(description):
        for rec in description.split():
            if rec.startswith('[locus_tag='):
                return rec.lstrip('[locus_tag=').rstrip(']')

    def read_(file, stop=None):
        with open(file, 'r') as handle:
            for record in itertools.islice(Bio.SeqIO.parse(handle, 'fasta'), stop):
                #print(get_protein_id(record.description))
                yield(get_protein_id(record.description), get_locus_tag(record.description), str(record.seq))

    fp_ = '/cluster/work/beltrao/jjaenes/25.04.02_batch-infer-projects/af3_mgen/_25.04.04/mgen_proteins.txt'
    df_len_ = pd.DataFrame.from_records(read_(fp_), columns=['protein_id', 'locus_tag', 'seq'])
    df_len_['af3_id'] = df_len_['protein_id'].str.lower()
    df_len_ = df_len_[['protein_id', 'af3_id', 'locus_tag', 'seq']]
    df_len_['seq_len'] = df_len_['seq'].str.len()
    return df_len_

pairs = pairs\
    .merge(read_proteins()[['af3_id', 'seq_len']].rename({'af3_id': 'pair_id1', 'seq_len': 'seq1_len'}, axis=1), on='pair_id1')\
    .merge(read_proteins()[['af3_id', 'seq_len']].rename({'af3_id': 'pair_id2', 'seq_len': 'seq2_len'}, axis=1), on='pair_id2')
pairs['pair_len'] = pairs[['seq1_len', 'seq2_len']].sum(axis=1)
pairs

Unnamed: 0,pair_id1,pair_id2,pair_iptm_mean,seq1_len,seq2_len,pair_len
0,aac71280.2,aac71285.1,0.09,303,516,819
1,aac71280.2,aac71638.1,0.08,303,329,632
2,aac71280.2,aac72465.1,0.094,303,231,534
3,aac71280.2,aac71306.1,0.064,303,155,458
4,aac71280.2,aac71492.1,0.078,303,336,639
...,...,...,...,...,...,...
296120,aac71303.1,aac71635.1,0.09,311,458,769
296121,aac71376.1,aac71628.1,0.066,138,279,417
296122,aac71376.1,aac71635.1,0.076,138,458,596
296123,aac71628.1,aac71635.1,0.088,279,458,737


In [6]:
class SizeCorrection:
    def __init__(self, tokens, scores):
        data_ = pd.DataFrame({'tokens': tokens, 'scores': scores}).sort_values('tokens').reset_index(drop=True)
        data_['tokens_sqrt'] = data_['tokens'].map(math.sqrt)
        data_['scores_min'] = data_['scores'].rolling(window=51, center=True).min()
        self.data = data_.query('scores_min == scores_min').reset_index(drop=True)

        print(len(data_.query('scores == scores')))
        print(len(data_.query('scores_min == scores_min')))

        self.predict_baseline = np.poly1d(np.polyfit(x=self.data.tokens_sqrt, y=self.data.scores_min, deg=1))
        print(self.predict_baseline)

    def transform(self, tokens, scores):
        data_ = pd.DataFrame({'tokens': tokens, 'scores': scores})
        data_['tokens_sqrt'] = data_['tokens'].map(math.sqrt)
        data_['baseline'] = self.predict_baseline(data_.tokens_sqrt)
        data_['scores_corrected'] = data_['scores'] - data_['baseline']       
        return data_['scores_corrected']

    def transform_plot(self, tokens, scores):
        data_ = pd.DataFrame({'tokens': tokens, 'scores': scores})
        data_['tokens_sqrt'] = data_['tokens'].map(math.sqrt)
        data_['baseline'] = self.predict_baseline(data_.tokens_sqrt)
        data_['scores_corrected'] = data_['scores'] - data_['baseline']       
        sns.scatterplot(data_, x='tokens', y='scores')
        sns.lineplot(data=data_, x='tokens', y='baseline', color='tab:red')

size_correction_ = SizeCorrection(pairs.pair_len, pairs.pair_iptm_mean)
pairs['pair_iptm_mean_corrected'] = size_correction_.transform(pairs.pair_len, pairs.pair_iptm_mean)
pairs

296125
296075
 
0.00437 x - 0.03651


Unnamed: 0,pair_id1,pair_id2,pair_iptm_mean,seq1_len,seq2_len,pair_len,pair_iptm_mean_corrected
0,aac71280.2,aac71285.1,0.09,303,516,819,0.00144
1,aac71280.2,aac71638.1,0.08,303,329,632,0.006643
2,aac71280.2,aac72465.1,0.094,303,231,534,0.02952
3,aac71280.2,aac71306.1,0.064,303,155,458,0.006982
4,aac71280.2,aac71492.1,0.078,303,336,639,0.004036
...,...,...,...,...,...,...,...
296120,aac71303.1,aac71635.1,0.09,311,458,769,0.005318
296121,aac71376.1,aac71628.1,0.066,138,279,417,0.013266
296122,aac71376.1,aac71635.1,0.076,138,458,596,0.005818
296123,aac71628.1,aac71635.1,0.088,279,458,737,0.005866
