In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import os
from cnndockbench.net import ConvProt, FullNN
import torch
from cnndockbench.utils import getTensor
from torch import nn

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
def splitTrainTests(prots, ligs, corrVal, ratio=0.8):
    idxs = np.arange(prots.shape[0])
    np.random.shuffle(idxs)
    train_samples = int(prots.shape[0] * 0.8)
    test_samples = prots.shape[0] - train_samples
    print("Number of train samples: ", train_samples)
    print('Number of test samples: ', test_samples)
    prots_test_train = [prots[:train_samples], prots[train_samples:]]
    ligs_test_train = [ligs[:train_samples], ligs[train_samples:]]
    corrVal_test_train = [corrVal[:train_samples], corrVal[train_samples:]]
    
    return prots_test_train, ligs_test_train, corrVal_test_train
    
    

In [5]:
# load arrays
prots = np.load(os.path.join('data', 'protFeats.npy'))
ligs = np.load(os.path.join('data', 'ligFeats.npy'))
corrVal = np.load(os.path.join('data', 'correctValues.npy'))

In [6]:
prots.shape, ligs.shape, corrVal.shape

((199, 24, 24, 24, 7), (199, 1024), (199, 3, 17))

In [7]:
protsSplits, ligsSplits, corrValSplits = splitTrainTests(prots, ligs, corrVal)

Number of train samples:  159
Number of test samples:  40


In [16]:
ligsSplits[0].shape

(159, 1024)

## Train

In [8]:
prots_train = protsSplits[0]
ligs_train = ligsSplits[0]
corrVal_train = corrValSplits[0]

In [9]:
modelProt = ConvProt().to(device)

In [10]:
modelComples = FullNN().to(device)

In [27]:
criterion1 = nn.MSELoss()
criterion2 = nn.MSELoss()
criterion3 = nn.PoissonNLLLoss()

In [32]:
numEpochs = 25
learning_rate = 0.001
optimizer = torch.optim.Adam(modelComples.parameters(), lr=learning_rate)
for epoch in range(numEpochs):
    for i, (p,l, c) in enumerate(zip(prots_train, ligs_train, corrVal_train)):
        p = p.reshape(7,24,24,24)
        tp = getTensor(p, 'float32', device)
        tp = tp.unsqueeze(0)

        l = np.array(list(l)).reshape(1, 1024)
        tl = getTensor(l, 'float32', device)
        
        out = modelProt(tp)
        
        o1,o2,o3 = modelComples(tl, out)
        
        L1 = torch.from_numpy(c[0, :].astype('float32').reshape(1, 17)).to(device)
        L2 = torch.from_numpy(c[1, :].astype('float32').reshape(1, 17)).to(device)
        L3 = torch.from_numpy(c[2, :].astype('float32').reshape(1, 17)).to(device)

        loss1 = criterion1(o1, L1)
        loss2 = criterion2(o2, L2)
        loss3 = criterion3(o3, L3)
        
        loss = loss1 + loss2 + loss3
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, numEpochs, 
                                                                     i+1, ligs_train.shape[0],
                                                                    loss.item()) )
        

Epoch [1/25], Step [10/159], Loss: nan
Epoch [1/25], Step [20/159], Loss: nan
Epoch [1/25], Step [30/159], Loss: nan
Epoch [1/25], Step [40/159], Loss: nan
Epoch [1/25], Step [50/159], Loss: nan
Epoch [1/25], Step [60/159], Loss: nan
Epoch [1/25], Step [70/159], Loss: nan
Epoch [1/25], Step [80/159], Loss: nan
Epoch [1/25], Step [90/159], Loss: nan
Epoch [1/25], Step [100/159], Loss: nan
Epoch [1/25], Step [110/159], Loss: nan
Epoch [1/25], Step [120/159], Loss: nan
Epoch [1/25], Step [130/159], Loss: nan
Epoch [1/25], Step [140/159], Loss: nan
Epoch [1/25], Step [150/159], Loss: nan
Epoch [2/25], Step [10/159], Loss: nan
Epoch [2/25], Step [20/159], Loss: nan
Epoch [2/25], Step [30/159], Loss: nan
Epoch [2/25], Step [40/159], Loss: nan
Epoch [2/25], Step [50/159], Loss: nan
Epoch [2/25], Step [60/159], Loss: nan
Epoch [2/25], Step [70/159], Loss: nan
Epoch [2/25], Step [80/159], Loss: nan


KeyboardInterrupt: 