In [17]:
import shelve
from pathlib import Path

import torch
from torch_scatter import scatter
from torch_cluster import knn_graph

from utils.dataset import UnclusteredProteinChainDataset, idealize_backbone_coords
from utils.utility_functions import compute_residue_frames, compute_center_of_mass, center_coords

In [18]:
file_path = Path('.').resolve()
params = {
    'dataset_shelve_path': file_path / 'dataset' / 'dataset_shelve',
    'metadata_shelve_path': file_path / 'dataset' / 'metadata_shelve',
}

In [19]:
dataset = UnclusteredProteinChainDataset(params['dataset_shelve_path'], params['metadata_shelve_path'])

able_chain_A, _ = dataset[dataset.chain_key_to_index['6w70_1-A-A']]
bb_coords = able_chain_A[('A', 'A')]['backbone_coords']
phi_psi_angles = able_chain_A[('A', 'A')]['phi_psi_angles']

ideal_bb_coords = idealize_backbone_coords(bb_coords, phi_psi_angles)
ca_com = compute_center_of_mass(ideal_bb_coords[:, 1])
centered_ideal_bb_coords = center_coords(ideal_bb_coords, ca_com)

frames = compute_residue_frames(centered_ideal_bb_coords)

### Visualize rotation frames

In [20]:
import plotly.express as px
from plotly import graph_objects as go

fig = px.scatter_3d(x=centered_ideal_bb_coords[:, 1, 0], y=centered_ideal_bb_coords[:, 1, 1], z=centered_ideal_bb_coords[:, 1, 2])

fig.add_trace(
    go.Scatter3d(
        x=centered_ideal_bb_coords[:, 0, 0],
        y=centered_ideal_bb_coords[:, 0, 1],
        z=centered_ideal_bb_coords[:, 0, 2],
        mode='markers',
        marker=dict(size=5, color='black'),
        hovertext='N'
    )
)
fig.add_trace(
    go.Scatter3d(
        x=centered_ideal_bb_coords[:, 3, 0],
        y=centered_ideal_bb_coords[:, 3, 1],
        z=centered_ideal_bb_coords[:, 3, 2],
        mode='markers',
        marker=dict(size=5, color='black'),
        hovertext='C'
    )
)

fig.update_traces(marker=dict(size=2.5, color='black'))
colors = ['red', 'green', 'blue']

for frame_idx in range(centered_ideal_bb_coords.shape[0]):
    for dim in range(3):
        color = colors[dim]
        offset = centered_ideal_bb_coords[frame_idx, 1]
        fig.add_trace(
            go.Scatter3d(
                x=[offset[0], frames[frame_idx, 0, dim] + offset[0]],
                y=[offset[1], frames[frame_idx, 1, dim] + offset[1]],
                z=[offset[2], frames[frame_idx, 2, dim] + offset[2]],
                mode='lines',
                line=dict(color=color, width=5)
            )
        )

fig.update_layout(showlegend=False, width=1000, height=800)
fig.show()

### Test global to local frame superposition

In [21]:
from utils.rigid_utils import Rigid, Rotation
ofold_rigid_objs = Rigid(rots=Rotation(frames), trans=centered_ideal_bb_coords[:, 1])

pre = torch.tensor([0.0, 2.0, 0.0])
out = ofold_rigid_objs.apply(pre)[0]

In [26]:
fig = px.scatter_3d(
    x=centered_ideal_bb_coords[0, 1, 0].unsqueeze(0), 
    y=centered_ideal_bb_coords[0, 1, 1].unsqueeze(0), 
    z=centered_ideal_bb_coords[0, 1, 2].unsqueeze(0),
)

fig.add_trace(
    go.Scatter3d(
        x=pre[0].unsqueeze(0),
        y=pre[1].unsqueeze(0),
        z=pre[2].unsqueeze(0),
        mode='markers',
        hovertext='original'
    )
)

colors = ['red', 'green', 'blue']

fig.add_trace(
    go.Scatter3d(
        x=out[0].unsqueeze(0),
        y=out[1].unsqueeze(0),
        z=out[2].unsqueeze(0),
        mode='markers',
        marker=dict(size=2.5, color='blue'),
        hovertext='transformed'
    )
)

for frame_idx in range(1):
    for dir in range(3):
        color = colors[dir]
        offset = centered_ideal_bb_coords[frame_idx, 1]
        fig.add_trace(
            go.Scatter3d(
                x=[offset[0], frames[frame_idx, 0, dir] + offset[0]],
                y=[offset[1], frames[frame_idx, 1, dir] + offset[1]],
                z=[offset[2], frames[frame_idx, 2, dir] + offset[2]],
                mode='lines',
                line=dict(color=color, width=5),
                name=f'frame: {frame_idx}',
                hoverinfo='name'
            )
        )

bases = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
for dir in range(3):
    color = colors[dir]
    offset = centered_ideal_bb_coords[frame_idx, 1]
    fig.add_trace(
        go.Scatter3d(
            x=[0, bases[dir][0]],
            y=[0, bases[dir][1]],
            z=[0, bases[dir][2]],
            mode='lines',
            line=dict(color=color, width=5),
            name=f'origin',
            hoverinfo='name'
        )
    )

fig.update_traces(marker=dict(size=5))
fig.update_layout(
    showlegend=False, 
    width=1000, 
    height=800,
    scene=dict(
        xaxis=dict(range=[-11, 15]),
        yaxis=dict(range=[-20, 6]),
        zaxis=dict(range=[-1, 25])
    )
)
fig.show()

### Test local to global frame transformation

In [23]:
frames[0], frames[0, :, 1]

(tensor([[ 0.9988,  0.0154, -0.0473],
         [ 0.0311,  0.5498,  0.8347],
         [ 0.0389, -0.8352,  0.5486]]),
 tensor([ 0.0154,  0.5498, -0.8352]))

In [27]:
from utils.rigid_utils import Rigid, Rotation
ofold_rigid_objs = Rigid(rots=Rotation(frames), trans=centered_ideal_bb_coords[:, 1])

pre = (3.0 * frames[0, :, 1]) + centered_ideal_bb_coords[0, 1]
out = ofold_rigid_objs.invert_apply(pre)[0]

In [28]:
fig = px.scatter_3d(
    x=centered_ideal_bb_coords[0, 1, 0].unsqueeze(0), 
    y=centered_ideal_bb_coords[0, 1, 1].unsqueeze(0), 
    z=centered_ideal_bb_coords[0, 1, 2].unsqueeze(0),
)

fig.add_trace(
    go.Scatter3d(
        x=pre[0].unsqueeze(0),
        y=pre[1].unsqueeze(0),
        z=pre[2].unsqueeze(0),
        mode='markers',
        hovertext='original'
    )
)


fig.update_traces(marker=dict(size=2.5, color='black'))
colors = ['red', 'green', 'blue']

fig.add_trace(
    go.Scatter3d(
        x=out[0].unsqueeze(0),
        y=out[1].unsqueeze(0),
        z=out[2].unsqueeze(0),
        mode='markers',
        marker=dict(size=2.5, color='blue'),
        hovertext='transformed'
    )
)

for frame_idx in range(1):
    for dir in range(3):
        color = colors[dir]
        offset = centered_ideal_bb_coords[frame_idx, 1]
        fig.add_trace(
            go.Scatter3d(
                x=[offset[0], frames[frame_idx, 0, dir] + offset[0]],
                y=[offset[1], frames[frame_idx, 1, dir] + offset[1]],
                z=[offset[2], frames[frame_idx, 2, dir] + offset[2]],
                mode='lines',
                line=dict(color=color, width=5),
                name=f'frame: {frame_idx}',
                hoverinfo='name'
            )
        )

bases = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
for dir in range(3):
    color = colors[dir]
    offset = centered_ideal_bb_coords[frame_idx, 1]
    fig.add_trace(
        go.Scatter3d(
            x=[0, bases[dir][0]],
            y=[0, bases[dir][1]],
            z=[0, bases[dir][2]],
            mode='lines',
            line=dict(color=color, width=5),
            name=f'origin',
            hoverinfo='name'
        )
    )

fig.update_layout(
    showlegend=False, 
    width=1000, 
    height=800,
    scene=dict(
        xaxis=dict(range=[-11, 15]),
        yaxis=dict(range=[-20, 6]),
        zaxis=dict(range=[-1, 25])
    )
)
fig.show()