In [1]:
import torch
from pathlib import Path
import torch.nn as nn
from src.data import AllVertices, ProteinRecord
from src.nn_model import AmberNN
from run_opts import config_runtime

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_cls = 3  # number of classes
conf_dev = config_runtime['device']
train_f = config_runtime['train_frac']
learning_rate = config_runtime['learning_rate']
batch_size = config_runtime['batch_size']
hid_size = config_runtime['hidden_size']
log_step = config_runtime['log_step']
epochs = config_runtime['num_epochs']
run_name = config_runtime['run_name']

In [3]:
device = torch.device(conf_dev)

In [4]:
def save_for_paraview(filename, xyz_tensor, color_tensor, z_shift = 0):
    
    tosave = torch.cat([xyz_tensor, color_tensor[:,None]], dim=1)
    qc = tosave.cpu().detach().numpy()
    with open(filename, 'w') as iFile:
        iFile.write("x,y,z,c\n")
        for i in qc:
            x,y,z,c = i
            iFile.write(f'{x},{y},{z+z_shift},{c}\n')
    
    return 0

# Read model and data

In [5]:
trial_pdbs = []
with open("../data/lists/test.txt", 'r') as iFile:
    for i in iFile:
        trial_pdbs.append(i.strip())

In [6]:
trial_pdbs

['1FZW_B', '6BOY_BC']

In [7]:
pd = AllVertices(trial_pdbs)

# Predict

In [8]:
pd.info()

{'1FZW_B': [0, 13034], '6BOY_B': [13034, 31124], '6BOY_C': [31124, 37675]}

In [9]:
pidcid = '6BOY_C'

In [13]:
# X.shape

In [14]:
X, y, p = pd.get_protein(pidcid)
X = X.to(device)
y = y.to(device)
p = p.to(device)

In [15]:
n_features = 111
model = AmberNN(n_features, 3, 256).to(device)
model.load_state_dict(torch.load("dnn_2023-03-02_17-29.pytorch"))
model.eval()

AmberNN(
  (fc1): Linear(in_features=111, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=3, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [16]:
with torch.no_grad():
    logits = model(X)
    pred_probab = nn.Softmax(dim=1)(logits)
    y_pred = pred_probab.argmax(1)
    # print(f"Predicted class: {y_pred}")

In [18]:
save_for_paraview(f"predicted_{pidcid}.csv", p, y_pred)
save_for_paraview(f"real_{pidcid}.csv", p, y)

0