In [1]:
import torch
from torch import nn
import numpy as np
import re

In [3]:
of_basic_features = torch.load('current_implementation/test_outputs/basic_features.pt', map_location='cpu')
of_extra_msa_feat = torch.load('current_implementation/test_outputs/extra_msa_feat.pt', map_location='cpu')

### target_feat
This is a feature of size [$N_{res}$ , 21] consisting of the “aatype” feature.
- One-hot representation of the input amino acid sequence (20 amino
acids + unknown).
This feature is padded in openfold in data_transforms: make_msa_feat(), so it has shape [$N_{res}$, 22]
For now, we simulate this with an additional letter Z in the beginning.

In [4]:
sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASK'


def calculate_target_feat(sequence):
    aa_codes = "ZARNDCQEGHILKMFPSTWYVX"
    sequence_inds = torch.tensor([aa_codes.index(a) for a in sequence])
    encoding = nn.functional.one_hot(sequence_inds, num_classes=22)
    return encoding




### residue_index
This is a feature of size [$N_{res}$] consisting of the "residue_index" feature
- index into the original amino acid sequence, in our case just [0, ..., $N_{res}$]

In [5]:
def calculate_residue_index(sequence):
    return torch.arange(len(sequence))


### msa_feat
This is a feature of size [$N_{clust}$, $N_{res}$, 49] constructed by concatenating "cluster_msa", "cluster_has_deletion", "cluster_deletion_value", "cluster_deletion_mean", "cluster_profile".
- This feature seems to be sampled randomly during recycling, this isn't implemented yet
- "cluster_msa" of shape [$N_{clust}$, $N_{res}$, 23] is a one-hot representation of the cluster centers.
- "cluster_has_deletion" of shape [$N_{clust}$, $N_{res}$, 1] is a binary feature indicating if there is a deletion to the left of the residue in the MSA cluster centres
- "cluster_deletion_value"...

In [7]:
from current_implementation.structure_module import residue_constants

def load_a3m_file(file_name):
    with open(file_name, 'r') as f:
        lines = f.readlines()
    description_line_indices = [i for i,l in enumerate(lines) if l.startswith('>')]
    descriptions = [lines[i].strip() for i in description_line_indices]
    seqs = [lines[i+1].strip() for i in description_line_indices]
    unique_seqs = []

    deletion_count_matrix = []
    for seq in seqs:
        deletion_count_list = []
        deletion_counter = 0
        for letter in seq:
            if letter.islower():
                deletion_counter += 1
            else:
                deletion_count_list.append(deletion_counter)
                deletion_counter = 0
        seq_without_deletion = re.sub('[a-z]', '', seq)
        if seq_without_deletion in unique_seqs:
            continue
        
        unique_seqs.append(seq_without_deletion)
        deletion_count_matrix.append(deletion_count_list)

    aa_codes = "ARNDCQEGHILKMFPSTWYVX-"
    unique_seqs = torch.stack([
            torch.tensor([aa_codes.index(a) for a in seq]) 
        for seq in unique_seqs])
    unique_seqs_one_hot = nn.functional.one_hot(unique_seqs, num_classes=22)
    amino_acid_distribution = unique_seqs_one_hot.float().mean(dim=0)
    deletion_count_matrix = torch.tensor(deletion_count_matrix)
    
    return { 'msa_aatype': unique_seqs, 'msa_deletion_count': deletion_count_matrix, 'amino_acid_distribution': amino_acid_distribution}

# features = load_a3m_file('kilian/alignments_hhr/test_tautomerase/test_tautomerase.a3m')

        


In [8]:
def split_clusters(features, max_msa_clusters=512, basic_seed=0):
    sequences = features['msa_aatype']
    num_seqs = sequences.shape[0]
    max_msa_clusters = min(max_msa_clusters, num_seqs)
    random_generator = torch.Generator(device=sequences.device)
    random_generator.manual_seed(basic_seed)
    shuffled = torch.randperm(num_seqs-1, generator=random_generator)+1
    shuffled = torch.cat((torch.tensor([0]), shuffled), dim=0)

    resulting_features = {key: value.clone() for key, value in features.items()}
    MSA_FEATURE_NAMES = ['msa_aatype', 'msa_deletion_count']
    for key in MSA_FEATURE_NAMES:
        if key not in features:
            continue
        feature = features[key]
        resulting_features = {
            **resulting_features, 
            key: feature[shuffled[:max_msa_clusters]],
            f'extra_{key}': feature[shuffled[max_msa_clusters:]] 
        }
    return resulting_features


In [9]:
import torch.distributions as distributions
def mask_clusters(features, mask_probability=0.15, basic_seed=0):
    features = {key: value.clone() for key, value in features.items()}
    N_clust, N_res = features['msa_aatype'].shape
    N_aa_categories = 23 # 20 Amino Acids, Unknown Amino Acid, Gap, masked_msa_token
    odds = {
        'uniform_replacement': 0.1,
        'replacement_from_distribution': 0.1,
        'no_replacement': 0.1,
        'masked_out': 0.7
    }
    uniform_category = torch.tensor([1/20] * 20 + [0, 0]) * odds['uniform_replacement']
    replacement_from_distribution = features['amino_acid_distribution'] * odds['replacement_from_distribution']
    no_replacement = nn.functional.one_hot(features['msa_aatype'], num_classes=22) * odds['no_replacement']
    masked_out = torch.tensor([odds['masked_out']]).expand((N_clust, N_res, 1))

    transition_categories_without_mask = uniform_category + replacement_from_distribution + no_replacement
    transition_categories = torch.cat((transition_categories_without_mask, masked_out), dim=-1)
    unk_in_msa = features['msa_aatype'] == 20


    gen = torch.Generator(device=features['msa_aatype'].device)
    gen.manual_seed(basic_seed+1)
    mask = torch.rand(features['msa_aatype'].shape, generator=gen) < mask_probability

    torch.manual_seed(basic_seed+1)
    replacement = distributions.Categorical(probs=transition_categories.reshape(-1, N_aa_categories)).sample()
    replacement = replacement.reshape(N_clust, N_res)
    features['true_msa_aatype'] = features['msa_aatype'].clone()
    features['msa_aatype'][mask] = replacement[mask]

    return features



    

In [10]:
def cluster_assignment(features):
    N_clust, N_res = features['msa_aatype'].shape
    N_extra, _ = features['extra_msa_aatype'].shape

    msa_broadcast = features['msa_aatype'].reshape((N_clust, N_res, 1))
    extra_broadcast = features['extra_msa_aatype'].T.reshape((1, N_res, N_extra))
    mask = torch.logical_and(msa_broadcast != 22, msa_broadcast != 21)
    agreement = torch.logical_and(msa_broadcast == extra_broadcast, mask).sum(dim=1)
    assignment = torch.argmax(agreement, dim=0)
    features['extra_cluster_assignment'] = assignment
           
    return features


In [45]:
def summarize_clusters(features):
    N_clust, N_res = features['msa_aatype'].shape
    N_extra, _ = features['extra_msa_aatype'].shape
    assignment = features['extra_cluster_assignment']
    assignment_counts = (torch.arange(N_clust).reshape(N_clust, 1) == assignment.reshape(1, N_extra)).sum(dim=1) + 1

    def cluster_average(feature, extra_feature):
        unsqueezed_shape = [-1] + [1] * (extra_feature.ndim-1)
        broadcast_assignment = assignment.reshape(unsqueezed_shape).expand(extra_feature.shape)
        result = torch.scatter_add(feature, 0, broadcast_assignment, extra_feature)
        return result / assignment_counts.reshape(unsqueezed_shape)

    cluster_deletion_mean = cluster_average(features['msa_deletion_count'], features['extra_msa_deletion_count'])
    cluster_deletion_mean = 2/torch.pi * torch.arctan(cluster_deletion_mean / 3)

    msa_one_hot = nn.functional.one_hot(features['msa_aatype'], num_classes=23)
    extra_msa_one_hot = nn.functional.one_hot(features['extra_msa_aatype'], num_classes=23)
    cluster_profile = cluster_average(msa_one_hot, extra_msa_one_hot)

    features['cluster_deletion_mean'] = cluster_deletion_mean
    features['cluster_profile'] = cluster_profile
    return features




    

In [39]:
def crop_extra_msa(features, max_extra_msa_count=5120, basic_seed=0):
    N_extra = features['extra_msa_aatype'].shape[0]
    gen = torch.Generator(features['extra_msa_aatype'].device)
    gen.manual_seed(basic_seed+2)
    inds_to_select = torch.randperm(N_extra, generator=gen)[:max_extra_msa_count]
    for k,v in features.items():
        if 'extra_' in k:
            features[k] = features[k][inds_to_select]
    return features


In [40]:
def calculate_msa_feat(features):
    N_clust, N_res = features['msa_aatype'].shape

    cluster_msa = nn.functional.one_hot(features['msa_aatype'], num_classes=23)

    cluster_has_deletion = (features['msa_deletion_count'] > 0).float()
    cluster_has_deletion = cluster_has_deletion.reshape(N_clust, N_res, 1)

    cluster_deletion_value = 2/torch.pi * torch.arctan(features['msa_deletion_count'] / 3)
    cluster_deletion_value = cluster_deletion_value.reshape(N_clust, N_res, 1)

    cluster_deletion_mean = features['cluster_deletion_mean']
    cluster_deletion_mean = cluster_deletion_mean.reshape(N_clust, N_res, 1)
    cluster_profile = features['cluster_profile']

    msa_feat = torch.cat((cluster_msa, cluster_has_deletion, cluster_deletion_value, cluster_profile, cluster_deletion_mean), dim=-1)
    return msa_feat


    

In [41]:
def calculate_extra_msa_feat(features):
    N_extra, N_res = features['extra_msa_aatype'].shape

    extra_msa = nn.functional.one_hot(features['extra_msa_aatype'], num_classes=23)
    extra_msa_has_deletion = (features['extra_msa_deletion_count'] > 0).float()
    extra_msa_has_deletion = extra_msa_has_deletion.reshape(N_extra, N_res, 1)
    extra_msa_deletion_value = 2/torch.pi * torch.arctan(features['extra_msa_deletion_count']/3)
    extra_msa_deletion_value = extra_msa_deletion_value.reshape(N_extra, N_res, 1)

    return torch.cat((extra_msa, extra_msa_has_deletion, extra_msa_deletion_value), dim=-1)


In [59]:
seq_data = load_a3m_file('current_implementation/alignments_hhr/test_tautomerase/test_tautomerase.a3m')
msa_feats = []
extra_msa_feats = []
for i in range(4):
    print(f'Iteration {i}')
    features = split_clusters(seq_data, basic_seed=i)
    other_features = torch.load('solutions/feature_extraction/control_values/clusters_selected.pt')
    for key, param in features.items():
        if key != 'amino_acid_distribution':
            if 'aatype' in key:
                other_features[key] = other_features[key].argmax(dim=-1)
            assert torch.allclose(param.float(), other_features[key].float()), f'Error with {key} after splitting.'
    features = mask_clusters(features, basic_seed=i)
    other_features = torch.load('solutions/feature_extraction/control_values/clusters_masked.pt')
    for key, param in features.items():
        if key != 'amino_acid_distribution':
            if 'aatype' in key:
                other_features[key] = other_features[key].argmax(dim=-1)
            print(key)
            assert torch.allclose(param.float(), other_features[key].float()), f'Error with {key} after masking.'
    features = cluster_assignment(features)
    other_features = torch.load('solutions/feature_extraction/control_values/clusters_assigned.pt')
    for key, param in features.items():
        if key != 'amino_acid_distribution':
            if 'aatype' in key:
                other_features[key] = other_features[key].argmax(dim=-1)

            if key == 'extra_cluster_assignment':
                key = 'cluster_assignment'
            assert torch.allclose(param.float(), other_features[key].float()), f'Error with {key} after assigning.'
    features = summarize_clusters(features)
    other_features = torch.load('solutions/feature_extraction/control_values/clusters_summarized.pt')
    for key, param in features.items():
        if key != 'amino_acid_distribution':
            if 'aatype' in key:
                other_features[key] = other_features[key].argmax(dim=-1)
            if key == 'extra_cluster_assignment':
                key = 'cluster_assignment'
            assert torch.allclose(param.float(), other_features[key].float()), f'Error with {key} after summarizing.'
    features = crop_extra_msa(features, basic_seed=i)
    other_features = torch.load('solutions/feature_extraction/control_values/extra_msa_cropped.pt')
    for key, param in features.items():
        if key != 'amino_acid_distribution':
            if 'aatype' in key:
                other_features[key] = other_features[key].argmax(dim=-1)

            if key == 'extra_cluster_assignment':
                key = 'cluster_assignment'
            assert torch.allclose(param.float(), other_features[key].float()), f'Error with {key} after cropping.'
    msa_feat = calculate_msa_feat(features)
    extra_msa_feat = calculate_extra_msa_feat(features)

    msa_feats.append(msa_feat)
    extra_msa_feats.append(extra_msa_feat)

target_feat = calculate_target_feat(sequence)
residue_index = calculate_residue_index(sequence)

msa_feat = torch.stack(msa_feats, dim=-1)
extra_msa_feat = torch.stack(extra_msa_feats, dim=-1)
target_feat = target_feat.view(target_feat.shape+(1,)).broadcast_to(target_feat.shape+(4,))
residue_index = residue_index.view(residue_index.shape+(1,)).broadcast_to(residue_index.shape+(4,))



Iteration 0
msa_aatype
msa_deletion_count
extra_msa_aatype
extra_msa_deletion_count
true_msa_aatype


AssertionError: Error with extra_msa_aatype after cropping.

In [60]:
print('msa_feat:')
print((msa_feat - of_basic_features['msa_feat']).abs().mean())
print('target_feat:')
print((target_feat - of_basic_features['target_feat']).abs().mean())
print('residue index:')
print((residue_index - of_basic_features['residue_index']).float().abs().mean())
print('extra_msa_feat:')
print((extra_msa_feat[...,0] - of_extra_msa_feat).abs().mean())

my_batch = {
    'msa_feat': msa_feat.float(),
    'extra_msa_feat': extra_msa_feat.float(),
    'target_feat': target_feat.float(),
    'residue_index': residue_index.float()
}

# torch.save(my_batch, 'kilian/my_batch.pt')

msa_feat:


RuntimeError: The size of tensor a (49) must match the size of tensor b (4) at non-singleton dimension 3