In [6]:
from pathlib import Path
import gzip
from datetime import datetime
import logging
import pickle
import numpy as np

np.set_printoptions(precision=4, edgeitems=20, linewidth=180)

logging.basicConfig(level=logging.INFO)

from deepfold.data.search.parsers import parse_hmmsearch_sto, parse_fasta, parse_a3m, convert_stockholm_to_a3m, parse_hhr
from deepfold.data.search.input_features import create_sequence_features, create_msa_features, create_template_features
from deepfold.data.search.templates import TemplateHitFeaturizer
from deepfold.data.search.crfalign import parse_pir

ModuleNotFoundError: No module named 'deepfold.data.search.pipeline'

In [2]:
with open("out/T1210.fasta", "r") as fp:
    fasta_str = fp.read()

sequences, descriptions = parse_fasta(fasta_str)

query_sequence = sequences[0]
query_sequence

'MLLLLTLLLFAGTVAADFQHNWQVGNEYTYLVRSRTLTSLGDLSDVHTGILIKALLTVQAKDSNVLAAKVWNGQYARVQQSMPDGWETEISDQMLELRDLPISGKPFQIRMKHGLIRDLIVDRDVPTWEVNILKSIVGQLQVDTQGENAVKVNSVQVPTDDEPYASFKAMEDSVGGKCEVLYDIAPLSDFVIHRSPELVPMPTLKGDGRHMEVIKIKNFDNCDQRINYHFGMTDNSRLEPGTNKNGKFFSRSSTSRIVISESLKHFTIQSSVTTSKMMVSPRLYDRQNGLVLSRMNLTLAKMEKTSKPLPMVDNPESTGNLVYIYNNPFSDVEERRVSKTAMNSNQIVSDNSLSSSEEKLKQDILNLRTDISSSSSSISSSEENDFWQPKPTLEDAPQNSLLPNFVGYKGKHIGKSGKVDVINAAKELIFQIANELEDASNIPVHATLEKFMILCNLMRTMNRKQISELESNMQISPNELKPNDKSQVIKQNTWTVFRDAITQTGTGPAFLTIKEWIERGTTKSMEAANIMSKLPKTVRTPTDSYIRSFFELLQNPKVSNEQFLNTAATLSFCEMIHNAQVNKRSIHNNYPVHTFGRLTSKHDNSLYDEYIPFLERELRKAHQEKDSPRIQTYIMALGMIGEPKILSVFEPYLEGKQQMTVFQRTLMVGSLGKLTETNPKLARSVLYKIYLNTMESHEVRCTAVFLLMKTNPPLSMLQRMAEFTKLDTNRQVNSAVKSTIQSLMKLKSPEWKDLAKKARSVNHLLTHHEYDYELSRGYIDEKILENQNIITHMILNYVGSEDSVIPRILYLTWYSSNGDIKVPSTKVLAMISSVKSFMELSLRSVKDRETIISAAEKIAEELKIVPEELVPLEGNLMINNKYALKFFPFDKHILDKLPTLISNYIEAVKEGKFMNVNMLDTYESVHSFPTETGLPFVYTFNVIKLTKTSGTVQAQINPDFAFIVNSNLRLTFSKNVQGRVGFVTPFEHRHFISGIDSNL

In [3]:
# with open("out/T1207/hmm_output.sto", "r") as fp:
#     sto_str = fp.read()
# hits = parse_hmmsearch_sto(query_sequence, sto_str)

with open("out/msas/pdb70_hits.hhr", "r") as fp:
    hrr_str = fp.read()
hits = parse_hhr(hrr_str)

print(len(hits))
for k, v in vars(hits[0]).items():
    print(k, v)

367
index 1
name 1LSH_A LIPOVITELLIN (LV-1N, LV-1C); LIPOVITELLIN, VITELLOGENIN, LIPOPROTEIN, PLASMA APOLIPOPROTE APOLIPOPROTEIN B, APOB, MICROSOMAL TRIGLYCERIDE TRANSFER PR BOUNDARY LIPID, PHOSPHOLIPID STRUCTURE; HET: PLD, UPL; 1.9A {Ichthyomyzon unicuspis} SCOP: f.7.1.1, a.118.4.1
aligned_cols 866
sum_probs 704.8
query NWQVGNEYTYLVRSRTLTSLGDLSDVHTGILIKALLTVQAKDSNVLAAKVWNGQYARVQQSMPDGWETEISDQMLELRDLPISGKPFQIRMKHGLIRDLIVDRDVPTWEVNILKSIVGQLQVDTQGENAVKVNSVQVPTDDEPYASFKAMEDSVGGKCEVLYDIAPLSDFVIHRSPELVPMPTLKGDGRHMEVIKIKNFDNCDQRINYHFGMTDNSRLEPGTNKNGKFFSRSSTSRIVISESLKHFTIQSSVTTSKMMVSPRLY-DRQNGLVLSRMNLTLAKMEKTSKPLPMVDNPESTGNLVYIYNNPFSDVEERRVSKTAMNSNQIVSDNSLSSSEEKLKQDILNLRTDISSSSSSISSSEENDFWQPKPTLEDAPQNSLLPNFVGYKGKHIGKSGKVDVINAAKELIFQIANELEDASNIPVHATLEKFMILCNLMRTMNRKQISELESNMQISPNELKPNDKSQVIKQNTWTVFRDAITQTGTGPAFLTIKEWIERGTTKSMEAANIMSKLPKTVRTPTDSYIRSFFELLQNPKVSNEQFLNTAATLSFCEMIHNAQVNKRSIHNNYPVHTFGRLTSKHDNSLYDEYIPFLERELRKAHQEKDSPRIQTYIMALGMIGEPKILSVFEPYLEGK----QQMTVFQRTLMVGSLGKLTETNPKLARSVLYKIYLN

In [4]:
featurizer = TemplateHitFeaturizer(
    max_template_hits=20,
    pdb_mmcif_dirpath=Path("/gpfs/database/casp16/pdb/mmCIF"),
    template_pdb_chain_ids=None,
    # pdb_release_dates={},
    pdb_obsolete_filepath=Path("/gpfs/database/casp16/pdb/obsolete.dat"),
    shuffle_top_k_prefiltered=None,
    verbose=True,
)

In [5]:
template_features = create_template_features(
    sequence=query_sequence,
    template_hits=hits[0:50],
    template_hit_featurizer=featurizer,
    max_release_date="2022-04-30",
)
print(template_features.keys())
print(len(template_features["template_aatype"]))
print(*template_features["template_domain_names"])

INFO:deepfold.data.search.templates:_prefilter_template_hits: query_pdb_id=None num_template_hits=50 num_accepted_hits=10 num_rejected_hits=40 rejection_messages_stats=Counter({'align_ratio <= min_align_ratio': 40})
ERROR:deepfold.data.search.templates:DateError: Date 2023-11-29 > max template date 2022-04-30; name=8D00_A Microtubule-associated protein TORTIFOLIA1; TOG, HEAT, Alpha-solenoid, Microtubule-binding, STRUCTURAL PROTEIN; 2.8A {Arabidopsis thaliana}
INFO:deepfold.data.search.templates:get_template_features: num_featurized_templates=9 template_domain_names=['1lsh_A', '6i7s_H', '1lsh_B', '4d50_A', '1b3u_B', '4p6z_G', '7egh_B', '7p3x_A', '6qh5_B'] errors: ['DateError: Date 2023-11-29 > max template date 2022-04-30; name=8D00_A Microtubule-associated protein TORTIFOLIA1; TOG, HEAT, Alpha-solenoid, Microtubule-binding, STRUCTURAL PROTEIN; 2.8A {Arabidopsis thaliana}']


dict_keys(['template_domain_names', 'template_sequence', 'template_aatype', 'template_all_atom_positions', 'template_all_atom_mask', 'template_sum_probs'])
9
1lsh_A 6i7s_H 1lsh_B 4d50_A 1b3u_B 4p6z_G 7egh_B 7p3x_A 6qh5_B


In [6]:
sequence_features = create_sequence_features(query_sequence, "T1210")
seq = sequence_features["sequence"][0].decode()
print(seq)

MLLLLTLLLFAGTVAADFQHNWQVGNEYTYLVRSRTLTSLGDLSDVHTGILIKALLTVQAKDSNVLAAKVWNGQYARVQQSMPDGWETEISDQMLELRDLPISGKPFQIRMKHGLIRDLIVDRDVPTWEVNILKSIVGQLQVDTQGENAVKVNSVQVPTDDEPYASFKAMEDSVGGKCEVLYDIAPLSDFVIHRSPELVPMPTLKGDGRHMEVIKIKNFDNCDQRINYHFGMTDNSRLEPGTNKNGKFFSRSSTSRIVISESLKHFTIQSSVTTSKMMVSPRLYDRQNGLVLSRMNLTLAKMEKTSKPLPMVDNPESTGNLVYIYNNPFSDVEERRVSKTAMNSNQIVSDNSLSSSEEKLKQDILNLRTDISSSSSSISSSEENDFWQPKPTLEDAPQNSLLPNFVGYKGKHIGKSGKVDVINAAKELIFQIANELEDASNIPVHATLEKFMILCNLMRTMNRKQISELESNMQISPNELKPNDKSQVIKQNTWTVFRDAITQTGTGPAFLTIKEWIERGTTKSMEAANIMSKLPKTVRTPTDSYIRSFFELLQNPKVSNEQFLNTAATLSFCEMIHNAQVNKRSIHNNYPVHTFGRLTSKHDNSLYDEYIPFLERELRKAHQEKDSPRIQTYIMALGMIGEPKILSVFEPYLEGKQQMTVFQRTLMVGSLGKLTETNPKLARSVLYKIYLNTMESHEVRCTAVFLLMKTNPPLSMLQRMAEFTKLDTNRQVNSAVKSTIQSLMKLKSPEWKDLAKKARSVNHLLTHHEYDYELSRGYIDEKILENQNIITHMILNYVGSEDSVIPRILYLTWYSSNGDIKVPSTKVLAMISSVKSFMELSLRSVKDRETIISAAEKIAEELKIVPEELVPLEGNLMINNKYALKFFPFDKHILDKLPTLISNYIEAVKEGKFMNVNMLDTYESVHSFPTETGLPFVYTFNVIKLTKTSGTVQAQINPDFAFIVNSNLRLTFSKNVQGRVGFVTPFEHRHFISGIDSNLH

# CRF

In [7]:
with open("/gpfs/deepfold/casp/casp16/T1210/list.crf", "r") as fp:
    lst_crf = fp.read().strip().splitlines()
lst_crf = lst_crf[:20]

crf_hits = []
for i, chain in enumerate(lst_crf, start=1):
    with open(f"/gpfs/deepfold/casp/casp16/T1210/ali/crf/T1210-{chain}.pir", "r") as fp:
        lines = fp.read()
    hit = parse_pir(lines, index=i)
    crf_hits.append(hit)

print(*lst_crf)

1lsh_A 1lsh_B 6i7s_H 4d50_A 8eoj_B 7pp6_D 8d3d_A 4d50_A 5mpc_N 7zwh_E 8amz_N 6msh_U 7qcn_B 5vhf_U 4p6z_G 5nzr_B 5mu7_A 7vmq_A 6msb_U 7vmo_A


In [25]:
crf_features = create_template_features(
    sequence=query_sequence,
    template_hits=crf_hits,
    template_hit_featurizer=featurizer,
    max_release_date="2022-04-30",
)
print(*crf_features["template_domain_names"])

INFO:deepfold.data.search.templates:_prefilter_template_hits: query_pdb_id=None num_template_hits=20 num_accepted_hits=20 num_rejected_hits=0 rejection_messages_stats=Counter()
ERROR:deepfold.data.search.templates:DateError: Date 2023-05-03 > max template date 2022-04-30; name=8eoj_B                                          
ERROR:deepfold.data.search.templates:DateError: Date 2022-07-06 > max template date 2022-04-30; name=8d3d_A                                          
ERROR:deepfold.data.search.templates:DateError: Date 2022-08-24 > max template date 2022-04-30; name=8amz_N                                          
ERROR:deepfold.data.search.templates:DateError: Date 2022-07-13 > max template date 2022-04-30; name=7zwh_E                                          
ERROR:deepfold.data.search.templates:DateError: Date 2022-08-10 > max template date 2022-04-30; name=7vmq_A                                          
ERROR:deepfold.data.search.templates:DateError: Date 2022-08-10 > max tem

1lsh_A 6i7s_H 7pp6_D 1lsh_B 5mpc_N 4d50_A 4d50_A 6msh_U 5vhf_U 4p6z_G 5nzr_B 6msb_U 5mu7_A


In [29]:
c1 = {k: v[5] for k, v in crf_features.items()}
h1 = {k: v[3] for k, v in template_features.items()}
print(c1["template_domain_names"], h1["template_domain_names"])

4d50_A 4d50_A


In [30]:
aa1 = np.argmax(c1['template_aatype'], -1)
print(aa1)

[21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 ... 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21]


In [31]:
aa2 = np.argmax(h1['template_aatype'], -1)
print(aa2)

[21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 ... 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21]


In [32]:
print(c1["template_sequence"])
print(h1["template_sequence"])

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------G-----------------------------------------------------------------PLG--------------SMVTEQEVDAIGQTLVDPK-----QPLQARFRALFTLRGLGGPGAIAWISQAFDD------------DSALLKHELAYCLGQMQDARAIPMLVDVLQDTRQEPMVRHEAGEALGAI---GDPEVLEILKQYSSDP----VIEVAETCQLAVRRLEWLQQHGGEPAAGPYLS-------VDPAPPAEERDVGRLREALLD---ESRPLERYRAMFALRNAGGEEAALALAEGLHC---GSALFRHEVGYVLGQL---QHEAAVPQLAAALARCTENPMVRHECAEALGAIARP-ACLA-ALQAHADDPERVVRESCEVALDMYEHETGRAFQ--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [36]:
p1 = c1["template_all_atom_positions"][:, 1, :]
p2 = h1["template_all_atom_positions"][:, 1, :]

In [44]:
np.concatenate([p1, p2], -1)[500:600]

array([[  4.412, -11.419,  18.193,   4.412, -11.419,  18.193],
       [  7.203,  -9.066,  17.284,   7.203,  -9.066,  17.284],
       [  7.654,  -7.747,  20.799,   7.654,  -7.747,  20.799],
       [  7.566, -11.328,  22.169,   7.566, -11.328,  22.169],
       [ 10.841, -11.959,  20.295,  10.841, -11.959,  20.295],
       [  9.996, -15.654,  19.999,   9.996, -15.654,  19.999],
       [ 11.298, -17.257,  16.804,  11.298, -17.257,  16.804],
       [  8.343, -19.692,  17.006,   8.343, -19.692,  17.006],
       [  6.17 , -16.883,  15.579,   6.17 , -16.883,  15.579],
       [  8.201, -16.585,  12.443,   8.201, -16.585,  12.443],
       [  6.194, -18.878,  10.14 ,   6.194, -18.878,  10.14 ],
       [  2.859, -17.179,  10.952,   2.859, -17.179,  10.952],
       [  4.391, -13.707,  10.5  ,   4.391, -13.707,  10.5  ],
       [  5.793, -14.757,   7.113,   5.793, -14.757,   7.113],
       [  2.35 , -16.024,   6.048,   2.35 , -16.024,   6.048],
       [  0.849, -12.62 ,   7.04 ,   0.849, -12.62 ,   

# Test

In [7]:
with gzip.open("out/features_afm.pkz", "rb") as fp:
    feats = pickle.load(fp)

for k, v in feats.items():
    if isinstance(v, np.ndarray):
        print(k, str(v.dtype), v.shape)
    else:
        print(k)

aatype int32 (369,)
residue_index int32 (369,)
seq_length int32 ()
msa int32 (4093, 369)
num_alignments int32 ()
template_aatype int64 (4, 369)
template_all_atom_mask float32 (4, 369, 37)
template_all_atom_positions float32 (4, 369, 37, 3)
asym_id int64 (369,)
sym_id int64 (369,)
entity_id int64 (369,)
deletion_matrix float32 (4093, 369)
deletion_mean float32 (369,)
all_atom_mask float32 (369, 37)
all_atom_positions float64 (369, 37, 3)
assembly_num_chains int64 ()
entity_mask int32 (369,)
num_templates int32 ()
cluster_bias_mask float64 (4093,)
bert_mask float32 (4093, 369)
seq_mask float32 (369,)
msa_mask float32 (4093, 369)


In [4]:
import random
from collections import defaultdict

def sample_balls(buckets):
    # Collect all balls from all buckets
    sampled_balls = []
    for bucket_name, balls in buckets.items():
        sampled_balls.extend([(ball, color, bucket_name) for ball, color in balls])
    return sampled_balls

def group_balls_by_color(sampled_balls):
    balls_by_color = defaultdict(list)
    for ball, color, bucket in sampled_balls:
        balls_by_color[color].append((ball, bucket))
    return balls_by_color

def form_groups(balls_by_color, L):
    groups = [[] for _ in range(L)]
    group_indices = {i: set() for i in range(L)}

    for color, balls in balls_by_color.items():
        group_index = 0
        for ball, bucket in balls:
            # Try to place the ball in a group that does not have a ball from the same bucket
            placed = False
            for i in range(L):
                if bucket not in group_indices[i]:
                    groups[i].append((ball, color))
                    group_indices[i].add(bucket)
                    placed = True
                    break
            if not placed:
                # If no suitable group is found, skip the ball
                continue

    return groups

# Example usage:
buckets = {
    'bucket1': [('ball1', 'red'), ('ball2', 'blue'), ('ball3', 'green')],
    'bucket2': [('ball4', 'yellow'), ('ball5', 'red')],
    'bucket3': [('ball6', 'blue'), ('ball7', 'green'), ('ball8', 'yellow'), ('ball9', 'red')]
}

L = 4
sampled_balls = sample_balls(buckets)
print("Sampled balls:", sampled_balls)

balls_by_color = group_balls_by_color(sampled_balls)
groups = form_groups(balls_by_color, L)

print("Groups:")
for i, group in enumerate(groups):
    print(f"Group {i + 1}: {group}")


Sampled balls: [('ball1', 'red', 'bucket1'), ('ball2', 'blue', 'bucket1'), ('ball3', 'green', 'bucket1'), ('ball4', 'yellow', 'bucket2'), ('ball5', 'red', 'bucket2'), ('ball6', 'blue', 'bucket3'), ('ball7', 'green', 'bucket3'), ('ball8', 'yellow', 'bucket3'), ('ball9', 'red', 'bucket3')]
Groups:
Group 1: [('ball1', 'red'), ('ball5', 'red'), ('ball9', 'red')]
Group 2: [('ball2', 'blue'), ('ball6', 'blue'), ('ball4', 'yellow')]
Group 3: [('ball3', 'green'), ('ball7', 'green')]
Group 4: [('ball8', 'yellow')]
