In [10]:
%matplotlib ipympl
import matplotlib
import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d.art3d import Line3DCollection
import sys
sys.path.append('../../..')
import os 
import pickle

import torch 

In [11]:
from src.proteins.data_ops import get_train_data_loader
from src.proteins.models import MODEL_DICT
from src.proteins.visualization.visualize_spatial import visualize_spatial
from src.proteins.visualization.color_dict import COLOR_DICT
from src.proteins.visualization.spatial_weights_report import spatial_weights_report
from src.utils import load_model

In [12]:
HOME = '/Users/isaachenrion/x/research/'

In [13]:
filename = os.path.join(HOME, 'model_store', '2')
model, _ = load_model(MODEL_DICT, filename)

In [14]:
data_dir = os.path.join(HOME, 'graphs', 'data', 'proteins/pdb25/small')
n_train = -1 
n_valid = 100
batch_size = 1
_, data_loader = get_train_data_loader(data_dir, n_train, n_valid, batch_size)

In [15]:

def visualize_spatials(spatials, string_sequence):
    
    print('{}'.format(string_sequence))
    color_sequence = list(map(lambda x: [y/255.0 for y in COLOR_DICT[x]], string_sequence))
    fig = plt.figure(figsize=(14,10))
    row = 0
    col = 0
    for i, spatial in enumerate(spatials):
        row += 1
        if row == 3:
            row = 0
            col += 1
        ax = fig.add_subplot(2, 3, i+1, projection='3d')
            
        x = spatial[:,0]
        y = spatial[:,1]
        z = spatial[:,2]

        points = np.array([x, y, z]).T.reshape(-1, 1, 3)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)

        lc = Line3DCollection(segments, colors=color_sequence)
        lc.set_linewidth(3)

        ax.set_xlim(min(x), max(x))
        ax.set_ylim(min(y), max(y))
        ax.set_zlim(min(z), max(z))
        ax.add_collection3d(lc, zs=z, zdir='z')

In [16]:
spatials = []
for i, batch in enumerate(data_loader):
    if i == 0:
        string_sequence = data_loader.dataset.proteins[i].string_sequence
        spatials = []
        (x, y, y_mask, batch_mask) = batch
        x = model.initial_embedding(x)
        for nmp in model.nmp_blocks:
            s = nmp.spatial_embedding(x)
            spatials.append(s.data.numpy()[0])
            x = nmp(x, batch_mask)

        s = model.final_spatial_embedding(x)
        spatials.append(s.data.numpy()[0])
        visualize_spatials(spatials, string_sequence)

        
        
    

MQQSLAVKTFEDLFAELGDRARTRPADSTTVAALDGGVHALGKKLLEEAGEVWLAAEHESNDALAEEISQLLYWTQVLMISRGLSLDDVYRKL


FigureCanvasNbAgg()

In [8]:
from sklearn.decomposition import PCA
def weights_pca_report(U_k):
    X = U_k.data.numpy()
    pca = PCA(n_components=3)
    pca.fit(X)
    
    out_str = ''
    for i, (perc, comp) in enumerate(zip(pca.explained_variance_ratio_, pca.components_)):
        
        comp_string = ', '.join(['{:.1f}'.format(i) for i in comp])
        out_str += '\nComponent [{}] has weight {:.2f}'.format(comp_string, perc)
    return out_str

In [9]:
print(spatial_weights_report(model))

Layer 0
Component [-0.2, -0.5, -0.8] has weight 0.67
Component [0.9, 0.2, -0.4] has weight 0.17
Component [-0.4, 0.8, -0.5] has weight 0.15
Layer 1
Component [-0.8, -0.1, -0.6] has weight 0.64
Component [-0.6, 0.3, 0.8] has weight 0.32
Component [0.1, 1.0, -0.2] has weight 0.04
Layer 2
Component [0.8, 0.5, 0.2] has weight 0.90
Component [-0.6, 0.6, 0.5] has weight 0.06
Component [-0.1, 0.5, -0.8] has weight 0.04
Layer 3
Component [-0.5, -0.6, -0.6] has weight 0.88
Component [-0.8, 0.0, 0.6] has weight 0.08
Component [-0.4, 0.8, -0.5] has weight 0.03
Layer 4
Component [-0.2, 0.6, 0.8] has weight 0.35
Component [-0.4, 0.7, -0.6] has weight 0.33
Component [0.9, 0.4, -0.1] has weight 0.31

