In [56]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt 
from transformer import RT
import seaborn as sns 

from checkpointer import Checkpointer


In [57]:

from models.mpnn import AlignedMPNN
import jax 
import jax.numpy as jnp
import haiku as hk
from pathlib import Path
MODEL_DIR = Path(Path.cwd(), "trained_models")
MODEL_DIR.mkdir(exist_ok=True, parents=True)

In [75]:
import os
from enum import StrEnum

import numpy as np


class DatasetPath(StrEnum):
    TRAIN_PATH = "dataset/train"
    VALIDATION_PATH = "dataset/validation"
    TEST_PATH = "dataset/test"


def load_batch(batch_path: str):
    input_node_features = np.load(
        os.path.join(batch_path, "input_node_features.npy")
    )
    input_edge_features = np.load(
        os.path.join(batch_path, "input_edge_features.npy")
    )
    input_graph_features = np.load(
        os.path.join(batch_path, "input_graph_features.npy")
    )
    input_adjacency_matrix = np.load(
        os.path.join(batch_path, "input_adjacency_matrix.npy")
    )
    input_hidden_node_features = np.load(
        os.path.join(batch_path, "input_hidden_node_features.npy")
    )
    input_hidden_edge_features = np.load(
        os.path.join(batch_path, "input_hidden_edge_features.npy")
    )
    
    node_features_all_layers = []

    for i in range(3):
        node_features = np.load(
            os.path.join(batch_path, f"out_node_features_{i}.npy")
        )
        node_features_all_layers.append(node_features)

    out_edge_features = np.load(
        os.path.join(batch_path, "out_edge_features.npy")
    )

    return (
        (
            input_node_features,
            input_edge_features,
            input_graph_features,
            input_adjacency_matrix,
            input_hidden_node_features,
            input_hidden_edge_features,
        ),
        node_features_all_layers,
        out_edge_features,
    )


def dataloader(dataset_path: DatasetPath):
    batch_dirs = [
        os.path.join(dataset_path, d)
        for d in sorted(os.listdir(dataset_path))
        if os.path.isdir(os.path.join(dataset_path, d))
    ]
    for batch_dir in batch_dirs:
        yield load_batch(batch_dir)


In [76]:
sns.set(style='whitegrid')
sns.set_context('notebook')

In [77]:
train_dataloader = dataloader(DatasetPath.TRAIN_PATH)
(
    (
        input_node_features,
        input_edge_features,
        input_graph_features,
        input_adjacency_matrix,
        input_hidden_node_features,
        input_hidden_edge_features,
    ),
    transformer_node_features_all_layers,
    transformer_edge_embedding,
) = next(train_dataloader)

(4, 16, 16, 192)

In [150]:
def model_fn(node_fts, edge_fts, graph_fts, adj_mat, hidden, e_hidden):
    model = RT(
        nb_layers=3,
        nb_heads=12,
        vec_size=192,
        node_hid_size=32,
        edge_hid_size_1=16,
        edge_hid_size_2=8,
        graph_vec="att",
        disable_edge_updates=True,
        save_emb_sub_dir="",
        save_embeddings="",
        name="rt"
    )
    return model(node_fts, edge_fts, graph_fts, adj_mat, hidden, e_hidden=e_hidden)


model = hk.without_apply_rng(hk.transform(model_fn))

parameters = model.init(
    jax.random.PRNGKey(42),
    node_fts=input_node_features,
    edge_fts=input_edge_features,
    graph_fts=input_graph_features,
    adj_mat=input_adjacency_matrix,
    hidden=input_hidden_node_features,
    e_hidden=input_hidden_edge_features,
)


In [151]:
parameters.keys()

dict_keys(['rt/linear', 'rt/linear_1', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_3', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_4', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_5', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_1', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_2', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_6', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_7', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_8', 'rt/rt_layer0/rt_layer0/linear', 'rt/rt_layer0/rt_layer0/layer_norm', 'rt/rt_layer0/rt_layer0/linear_1', 'rt/rt_layer0/rt_layer0/linear_2', 'rt/rt_layer0/rt_layer0/layer_norm_1', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear_3', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear_4', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear_5', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear_1', 'rt/rt_layer1/rt_layer1/rt_att

In [152]:
ckpt = Checkpointer(f"{MODEL_DIR}/rt_jarvis_march.pkl")#aligned_mpnn_all_layers.pkl")
loaded_parameters = ckpt.load()

In [153]:
def _filter_processor(params: hk.Params) -> hk.Params:
    return hk.data_structures.filter(
        lambda module_name, n, v: "rt" in module_name, params
    )

In [154]:
loaded_parameters["params"] = _filter_processor(loaded_parameters["params"])

In [155]:
loaded_parameters.keys()

dict_keys(['opt_state', 'params'])

In [156]:
loaded_parameters["params"] = {
    f"{k[4:]}": v for k, v in loaded_parameters["params"].items()
}

In [157]:
loaded_parameters["params"].keys()

dict_keys(['rt/linear', 'rt/linear_1', 'rt/rt_layer0/rt_layer0/layer_norm', 'rt/rt_layer0/rt_layer0/layer_norm_1', 'rt/rt_layer0/rt_layer0/linear', 'rt/rt_layer0/rt_layer0/linear_1', 'rt/rt_layer0/rt_layer0/linear_2', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_1', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_2', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_3', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_4', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_5', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_6', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_7', 'rt/rt_layer0/rt_layer0/rt_attention_layer/linear_8', 'rt/rt_layer1/rt_layer1/layer_norm', 'rt/rt_layer1/rt_layer1/layer_norm_1', 'rt/rt_layer1/rt_layer1/linear', 'rt/rt_layer1/rt_layer1/linear_1', 'rt/rt_layer1/rt_layer1/linear_2', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear', 'rt/rt_layer1/rt_layer1/rt_attention_layer/linear_1', 'rt/rt_layer1

In [158]:
parameters.keys() == loaded_parameters["params"].keys()

True

In [214]:
import optax

(
    (
        input_node_features,
        input_edge_features,
        input_graph_features,
        input_adjacency_matrix,
        input_hidden_node_features,
        input_hidden_edge_features,
    ),
    transformer_node_features_all_layers,
    transformer_edge_embedding,
) = next(train_dataloader)

optimizer = optax.adam(0.001)
optimizer_state = optimizer.init(loaded_parameters)

out_transformer_node_features, out_transformer_edge_embedding, _ = model.apply(
    loaded_parameters["params"],
    jax.random.PRNGKey(42),
    input_node_features,
    input_edge_features,
    input_graph_features,
    input_adjacency_matrix,
    input_hidden_node_features,
    input_hidden_edge_features,
)


    
    

In [215]:
print(not jnp.array_equal(input_node_features, prev_node_features))

print(jnp.array_equal(out_transformer_node_features, transformer_node_features_all_layers[-1]))
print(jnp.array_equal(out_transformer_edge_embedding, transformer_edge_embedding))
prev_node_features = input_node_features

True
True
True


In [204]:
input_node_features

array([[[-2.7629688 ,  0.3663008 ,  1.3824675 , ..., -3.969361  ,
         -2.6140122 , -1.1244776 ],
        [ 1.1017193 , -0.45300224,  0.06967385, ..., -0.33809122,
         -1.7713001 , -0.27791956],
        [-0.85767794,  0.89120144, -0.8950694 , ..., -0.30672598,
         -1.1762357 ,  1.192628  ],
        ...,
        [ 0.0547842 ,  0.0309753 ,  0.95297354, ..., -0.987455  ,
          0.9850799 , -0.6780912 ],
        [-0.7856734 ,  0.23537533,  1.1660761 , ..., -0.66634315,
          3.2128007 , -0.82988054],
        [-0.04846592, -0.6029063 ,  2.167494  , ..., -0.4702959 ,
          4.8446207 , -2.1364338 ]],

       [[ 0.26297855, -2.515852  ,  4.258234  , ..., -3.3750892 ,
          0.773173  , -5.109954  ],
        [-2.6042683 ,  2.3243716 , -2.2892258 , ..., -0.354484  ,
         -2.0560188 ,  3.0930686 ],
        [-1.854869  ,  1.1578159 , -0.78675157, ...,  0.14333628,
          1.2153314 ,  1.1349628 ],
        ...,
        [-0.9993882 ,  1.4985843 , -0.76536715, ..., -