In [None]:
import numpy as np
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool

import pickle

In [None]:
test_domains = ['1qlc', '2dfe', '2fbo', '2hoa', '4azq', '4id7']

dist_th = 12.0  # threshold distance between CA atoms below which edges will be built
idx_mode = 63  # first mode = 0, 64th mode = 63

batch_size = 32
epochs = 50

num_layers = 4

np.random.seed(0)
torch.manual_seed(0)

task_name_temp = 'all_th'+str(int(dist_th))
task_name = task_name_temp+'_mode'+str(idx_mode)

In [None]:
checkpoint = torch.load('data/'+task_name+'_Nlayer'+str(num_layers)+
                        '_checkpoint'+str(epochs)+'.pt')

In [None]:
def load_vocab(filename):
    try:
        d = dict()
        with open(filename) as f:
            for idx, word in enumerate(f):
                word = word.strip()
                d[word] = idx

    except IOError:
        raise MyIOError(filename)
    return d

vocab_chars = load_vocab("food_chars.txt")


def get_node_features(sequence):
    seq = []
    for res in list(sequence):
        seq.append(vocab_chars[res])
    
    node_features = torch.tensor(seq, dtype=torch.long)
    
    return node_features


def domain2graph(domain):
    # load distance matrix
    with open('data/'+domain+'_dist.pickle', 'rb') as handle:
        protein_dict = pickle.load(handle)

    assert domain == protein_dict['domain']
    dist = protein_dict['dist']
    
    # load sequence
    f1 = open('data/'+domain+'.fasta', 'r')
    first_line = True
    for line in f1:
        if first_line:
            first_line = False
        else:
            sequence = line[:-1]
    f1.close()
    node_features = get_node_features(sequence)
    
    edges = []
    CA_distances = []
    for idx1 in range(len(sequence)):
        for idx2 in range(len(sequence)):
            if idx1 != idx2:
                distance = dist[idx1][idx2]
                if distance < dist_th:
                    edges.append([idx1, idx2])
                    CA_distances.append(distance)
    
    edge_index = torch.tensor(edges, dtype=torch.long)
    edge_attr = torch.tensor(CA_distances, dtype=torch.float)
    
    return node_features, edge_index.T, edge_attr

In [None]:
test_dataset = []
for domain in test_domains:
    node_features, edge_index, edge_attr = domain2graph(domain)
    data = Data(edge_attr=edge_attr, edge_index=edge_index, x=node_features, y=torch.tensor([-1], dtype=torch.float))
    test_dataset.append(data)

In [None]:
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.node_emb = Embedding(20, 75)

        aggregators = ['mean', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms = ModuleList()
        for _ in range(num_layers):
            conv = PNAConv(in_channels=75, out_channels=75,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=1, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(75))

        self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),
                              Linear(25, 1))

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr.unsqueeze(1))))

        x = global_add_pool(x, batch)
        return self.mlp(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
deg = checkpoint['deg']
model = Net().to(device)
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
@torch.no_grad()
def get_test_results(loader):
    pred = []
    truth = []
    model.eval()
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred += out.squeeze().tolist()
    return pred

In [None]:
pred = get_test_results(test_loader)

In [None]:
print('Domain    ML freq (cm-1)')
for domain, f_ML in zip(test_domains, pred):
    print(domain, f_ML)