In [1]:
%matplotlib notebook

In [70]:
import torch 
import numpy
import time
from src.data_ops.load_dataset import load_train_dataset
from src.data_ops.utils import pad_tensors

In [3]:
filename = 'casp11'
n_train = 100
n_valid = 100
redo = False
train, valid = load_train_dataset('data/proteins', filename, n_train, n_valid, redo)



In [81]:
import torch
from src.data_ops.wrapping import wrap

def pad_tensors(tensor_list):
    data = tensor_list

    data_dim = data[0].size()[-1]

    seq_lengths = [len(x) for x in data]
    max_seq_length = max(seq_lengths)
    invariant_data_size = data[0].size()[2:]
    #print(invariant_data_size)
    if len(invariant_data_size) > 0:
        extra_dims = invariant_data_size
        padded_data = torch.zeros(len(data), max_seq_length, data_dim, *extra_dims)
    else:
        padded_data = torch.zeros(len(data), max_seq_length, data_dim)
    for i, x in enumerate(data):
        len_x = len(x)
        padded_data[i, :len_x, :data_dim] = x
    padded_data = wrap(padded_data)

    mask = torch.ones(len(data), max_seq_length, max_seq_length)
    for i, x in enumerate(data):
        seq_length = len(x)
        if seq_length < max_seq_length:
            mask[i, seq_length:, :].fill_(0)
            mask[i, :, seq_length:].fill_(0)
    mask = wrap(mask)

    return (padded_data, mask)


In [4]:
import matplotlib.pyplot as plot

In [5]:
idx = 10

In [6]:
X = train[idx][0][:,:20]

In [7]:
Y = train[idx][1]

In [9]:
proteins = train.proteins
proteins[idx].pdb_id

In [351]:
Y[100][:,1]

array([  194.4,  2056.8,  1218.5])

In [45]:
Z = numpy.zeros((X.shape[0], X.shape[0]))
for i in range(X.shape[0]):
    for j in range(i, X.shape[0]):
        z_ = numpy.sqrt(((Y[i].mean(1) - Y[j].mean(1))**2).sum())
        #Z[i/j] = 1./z_
        if z_ > 0.:
            Z[i,j] =1./z_
#         Z[i,j] =numpy.sqrt(((Y[i].mean(1) - Y[j].mean(1))**2).sum())


In [75]:
 sigma2 = 1e-10

In [206]:

def compute_adjacency2(coords, p=2.5):
    bs, n_vertices, n_atoms, space_dim = coords.shape

    coords_i = coords.view(bs, 1, n_vertices, n_atoms, space_dim).repeat(1, n_vertices, 1, 1, 1)
    coords_j = coords.view(bs, n_vertices, 1, n_atoms, space_dim).repeat(1, 1, n_vertices, 1, 1)
    dij = ((coords_i - coords_j).mean(3)**2)
    dij = torch.sqrt(dij.sum(-1))**p

    dij = 1/dij
    #dij = torch.exp(dij)
    return dij

In [352]:
def compute_adjacency(coords):
    bs, n_vertices, n_atoms, space_dim = coords.shape
    
    coords_i = coords.view(bs, 1, n_vertices, n_atoms, space_dim).repeat(1, n_vertices, 1, 1, 1)
    coords_j = coords.view(bs, n_vertices, 1, n_atoms, space_dim).repeat(1, 1, n_vertices, 1, 1)
    dij = ((coords_i - coords_j)[:,:,:,:,1]**2)
    dij = torch.sqrt(dij.sum(-1))
    
    #dij = (dij < threshold).float()
   
    #dij = -dij/(2*sigma**2)
    #dij = torch.exp(dij)
    return dij

In [312]:
from src.data_ops.proteins.preprocessing import string_vectorizer

In [344]:
def make_mask(masktext):
    #print(masktext)
    v = string_vectorizer(masktext, ['+'])
    v = torch.Tensor(v)
    mask = v.view(len(v), 1)
    matrix_mask = torch.mm(mask,torch.transpose(mask, 0, 1))
    return matrix_mask
    #print(v)
    

In [354]:
idx = 1
coords = [torch.from_numpy(p.tertiary) for p in train.proteins[idx:2*idx]]
masks = [make_mask(p.mask[0]) for p in train.proteins[idx:2*idx]]

In [355]:
def contact_map(dij, threshold=8):
    return (dij < threshold).float()

In [356]:
coords, mask = pad_tensors(coords)
coords = coords.data

In [399]:
#Z1, t1 = compute_adjacency(coords[:, :10])
Z = compute_adjacency(coords)
Z = contact_map(Z, threshold=800)

In [400]:
print(Z[0,:5,:5])
#Z2.max(), Z2.min()


 1  1  1  1  0
 1  1  1  1  0
 1  1  1  1  0
 1  1  1  1  0
 0  0  0  0  1
[torch.FloatTensor of size 5x5]



In [401]:
for z in Z:
    plot.figure()
    plot.imshow(z)
    plot.show()
    plot.colorbar()

<IPython.core.display.Javascript object>

In [346]:
for mask in masks:
    plot.figure()
    plot.imshow(mask)
    plot.show()
    plot.colorbar()

<IPython.core.display.Javascript object>