In [1]:
from __future__ import (
    print_function,
    division,
    absolute_import,
)
import collections
from copy import copy

import pandas
import random
import numpy as np
import torch
from six import StringIO

In [2]:
AA_SYMOLS = ['A', 'R', 'N', 'D', 'C',
             'Q', 'E', 'G', 'H', 'I',
             'L', 'K', 'M', 'F', 'P',
             'S', 'T', 'W', 'Y', 'V']

In [3]:
def encode_ligand(ligand):
    m = list()
    for symbol in AA_SYMOLS:
        channel = list()
        for aa in ligand:
            if aa.upper() == symbol: channel.append(1.0)
            else: channel.append(random.uniform(0.001, 0.01))
        m.append(channel)
    m = np.array(m).reshape(1, len(ligand), 20)
    return m

In [7]:
print(encode_ligand('KMYEYVFKG'))

[[[0.00712255 0.00147424 0.00173254 0.00733454 0.00486762 0.00701955
   0.00566584 0.00993252 0.00538337 0.00690312 0.00698338 0.00783121
   0.00554433 0.00109508 0.00335705 0.00770352 0.00652109 0.00372498
   0.00875786 0.0030225 ]
  [0.00824988 0.00887811 0.005264   0.00918274 0.00145545 0.00300292
   0.00728582 0.00526036 0.00332123 0.00212888 0.00572468 0.00127711
   0.00199793 0.00859027 0.00207352 0.00658875 0.00180687 0.00579378
   0.00896766 0.00684221]
  [0.00646873 0.0013779  0.00171367 0.00554197 0.00159154 0.0081159
   0.00977888 0.00449383 0.0086963  0.00308931 0.0089887  0.00509944
   0.0086759  0.00493609 0.00663178 0.00497767 0.00415516 1.
   0.0039738  0.00910814]
  [0.00556365 0.00915135 0.00119691 0.00852738 0.00542975 0.00541686
   0.00715161 0.00835746 0.00932109 0.00697007 0.00361128 1.
   0.00754489 0.00179376 0.00252989 0.00949418 0.00328642 0.00428285
   0.00767927 0.00852271]
  [0.0048442  0.00790229 0.00555354 0.00199327 0.00215958 0.00234197
   0.00987413 0.

In [4]:
COMMON_AMINO_ACIDS = collections.OrderedDict(sorted({
    "A": "Alanine",
    "R": "Arginine",
    "N": "Asparagine",
    "D": "Aspartic Acid",
    "C": "Cysteine",
    "E": "Glutamic Acid",
    "Q": "Glutamine",
    "G": "Glycine",
    "H": "Histidine",
    "I": "Isoleucine",
    "L": "Leucine",
    "K": "Lysine",
    "M": "Methionine",
    "F": "Phenylalanine",
    "P": "Proline",
    "S": "Serine",
    "T": "Threonine",
    "W": "Tryptophan",
    "Y": "Tyrosine",
    "V": "Valine",
}.items()))
COMMON_AMINO_ACIDS_WITH_UNKNOWN = copy(COMMON_AMINO_ACIDS)
COMMON_AMINO_ACIDS_WITH_UNKNOWN["X"] = "Unknown"

AMINO_ACID_INDEX = dict(
    (letter, i) for (i, letter) in enumerate(COMMON_AMINO_ACIDS_WITH_UNKNOWN))

AMINO_ACIDS = list(COMMON_AMINO_ACIDS_WITH_UNKNOWN.keys())

BLOSUM62_MATRIX = pandas.read_csv(StringIO("""
   A  R  N  D  C  Q  E  G  H  I  L  K  M  F  P  S  T  W  Y  V  X
A  4 -1 -2 -2  0 -1 -1  0 -2 -1 -1 -1 -1 -2 -1  1  0 -3 -2  0  0
R -1  5  0 -2 -3  1  0 -2  0 -3 -2  2 -1 -3 -2 -1 -1 -3 -2 -3  0
N -2  0  6  1 -3  0  0  0  1 -3 -3  0 -2 -3 -2  1  0 -4 -2 -3  0
D -2 -2  1  6 -3  0  2 -1 -1 -3 -4 -1 -3 -3 -1  0 -1 -4 -3 -3  0
C  0 -3 -3 -3  9 -3 -4 -3 -3 -1 -1 -3 -1 -2 -3 -1 -1 -2 -2 -1  0
Q -1  1  0  0 -3  5  2 -2  0 -3 -2  1  0 -3 -1  0 -1 -2 -1 -2  0
E -1  0  0  2 -4  2  5 -2  0 -3 -3  1 -2 -3 -1  0 -1 -3 -2 -2  0
G  0 -2  0 -1 -3 -2 -2  6 -2 -4 -4 -2 -3 -3 -2  0 -2 -2 -3 -3  0
H -2  0  1 -1 -3  0  0 -2  8 -3 -3 -1 -2 -1 -2 -1 -2 -2  2 -3  0
I -1 -3 -3 -3 -1 -3 -3 -4 -3  4  2 -3  1  0 -3 -2 -1 -3 -1  3  0
L -1 -2 -3 -4 -1 -2 -3 -4 -3  2  4 -2  2  0 -3 -2 -1 -2 -1  1  0
K -1  2  0 -1 -3  1  1 -2 -1 -3 -2  5 -1 -3 -1  0 -1 -3 -2 -2  0
M -1 -1 -2 -3 -1  0 -2 -3 -2  1  2 -1  5  0 -2 -1 -1 -1 -1  1  0
F -2 -3 -3 -3 -2 -3 -3 -3 -1  0  0 -3  0  6 -4 -2 -2  1  3 -1  0
P -1 -2 -2 -1 -3 -1 -1 -2 -2 -3 -3 -1 -2 -4  7 -1 -1 -4 -3 -2  0
S  1 -1  1  0 -1  0  0  0 -1 -2 -2  0 -1 -2 -1  4  1 -3 -2 -2  0
T  0 -1  0 -1 -1 -1 -1 -2 -2 -1 -1 -1 -1 -2 -1  1  5 -2 -2  0  0 
W -3 -3 -4 -4 -2 -2 -3 -2 -2 -3 -2 -3 -1  1 -4 -3 -2 11  2 -3  0
Y -2 -2 -2 -3 -2 -1 -2 -3  2 -1 -1 -2 -1  3 -3 -2 -2  2  7 -1  0
V  0 -3 -3 -3 -1 -2 -2 -3 -3  3  1 -2  1 -1 -2 -2  0 -3 -1  4  0
X  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1
"""), sep='\s+').loc[AMINO_ACIDS, AMINO_ACIDS]
assert (BLOSUM62_MATRIX == BLOSUM62_MATRIX.T).all().all()

In [5]:
BLOSUM62_MATRIX.head()

Unnamed: 0,A,C,D,E,F,G,H,I,K,L,...,N,P,Q,R,S,T,V,W,Y,X
A,4,0,-2,-1,-2,0,-2,-1,-1,-1,...,-2,-1,-1,-1,1,0,0,-3,-2,0
C,0,9,-3,-4,-2,-3,-3,-1,-3,-1,...,-3,-3,-3,-3,-1,-1,-1,-2,-2,0
D,-2,-3,6,2,-3,-1,-1,-3,-1,-4,...,1,-1,0,-2,0,-1,-3,-4,-3,0
E,-1,-4,2,5,-3,-2,0,-3,1,-3,...,0,-1,2,0,0,-1,-2,-3,-2,0
F,-2,-2,-3,-3,6,-3,-1,0,-3,0,...,-3,-4,-3,-3,-2,-2,-1,1,3,0


In [20]:
from torch.autograd import Variable
#a = torch.randn(5, 2)

# params
n_ex = 10
nets = 5
epochs = 1000
samples = 100
torch.manual_seed(1)

x_train = Variable(torch.rand(n_ex, 1), requires_grad=False)
y_train = Variable(torch.rand(n_ex, 1), requires_grad=False)

netlist = []
params = []

class MLP(torch.nn.Module):
    def __init__(self, h):
        super(MLP, self).__init__()
        self.main = torch.nn.Sequential(
                    torch.nn.Linear(1, h),
                    torch.nn.ReLU(),
                    torch.nn.Linear(h, h),
                    torch.nn.ReLU(),
                    torch.nn.Linear(h, 1),
                    torch.nn.Tanh()
                    )
    def forward(self, x):
        return self.main(x)

for i in range(1, nets+1):
    h = i
    print('Model built with h =', h)
    netlist.append(MLP(h))
params += list(netlist[i-1].parameters())

optimizer = torch.optim.SGD(params, 0.03)
crit = torch.nn.MSELoss(size_average=False)

for i in range(epochs):
    loss_list = []
    for net in netlist:
        optimizer.zero_grad()
        out = net(x_train)
        loss = crit(out, y_train)
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

    if i % samples == 0:
        print('epoch [{}/{}]'.format(i, epochs-samples))
        for j, model in enumerate(netlist):
            print('    model {} loss {:.5f}'.format(j, loss_list[j]))

# x data for plotting
xdata = Variable(torch.arange(0, 0.999, 0.001).unsqueeze_(1))
pred = torch.Tensor(nets, xdata.size(0), 1)
xdata_plot = xdata.data.numpy()

# get predictions for each network on 1000 x values (xdata)
for i, net in enumerate(netlist):
    pred[i] = net(xdata)
pred

Model built with h = 1
Model built with h = 2
Model built with h = 3
Model built with h = 4
Model built with h = 5
epoch [0/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 5.47783




epoch [100/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.25934
epoch [200/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.25034
epoch [300/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.24352
epoch [400/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.23554
epoch [500/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.22473
epoch [600/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.21042
epoch [700/900]
    model 0 loss 3.02213
    model 1 loss 1.02515
    model 2 loss 10.39546
    model 3 loss 4.01103
    model 4 loss 0.19263
epoch 

tensor([[[ 0.0478],
         [ 0.0478],
         [ 0.0478],
         ...,
         [ 0.0478],
         [ 0.0478],
         [ 0.0478]],

        [[ 0.2937],
         [ 0.2937],
         [ 0.2937],
         ...,
         [ 0.3529],
         [ 0.3531],
         [ 0.3533]],

        [[-0.5085],
         [-0.5084],
         [-0.5082],
         ...,
         [-0.3368],
         [-0.3365],
         [-0.3363]],

        [[-0.0301],
         [-0.0302],
         [-0.0303],
         ...,
         [-0.0566],
         [-0.0567],
         [-0.0567]],

        [[ 0.5331],
         [ 0.5336],
         [ 0.5340],
         ...,
         [ 0.3529],
         [ 0.3520],
         [ 0.3511]]], grad_fn=<CopySlices>)