In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.dataset import DNADataset
from model.origami_model import DNAOrigamiModel
from model.transformer import EncoderLayer, TransformerEncoder, TransformerDecoder, Transformer
# from model.origami_model import DNAOrigamiModel
from utils.logger import get_logger
from utils.parsing import parse_dna_origami_data
from utils.tokenize import tokenize_trajectory
from oxDNA_analysis_tools.UTILS.RyeReader import describe, get_confs, inbox
import numpy as np

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
trajectory_filepaths = ["dataset/data/trajectory.dat"]
topology_filepaths = ["dataset/data/output.top"]
dataset = DNADataset(trajectory_filepaths, topology_filepaths)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)



In [3]:
dataset[0]['nodes'].shape

torch.Size([4, 86, 13])

In [11]:
src = dataset.graphs[0]['nodes'][:-1]
n_confs = src.shape[0]
n_particles = src.shape[1]

src_mask = ~torch.ones((n_confs, n_particles), dtype=torch.bool)  # No masking, all positions attend to each other


tgt = dataset.graphs[0]['nodes'][1:]
tgt_mask = ~torch.ones((n_confs, n_particles), dtype=torch.bool)  # No masking, all positions attend to each other

In [14]:
cfg = {'num_layers':2, 'n_features':13, 'd_model': 52, 'nhead': 13, 'num_encoder_layers': 6, 'd_ff': 64, 'dropout_rate': 0.1, 'device': 'cpu'}
model = Transformer(cfg)

In [15]:
PATH = '/scratch/matthew/project_files/dnaOrigami/e3/dna_model.pth'
model.load_state_dict(torch.load(PATH, weights_only=True))
model.eval()

Transformer(
  (encoder): TransformerEncoder(
    (embed): Linear(in_features=13, out_features=52, bias=True)
    (layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (multi_head_attention): MultiHeadAttention(
          (attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=52, out_features=52, bias=True)
          )
        )
        (position_wise_feed_forward): PositionWiseFeedForward(
          (fc1): Linear(in_features=52, out_features=64, bias=False)
          (relu): ReLU()
          (fc2): Linear(in_features=64, out_features=52, bias=False)
        )
        (norm1): LayerNorm((52,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((52,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): TransformerDecoder(
    (embed): Linear(in_features=13, out_features=52, bias=True)
    (layers): ModuleList(
      (0-1): 2 x DecoderLayer(
        (multi_head

In [30]:
tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(n_particles)
out = model(src[:1], tgt[:1], src_mask[:1], tgt_mask)

In [41]:
out[0, : , :3]

tensor([[ -9.9639,  -0.5489,  -3.8592],
        [ -7.2686,  -0.7510,  -4.0224],
        [ -4.7048,  -1.3204,  -3.9317],
        [ -3.1997,  -0.6190,  -5.0667],
        [ -2.6330,  -0.2844,  -3.3280],
        [ -1.3469,  -1.7522,  -1.6005],
        [  0.3608,  -0.3806,  -0.5349],
        [ -0.6220,   0.5623,   1.1554],
        [ -0.3241,  -0.0930,   3.5268],
        [ -0.6207,   1.7325,   4.0328],
        [  0.9053,   3.2799,   4.0167],
        [  1.1267,   3.4154,   4.0021],
        [  1.2138,   4.8475,   5.0121],
        [  1.8436,   4.5764,   6.7801],
        [  1.6431,   4.4946,   7.2031],
        [  2.2147,   3.1833,   6.5277],
        [  1.5637,   4.0325,   3.8621],
        [  3.8150,   3.0382,   1.8404],
        [  3.0089,   1.1903,   1.7912],
        [  0.3119,   0.1833,   2.1237],
        [ -0.7573,   0.4903,  -0.4381],
        [ -2.7496,  -0.4629,  -2.1520],
        [  0.3005,  -2.7555,  -3.7222],
        [  2.3792,  -3.8261,  -5.7720],
        [  4.2390,  -3.0677,  -4.2808],


In [43]:
out[0, :,:3] -  tgt[0, :, :3]

tensor([[ 6.2275e-02,  3.4770e-02, -6.5256e-02],
        [-1.4011e-02,  9.8590e-02, -1.7092e-01],
        [ 4.2602e-02,  1.6367e-01, -1.0346e-01],
        [ 3.5018e-02,  1.5315e-02,  5.8325e-02],
        [ 1.2437e-01,  1.1889e-01, -1.9345e-03],
        [ 1.0256e-01,  1.3958e-01, -3.2908e-03],
        [ 7.4155e-02,  1.1223e-01,  5.1529e-02],
        [ 1.8275e-01,  2.5307e-01,  1.5151e-01],
        [ 2.0151e-01,  1.8690e-01,  1.0557e-01],
        [ 1.3850e-01,  2.2779e-01,  1.1862e-01],
        [ 1.5707e-01,  2.2676e-01,  2.3738e-01],
        [ 3.8461e-01,  3.1676e-01,  2.4974e-01],
        [ 2.5062e-01,  1.6195e-01,  3.2322e-01],
        [ 1.0090e-01,  1.9253e-01,  9.1698e-02],
        [ 2.3089e-01,  2.0330e-01,  2.1799e-01],
        [ 2.7736e-01,  3.2959e-01,  2.6697e-01],
        [ 2.9753e-01,  2.2457e-01,  3.9652e-01],
        [ 2.3723e-01,  2.3744e-01,  3.6148e-01],
        [ 4.1411e-01,  4.2813e-01,  3.8368e-01],
        [ 2.1950e-01,  1.9745e-01,  2.2556e-01],
        [ 2.5673e-01

In [46]:
import plotly.express as px
import numpy as np

x = tgt[0,:,0].detach().numpy()
y = tgt[0,:,1].detach().numpy()
z = tgt[0,:,2].detach().numpy()

# colors = np.random.uniform(0, 1, 86)

# Create interactive 3D scatter plot
fig = px.scatter_3d(
    x=x, y=y, z=z,
    # color=colors,
    title='Interactive 3D Scatter Plot',
    labels={'color': 'Point Index'},
    opacity=0.7
)

# Update layout for better interaction
fig.update_layout(
    scene = dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis'
    ),
    width=800,
    height=800
)

# Show plot
fig.show()