In [None]:
import numpy as np
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch.nn import ModuleList, Embedding
from torch.nn import Sequential, ReLU, Linear
from torch_geometric.data import DataLoader
from torch_geometric.nn import PNAConv, BatchNorm, global_add_pool

import time
import matplotlib as mpl
import matplotlib.pyplot as plt

In [None]:
dist_th = 12.0  # threshold distance between CA atoms below which edges will be built
idx_mode = 0  # first mode = 0, last mode = 63

batch_size = 32
total_epochs = 100

num_layers = 4

np.random.seed(0)
torch.manual_seed(0)

task_name_temp = 'all_th'+str(int(dist_th))
task_name = task_name_temp+'_mode'+str(idx_mode)

In [None]:
checkpoint = torch.load('data/'+task_name+'_Nlayer'+str(num_layers)+
                        '_checkpoint'+str(total_epochs)+'.pt')
test_dataset = checkpoint['test_dataset']
test_loader_sub = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
deg = checkpoint['deg']

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.node_emb = Embedding(20, 75)

        aggregators = ['mean', 'std']
        scalers = ['identity', 'amplification', 'attenuation']

        self.convs = ModuleList()
        self.batch_norms = ModuleList()
        for _ in range(num_layers):
            conv = PNAConv(in_channels=75, out_channels=75,
                           aggregators=aggregators, scalers=scalers, deg=deg,
                           edge_dim=1, towers=5, pre_layers=1, post_layers=1,
                           divide_input=False)
            self.convs.append(conv)
            self.batch_norms.append(BatchNorm(75))

        self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),
                              Linear(25, 1))

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze())

        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = F.relu(batch_norm(conv(x, edge_index, edge_attr.unsqueeze(1))))

        x = global_add_pool(x, batch)
        return self.mlp(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
@torch.no_grad()
def get_test_results(loader):
    pred = []
    truth = []
    model.eval()
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        pred += out.squeeze().tolist()
        truth += data.y.tolist()
    return pred, truth

In [None]:
t = time.time()
pred, truth = get_test_results(test_loader_sub)
print(f'Time: {time.time()-t:.4f}')

In [None]:
freqs_threshold = [8, 8, 10, 10, 12, 14, 14, 16,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40,
                  40, 40, 40, 40, 40, 40, 40, 40]
freq_threshold = freqs_threshold[idx_mode]

In [None]:
mpl.rcParams['axes.linewidth'] = 10
mpl.rcParams['xtick.major.size'] = 30
mpl.rcParams['xtick.major.width'] = 10
mpl.rcParams['ytick.major.size'] = 30
mpl.rcParams['ytick.major.width'] = 10

fontsize = 150
plt.figure(figsize=(50,50))
plt.scatter(truth, pred, s=100, cmap='coolwarm')
plt.plot([0, freq_threshold], [0, freq_threshold], linewidth=5, color='grey') # diagonal line
plt.xlabel('NMA frequency (cm$^{-1}$)', fontsize=fontsize)
plt.ylabel('ML frequency (cm$^{-1}$)', fontsize=fontsize)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.xlim([0, freq_threshold])
plt.ylim([0, freq_threshold])
plt.show()