In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
from scripts.dataset import Dataset
from scripts.rgyr_model import Model as RGYR
from scripts.tdiff_model import Model as TDIFF
from scripts.rdiff_model import Model as RDIFF
from scripts.gendata import Protein

import warnings
warnings.filterwarnings('ignore')

In [None]:
files = []

if len(files) == 1:
    onefile = True
    files.append(files[0]) # `Dataset` class has a hard time recognizing only one file
elif len(files) > 1:
    onefile = False
else:
    print('Please select one or more PDB files.')
    exit()

In [None]:
# Initialize variables
temp = 'temp.csv'
stdout0 = sys.stdout
toppath = '/home/spencer/ml/hpro/models'

# Analyze PDBs
with open(temp, 'w') as f:
    sys.stdout = f
    for pdb in files:
        protein = Protein(pdb)
        print(0, end=',') # 0 is a place holder for 'label' in `Dataset` class
        protein.all_ca_cofm_dist(1000)
    f.close()

sys.stdout = stdout0

In [None]:
def predict(p,Model):

    # Load model
    m = Model()
    m.load_state_dict(torch.load(p))
    m.eval()
    m.to('cuda')

    # Predict properties
    with torch.no_grad():

        item = Dataset(temp)
        test_loader = torch.utils.data.DataLoader(item)

        predictions = []

        for input, target in test_loader:
            #print(input,target)
            input = input.to('cuda')

            prediction = m(input)

            predictions.append(prediction.item())

    return predictions

In [None]:
rgyrs = predict(f'{toppath}/rgyr_other.pt',RGYR)
tdiffs = predict(f'{toppath}/tdiff.pt',TDIFF)
rdiffs = predict(f'{toppath}/rdiff.pt',RDIFF)

In [None]:
if onefile:
    del files[1]

for i in range(len(files)):
    print(f"\nFILE: '{files[i]}'\n")
    print('                  Radius of Gyration:', rgyrs[i] / 1e8, 'cm')
    print(' Translational Diffusion Coefficient:', tdiffs[i] * 0.01, 'nm^2/ns')
    print('    Rotational Diffusion Coefficient:', rdiffs[i] * 0.001, 'ns^-1')
    
os.remove(temp)
print()