In [138]:
from sequence_models.pretrained import load_model_and_alphabet
from sequence_models.pdb_utils import parse_PDB, process_coords
import torch
import numpy as np
import pandas as pd
import os
from Bio import AlignIO

align = AlignIO.read('../alignment/seqs.afa', "fasta")
model, collater = load_model_and_alphabet('checkpoints/mif.pt')

In [139]:
coords, wt, _ = parse_PDB('AlphaFold2_models/GTs_best/AFJ52996.1/AFJ52996.1_ranked_0.pdb')
coords = {
        'N': coords[:, 0],
        'CA': coords[:, 1],
        'C': coords[:, 2]
    }
dist, omega, theta, phi = process_coords(coords)
batch = [[wt, torch.tensor(dist, dtype=torch.float),
          torch.tensor(omega, dtype=torch.float),
          torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
src, nodes, edges, connections, edge_mask = collater(batch)
# can use result='repr' or result='logits'. Default is 'repr'.
rep = model(src, nodes, edges, connections, edge_mask) 

In [171]:
raw_rep_dict = {}
mean_rep_dict = {}
rep_mean_pooling_df = pd.DataFrame(columns=np.arange(256))
rep_mean_pooling_Nt_df = pd.DataFrame(columns=np.arange(256))
for path, dirs, files in os.walk('AlphaFold2_models/GTs_best/'):
    if len(dirs) == 0:
        prot_name = os.path.basename(path)
        coords, wt, _ = parse_PDB(path+'/'+prot_name+'_ranked_0.pdb')
        coords = {
        'N': coords[:, 0],
        'CA': coords[:, 1],
        'C': coords[:, 2]
            }
        dist, omega, theta, phi = process_coords(coords)
        batch = [[wt, torch.tensor(dist, dtype=torch.float),
                torch.tensor(omega, dtype=torch.float),
                torch.tensor(theta, dtype=torch.float), torch.tensor(phi, dtype=torch.float)]]
        src, nodes, edges, connections, edge_mask = collater(batch)
        # can use result='repr' or result='logits'. Default is 'repr'.
        rep = model(src, nodes, edges, connections, edge_mask)
        raw_rep_dict[prot_name] = rep[0].detach().numpy().flatten()
        mean_rep_dict[prot_name] = rep[0].detach().numpy().mean(axis=1)
        rep_mean_pooling_df.loc[prot_name] = rep[0].detach().numpy().mean(axis=0)
        rep_mean_pooling_Nt_df.loc[prot_name] = rep[0].detach().numpy()[:int(len(rep[0])/2)].mean(axis=0)

In [173]:
rep_mean_aligned_df = pd.DataFrame(columns=np.arange(1230))
rep_mean_aligned_Nt_df = pd.DataFrame(columns=np.arange(int(1230/2)))
for prot in align:
    seq_ = np.array(prot.seq)
    id_ = prot.id
    if id_.count('/') == 1:
        id_ = id_.replace('/','_')
    array_ = np.zeros(1230)
    array_[np.where(seq_!='-')] = mean_rep_dict[id_]
    rep_mean_aligned_df.loc[id_] = array_
    rep_mean_aligned_Nt_df.loc[id_] = array_[:int(len(array_)/2)]



In [177]:
rep_mean_aligned_df.index.names = ['enzyme']
rep_mean_pooling_df.index.names = ['enzyme']
rep_mean_aligned_Nt_df.index.names = ['enzyme']
rep_mean_pooling_Nt_df.index.names = ['enzyme']

In [178]:
rep_mean_aligned_df.to_csv('../encodings/mif_mean_aligned.tsv',sep='\t')
rep_mean_pooling_df.to_csv('../encodings/mif_mean_pooling.tsv',sep='\t')
rep_mean_aligned_Nt_df.to_csv('../encodings/mif_mean_aligned_Nt.tsv',sep='\t')
rep_mean_pooling_Nt_df.to_csv('../encodings/mif_mean_pooling_Nt.tsv',sep='\t')

In [176]:
rep_mean_pooling_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,246,247,248,249,250,251,252,253,254,255
At_73C3,0.798275,0.172551,-0.343319,0.185870,0.063439,-0.266953,-0.143875,-0.103003,0.325972,0.116491,...,0.592617,0.218935,0.286373,0.209605,-0.043539,-0.130537,-2.614981,0.374716,-0.240361,0.083336
At_73C4,0.779366,0.191016,-0.358959,0.171749,0.092846,-0.276260,-0.130294,-0.084544,0.340099,0.102399,...,0.592424,0.237597,0.258683,0.249600,-0.058800,-0.144022,-2.578619,0.398127,-0.224097,0.063335
Nt_UGT1,0.670855,0.154348,-0.389531,0.133537,0.069738,-0.370396,-0.105383,-0.150647,0.354411,0.096507,...,0.580596,0.269597,0.199412,0.284548,-0.066262,-0.168655,-2.705203,0.449160,-0.190234,0.100679
At_85A5,0.664204,0.194804,-0.369377,0.187246,0.032297,-0.294046,-0.103419,-0.090984,0.393898,0.056157,...,0.608657,0.194131,0.181928,0.289314,-0.057854,-0.162848,-2.503097,0.435016,-0.225531,0.094676
Lj_72L6,0.714345,0.136195,-0.346540,0.170281,0.179434,-0.377024,-0.099853,-0.164740,0.335972,0.081895,...,0.641644,0.301074,0.155014,0.221720,-0.044994,-0.145239,-2.712664,0.459210,-0.172612,0.029947
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
PgCGT1_UGT708A44,0.466756,0.164475,-0.477007,0.136100,0.435534,-0.244583,-0.037698,-0.001062,0.324374,0.323881,...,1.027306,0.116716,-0.144413,0.494078,-0.142917,-0.429550,-2.642405,0.657699,-0.517062,-0.067148
Zm_71B1,0.556898,0.152396,-0.344993,0.165175,0.327043,-0.306747,0.004290,-0.058974,0.336446,0.195468,...,0.873595,0.144034,0.020911,0.321267,-0.121102,-0.299260,-2.559200,0.598338,-0.408216,-0.016231
Lu_71M1,0.714976,0.157229,-0.363578,0.127835,0.237376,-0.352023,-0.070645,-0.091494,0.362785,0.128387,...,0.595850,0.244608,0.134498,0.312922,-0.068452,-0.219887,-2.672823,0.520080,-0.223292,0.028281
At_89C1,0.684921,0.153683,-0.402246,0.152336,0.111235,-0.271743,-0.124045,-0.102007,0.359531,0.136241,...,0.663636,0.159286,0.164335,0.317986,-0.121072,-0.177273,-2.595211,0.512483,-0.231423,0.070819


In [197]:
with open('path.txt','w') as f:
    for path, dirs, files in os.walk('AlphaFold2_models/GTs_best/'):
        prot_name = os.path.basename(path)
        f.write(path+'/'+prot_name+'_ranked_0.pdb')
        f.write('\t')
        f.write('d3urra1')
        f.write('\n')

In [195]:
fps, domids, notes = [], [], []
with open('path.txt') as f:
    for line in f.readlines():
        cols = line.strip().split(None, 2)
        fps.append(cols[0])
        domids.append(cols[1])
        notes.append(cols[2] if len(cols) > 2 else "-")

In [196]:
domids

['At_73C3',
 'At_73C4',
 'Nt_UGT1',
 'At_85A5',
 'Lj_72L6',
 'Sr_88B1',
 'Gm_88E3',
 'UGT71G1',
 'At_89B1',
 'TwUGT2',
 'UGT72B10',
 'At_71C1',
 'Ct_71E5',
 'At_73C5',
 'At_85A4',
 'At_76E11',
 'Bm_88D9',
 'Lc_71A12',
 'UGT86A5',
 'Zm_85854',
 'Zm_15991',
 'PpCGT1_UGT708A45',
 'SiCGT3_UGT708A33',
 'Lc_72B10',
 'IroB',
 'Mt_72L1',
 'BdCGT2_UGT708A8',
 'TaCGT1-A_UGT708A14',
 'QCD86231.1',
 'PhCGT2_UGT708A46',
 'SsfS6',
 'Zm_72068',
 'Lu_71Q2',
 'At_72B3',
 'Fe_72AC1',
 'At_74B1',
 'At_78D1',
 'SiCGT1_UGT708A31',
 'GmCGT_UGT708D1',
 'Lu_72M2',
 'At_72D1',
 'UGT78G1',
 'Fi_88A10',
 'Lj_72V3',
 'ApUFGT3',
 'Nt_71A6',
 'At_72E2',
 'Sr_71E1',
 'Zm_56026',
 'OleD',
 'At_84A1',
 'Gj_71A23',
 'At_76E4',
 'Lu_71P1',
 'UGT84C2',
 'ZmCGT5_UGT708A42',
 'At_74F1',
 'Mt_88E1',
 'UGT74B1',
 'Os_72F1',
 'At_72E3',
 'ZmCGT2_UGT708A5',
 'UGT74B6',
 'UGT85A73',
 'ApUFGT2',
 'Fe_72B19',
 'OsCGT6_UGT708A40',
 'GgCGT_UGT708B4',
 'At_76E2',
 'TaCGT2-A_UGT708A15',
 'MiCGT',
 'UGT85A20',
 'At_72C1',
 'At_76E5',
