In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
import torch
from itertools import count

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
### If notebook not at root of knossos:
import os
# os.chdir('/home/t-mbruno/projects/knossos')  # Path to knossos repo
os.chdir('/home/t-salewi/knossos')

In [None]:
from rlo import factory
from rlo.expression import Expression
from rlo import expr_sets
from train_on_dataset import get_symtab_free_var_types
from rlo.dataset import StateValueDataset
from rlo.expression_util import make_toplevel
from rlo import sparser
from rlo.pipelines import graph_pipeline
from rlo.pipelines.training_pipeline import _spec_to_input
from rlo.torch_dataset import MultiEdgeTypeBatch


##### Manually specify configs for the torch and TF regressors

In [None]:
# Config obtained by running train_on_dataset.py and getting the config.json from the run
tf_config = {
    "eager": True,
    "scenario": "binding_simplify_astar",
    "run_id": "binding_simplify_astar_2021_05_27_09_52_35_17255",
    "gitlog": "41043ebc@t-salewi/pytorch-dummy+local_changes",
    "output_dir": "outputs",
    "force_gpu": False,
    "gpu_memory_fraction": None,
    "save_all_models": False,
    "num_parallel": 1,
    "dataset_path": "datasets/value_dataset.json",
    "dist_plots": None,
    "cost_bins": 10,
    "value_bins": 10,
    "node_bins": 10,
    "exprs_per_generation": 0,
    "use_subtree_match_edges": True,
    "num_propagations": 10,
    "nonlinear_messages": False,
    "aggregation_over_edge_types": "sum",
    "decoder_readout": "sum",
    "message_from_sender_receiver": False,
    "one_hot_embedding": False,
    "hidden_dim": 200,
    "output_hidden_dim": 200,
    "gamma": 0.1,
    "max_num_episodes_train": 4096,
    "max_num_episodes_eval": 100,
    "num_positive_examples": 10,
    "simulation_depth": 11,
    "maxing": "accumulator",
    "min_epochs": 10,
    "max_epochs": 30,
    "num_repetitions": 8,
    "graph_state_keep_prob": 0.5,
    "output_keep_prob": 0.5,
    "cost_normalization": None,
    "patience_epochs": 4,
    "num_generations": None,
    "total_train_time": 3600,
    "num_episode_clusters": 5,
    "template_path": None,
    "test_on_defs": None,
    "train_on_defs": None,
    "seed_all_reps": None,
    "loss": "pinball=0.9",
    "lr": 0.0001,
    "grad_clip_value": 0,
    "split": 0.9,
    "value_bin_splits": None,
    "time_bin_splits": None,
    "episode_bin_splits": None,
    "extra_plots": [],
    "v2": False,
    "verbose": True,
    "dataset_refiners": ["best_across_generations_refiner"],
    "train_search": "astar",
    "eval_search": "astar",
    "cost_per_step": None,
    "max_gnn_train": 1000,
    "max_gnn_eval": 1000,
    "search_batch_size": 16,
    "hybrid_merge_handling": "STOP",
    "hybrid_prob_rollout": 1.0,
    # "hybrid_alpha": inf,
    "alpha_test": 5.0,
    "init_alpha": 1.0,
    "alpha_scaling_factor": 1.1,
    "alpha_scaling_factor_fail": 1.0,
    "sparse_gnn": True,
    "tensorflow": True,
    "num_gnn_blocks": 1,
    "stacked_gnn_double_hidden": False,
    "max_nodes_per_batch": 10000,
    "cumsum": None,
    "two_value_func": None,
    "two_value_func_var_frac_train": None,
    "rules": "binding_simplify_rules",
    "test_exprs": "binding_simplify_expressions",
    "train_exprs": "binding_simplify_expressions",
    "oracle": True,
    "extra_scenario_params": "+decoder_readout:sum+max_epochs:30+loss:pinball=0.9+tensorflow:True",
    "result_save_path": "outputs/Run_binding_simplify_astar_2021_05_27_09_52_35_17255/0",
    "repetition": 0,
}



In [None]:
import copy
tf_config["graph_state_keep_prob"] = 0.999999999999
torch_config = copy.deepcopy(tf_config)
del torch_config['eager']
torch_config['tensorflow'] = False
assert tf_config['tensorflow'] == True

### Make the tensorflow and torch models

In [None]:
# Construct the two models that should be the same
tf_model = factory.regressor_from_config(tf_config)
torch_model = factory.regressor_from_config(torch_config)

##### Load the data into a framework-agnostic dataset

In [None]:
# Load some data
with open('datasets/value_dataset.json') as f:
    dataset = json.load(f)['data_points']

symtab, free_var_types = get_symtab_free_var_types(
        expr_sets.get_expression_set(torch_config["train_exprs"])
    )

dataset = StateValueDataset.build_from_triples(
    (
        t,
        make_toplevel(
            sparser.parse_expr(expr_str),
            symtab=symtab,
            free_var_types=free_var_types,
        ),
        v,
    )
    for t, expr_str, v in dataset
)

raw_examples = dataset.get_examples()

In [None]:
raw_examples[0]

##### Load the data-loading utilities for torch and tf

In [None]:
# Torch data converter
data_converter = factory.data_converter_from_config(torch_config)
torch_graph = data_converter.prepare_exprtab(raw_examples[0][0])
print(torch_graph)

In [None]:
# Tensorflow data converter
pipeline = graph_pipeline.SparsePipeline(use_subtree_match_edges=tf_config['use_subtree_match_edges'])
np_graph = pipeline.prepare_example(raw_examples[0][0])

### Check torch inputs and tf inputs are the same

In [None]:
# There are 9 edge types
assert len(np_graph.edge_indices) == len(torch_graph.edge_indices) == 9
# Torch and numpy edges should be the same, just transposed.
for np_edges, torch_edges in zip(np_graph.edge_indices, torch_graph.edge_indices):
    np.testing.assert_equal(np_edges.T, torch_edges.numpy())


## Summarise tf_model and its layers

In [None]:

# pipeline.batched_spec
inputs = tf.nest.map_structure(_spec_to_input, pipeline.batched_spec)
# print(inputs.keys())
tf_model.keras_model.build_and_compile(inputs=inputs)
tf_model.keras_model.built = True  # Otherwise we get ValueError. build_and_compile method says something about why it calls 'call' not 'build'
# tf_model.keras_model.summary()

def print_details(model):
    try:
        model.summary()
    except Exception:
        return
    if hasattr(model, 'layers'):
        for layer in model.layers:
            
            print_details(layer)
            
print_details(tf_model.keras_model)        
# tf_model.keras_model.layers[0].summary()

## Summarise torch_model

In [None]:
print(torch_model.model)


In [None]:
def num_parameters(m):
    return sum(np.prod(p.shape) for p in m.parameters())

print('Whole torch model', num_parameters(torch_model.model))
print('Torch encoder', num_parameters(torch_model.model.encoder))   # We have an extra 600 params in encoder
print('Torch regressor', num_parameters(torch_model.model.regressor))

### Print-out and compare shapes of parameter tensors for each layer

In [None]:
def summarise_tf_weights_shapes(weights):
    for weight in weights:
        print(f"{weight.name:<70}\t{weight.shape}")
        
def summarise_torch_weights_shapes(model: torch.nn.Module):
    for name, param in torch_model.model.named_parameters():
        print(f"{name:<70}\t{param.shape}")

In [None]:
print("-" * 10 + " Tensorflow " + "-" * 10)
summarise_tf_weights_shapes(tf_model.keras_model.weights)
print("-" * 10 + " Torch " + "-" * 10)
summarise_torch_weights_shapes(torch_model.model)

##### Helper function for getting keras parameters by name:

In [None]:
def get_param_by_name(keras_model, param_name):
    params = [x.numpy() for x in keras_model.weights if x.name == param_name]
    if len(params) == 0:
        raise ValueError(f"No such parameter {param_name} in model.")
    elif len(params) > 1:
        raise ValueError(f"There are multiple parameters matching name {param_name}")
    return params[0]

#### Make plots comparing parameter distributions in each layer

###### Embeddding layer

In [None]:
# Compare embeddings
embedding_weights_tf = get_param_by_name(tf_model.keras_model, "sparse_gnn_encoder/embedding/embeddings:0")
embedding_weights_torch = torch_model.model.encoder.embedding.weight.detach().numpy()
plt.hist(embedding_weights_tf.ravel(), alpha=0.5, label="tensorflow")
plt.hist(embedding_weights_torch.ravel(), alpha=0.5, label="Torch")
plt.legend()

In [None]:
# embedding = torch.nn.Embedding(num_embeddings=28, embedding_dim=200)
# torch.nn.init.xavier_uniform_(embedding.weight)
# plt.hist(embedding.weight.detach().numpy().ravel(), bins=40);

###### Message functions

In [None]:
num_edge_types = torch_model.model.encoder.gnn.gnn_blocks[0].num_edge_types
print(f'There are {num_edge_types} edge types')

In [None]:
# Compare message_functions
fig, axes = plt.subplots(ncols=2, nrows=num_edge_types, figsize=(12,3.5 * num_edge_types))

for i in range(num_edge_types):
    kernel_weights_tf = next(x.numpy() for x in tf_model.keras_model.weights if x.name == f"sparse_gnn_encoder/sparse_gnn/kernel{i}:0")
    kernel_bias_tf = next(x.numpy() for x in tf_model.keras_model.weights if x.name == f"sparse_gnn_encoder/sparse_gnn/bias:0")[i]
#     print(kernel_bias_torch)

    kernel_weights_torch = next(
        x.detach().numpy() for name, x in torch_model.model.named_parameters() if name == f"encoder.gnn.gnn_blocks.0.message_functions.{i}.weight")
    kernel_bias_torch = next(
        x.detach().numpy() for name, x in torch_model.model.named_parameters() if name == f"encoder.gnn.gnn_blocks.0.message_functions.{i}.bias")
    
    axes[i, 0].hist(kernel_weights_tf.ravel(), alpha=0.5, label="tensorflow", density=True)
    axes[i, 0].hist(kernel_weights_torch.ravel(), alpha=0.5, label="Torch", density=True)
    axes[i, 0].set_title("Message function weights")
    axes[i, 1].hist(kernel_bias_tf.ravel(), alpha=0.5, label="tensorflow", density=True)
    axes[i, 1].hist(kernel_bias_torch.ravel(), alpha=0.5, label="Torch", density=True)
    axes[i, 1].set_xlim(kernel_bias_torch.min(), kernel_bias_torch.max())
    axes[i, 1].set_title("Message function bias")
    axes[i, 1].legend()

In [None]:
# Is kernel bias in tensorflow == 0 for all?
kernel_bias_tf = next(x.numpy() for x in tf_model.keras_model.weights if x.name == f"sparse_gnn_encoder/sparse_gnn/bias:0")
print(np.all(kernel_bias_tf.ravel() == 0.))

###### Print histograms for all parameter tensors (TORCH)

In [None]:

num_param_tensors = len(list(torch_model.model.parameters()))
color = "black"
fig, axes = plt.subplots(ncols=num_param_tensors, nrows=1, figsize=(7 * num_param_tensors,5))

for i, (name, param) in enumerate(torch_model.model.named_parameters()):
    axes[i].hist(param.detach().numpy().ravel(), color=color, density=True)
    axes[i].set_title(name)

###### Print histograms for all parameter tensors (TENSORFLOW)

In [None]:
# Compare message_functions
num_param_tensors = len(list(tf_model.keras_model.weights))
color = "orange"
fig, axes = plt.subplots(ncols=num_param_tensors, nrows=1, figsize=(7 * num_param_tensors,5))

for i, weight in enumerate(tf_model.keras_model.weights):
    axes[i].hist(weight.numpy().ravel(), color=color, density=True)
    axes[i].set_title(weight.name)

### Specify matching sets of parameters:

In [None]:
# Helper to print out parameter name for copy pasting later in the code

def summarise_tf_weights_names(weights):
    for weight in weights:
        print(f"{weight.name:<70}")
        
def summarise_torch_weights_names(model: torch.nn.Module):
    for name, param in torch_model.model.named_parameters():
        print(f"{name:<70}")
        
summarise_tf_weights_names(tf_model.keras_model.weights)
print("-"*40)
summarise_torch_weights_names(torch_model.model)

In [None]:
name_mapping = {}  # map (torch_name -> (tf_name, extractor_fn))
from functools import partial
def do_nothing(x): return x
def transpose(x): return x.T
def getitem(x,i):
    return x[i]

for i in range(num_edge_types):
    name_mapping[f"encoder.gnn.gnn_blocks.0.message_functions.{i}.bias"] = (f"sparse_gnn_encoder/sparse_gnn/bias:0", partial(getitem, i=i))
    name_mapping[f"encoder.gnn.gnn_blocks.0.message_functions.{i}.weight"] = (f"sparse_gnn_encoder/sparse_gnn/kernel{i}:0", do_nothing)
name_mapping[f"encoder.embedding.weight"] = (f"sparse_gnn_encoder/embedding/embeddings:0", do_nothing)
name_mapping[f"regressor.mlp.1.weight"] = (f"gated_regression/out_layer/dense/kernel:0", transpose)
name_mapping[f"regressor.mlp.1.bias"] = (f"gated_regression/out_layer/dense/bias:0", do_nothing)
name_mapping[f"regressor.mlp.3.weight"] = (f"gated_regression/regression_transform/dense_2/kernel:0", transpose)

name_mapping[f"regressor.mlp.3.bias"] = (f"gated_regression/regression_transform/dense_2/bias:0", do_nothing)
name_mapping[f"regressor.gate.1.weight"] = (f"gated_regression/regression_gate/dense_1/kernel:0", transpose)
name_mapping[f"regressor.gate.1.bias"] = (f"gated_regression/regression_gate/dense_1/bias:0", do_nothing)
# RNN (ignore bias)
name_mapping["encoder.gnn.gnn_blocks.0.rnn.weight_ih"] = ("sparse_gnn_encoder/sparse_gnn/seeded_gru_cell/kernel:0", transpose)
name_mapping["encoder.gnn.gnn_blocks.0.rnn.weight_hh"] = ("sparse_gnn_encoder/sparse_gnn/seeded_gru_cell/recurrent_kernel:0", transpose)

In [None]:
def assign_tf_weights_to_torch_model(model: torch.nn.Module, keras_model):
    state_dict = model.state_dict()
    keys_old = set(state_dict.keys())
    keys_new = set(name_mapping.keys())
    print('torch keys not updated:', keys_old - keys_new)
    print('keys added:', keys_new - keys_old)
    for torch_name, (tf_name, fn) in name_mapping.items():
        state_dict[torch_name] = torch.FloatTensor(fn(get_param_by_name(keras_model, tf_name)))
    model.load_state_dict(state_dict)
    return



##### Assign TF parameters to torch model

In [None]:
assign_tf_weights_to_torch_model(torch_model.model, tf_model.keras_model)

In [None]:
# # Verify that some parameters are indeed the same
# Commented out because it is slow

# ncols=3
# fig, axes = plt.subplots(ncols=ncols, nrows=len(name_mapping)//ncols + 1,figsize=(50, 100))
# state_dict = torch_model.model.state_dict()

# for n, (torch_name, (tf_name, fn)) in enumerate(name_mapping.items()):
#     ax = axes[n//ncols][n%ncols]
#     ax.plot(state_dict[torch_name], fn(get_param_by_name(tf_model.keras_model, tf_name)),'.')
#     ax.set_title(torch_name, fontsize=20)
    


### Compare output distributions

In [None]:
# Make a torch graph
torch_graph = data_converter.prepare_exprtab(raw_examples[0][0])

# Make a tf graph (same data, different format)
np_graph = pipeline.prepare_example(raw_examples[0][0])
tf_input = tf.nest.map_structure(tf.convert_to_tensor, np_graph._asdict())
tf_input['node_type'] = tf_input['node_reps']
tf_input['node_row_splits'] = [0, len(tf_input['node_reps'])]
tf_input['adjacency'] = tf_input['edge_indices']
del tf_input['node_reps']
del tf_input['edge_indices']

print('tf_input', tf_input)
print('torch_input', torch_graph.x, torch_graph.edge_indices)

In [None]:
# Forward passes through tf_model and torch_model with the same inputs

tf_output = tf_model.keras_model(tf_input, training=False)

print('tf model outputs', tf_output)
print('mean abs', tf.math.reduce_mean(tf.math.abs(tf_output)))
torch_batch = MultiEdgeTypeBatch.from_data_list([torch_graph])

torch_model.model.eval()
torch_output = torch_model.model(torch_batch)
print('torch model outputs', torch_output, 'mean abs', torch.mean(torch.abs(torch_output)), ' std', torch_output.std())

plt.plot(torch_output.detach().numpy(), tf_output,'.')
plt.xlabel('torch')
plt.ylabel('tf')

def plot_cumdist(vals, label):
    tmp = np.sort(vals.numpy().ravel())
    n = len(tmp)
    plt.plot(np.arange(n)/n, tmp, label=label)
    plt.xlabel('cum prob')
    plt.ylabel('value')

    
def compare_cumdist(torch_vals, tf_vals, title):
    plt.figure()
    plot_cumdist(torch_vals.detach(), label='torch')
    plot_cumdist(tf_vals, label='tf')
    plt.title(title)
    plt.legend()

compare_cumdist(torch_output, tf_output, 'Full-model outputs')

In [None]:
torch_gnn_output = torch_model.model.encoder(x=torch_batch.x, edge_indices=torch_batch.edge_indices)
tf_gnn_output = tf_model.keras_model.encoder(tf_input)
compare_cumdist(torch_gnn_output, tf_gnn_output, 'Encoder only')

In [None]:
# Compare just the Embedding at the input to the GNN

tf_output = tf_model.keras_model.encoder.initial_node_embedding(tf_input['node_type'], training=False)
torch_output = torch_model.model.encoder.embedding(torch_batch.x)
compare_cumdist(torch_output, tf_output, 'Embedding only')

In [None]:

random_regressor_input = np.random.rand(5, 400).astype(np.float32)
torch_input = torch.tensor(random_regressor_input, dtype=torch.float)
torch_output = torch_model.model.regressor(torch_input, 
                                           batch_assignment = torch.tensor([0,0,0,0,0]))
tf_output = tf_model.keras_model.regression(random_regressor_input, [0, 5])
compare_cumdist(torch_output, tf_output, 'Regressor only')
tf_regressor = tf_model.keras_model.regression
torch_regressor = torch_model.model.regressor


tf_gated_outputs = tf_regressor.regression_gate(random_regressor_input) * tf_regressor.regression_transform(
            tf_regressor.out_layer(random_regressor_input)
        )
torch_gated_outputs = torch_regressor.gate(torch_input) * torch_regressor.mlp(torch_input)

compare_cumdist(torch_gated_outputs,tf_gated_outputs,  'pre-pooling')

tf_gated_outputs = tf_regressor.regression_gate(random_regressor_input) 
torch_gated_outputs = torch_regressor.gate(torch_input) 
compare_cumdist(torch_gated_outputs, tf_gated_outputs, 'gates')


torch_x = torch_regressor.mlp(torch_input)
tf_x = tf_regressor.regression_transform(
            tf_regressor.out_layer(random_regressor_input)
        )
compare_cumdist(torch_x, tf_x, 'mlp')
# Plot the model outputs

### GRU cells

In [None]:
%matplotlib inline
from rlo.model.gru import GRU

tf_gru = tf_model.keras_model.encoder.gnn.gru_cell
print('TF gru cell has type', type(tf_gru), 'dropout', tf_gru.dropout)


# Examine properties of the tf_gru
# These determine what happens in 'call' method of this class
print('TF GRU settings:')
for k in ['use_bias', 'reset_after', 'implementation', 'dropout', 'activation', 'recurrent_activation']:
    if not k.startswith('_'):
        try:
            print(f"{k:<20}", getattr(tf_gru, k))
        except Exception:
            print('cannot print', k)

            
torch_gru = torch_model.model.encoder.gnn.gnn_blocks[0].rnn
print('Torch GRU cell has type', type(torch_gru), 'dropout', torch_gru._dropout)

# Define some random inputs
gru_input_np = np.random.rand(5,200).astype(np.float32) + np.arange(5)[:, None].astype(np.float32)
gru_hidden_np = np.random.rand(5,200).astype(np.float32) + np.arange(5)[:, None].astype(np.float32)

# COmpare output on the random inputs
tf_output, _ = tf_gru(inputs=gru_input_np, states=gru_hidden_np)
torch_output = torch_gru(input=torch.tensor(gru_input_np), hidden=torch.tensor(gru_hidden_np))
compare_cumdist(torch_output, tf_output, 'gru cell')

# New GRU is a torch copy of tensorflow GRU behaviour
new_gru = GRU(input_size = torch_gru.input_size, weight_ih = torch_gru.weight_ih, weight_hh = torch_gru.weight_hh)
new_torch_output = new_gru(inputs=torch.tensor(gru_input_np), states=torch.tensor(gru_hidden_np))

# ... and indeed it gives same outputs, unlike the old one
compare_cumdist(new_torch_output, tf_output, 'new torch gru cell')
plt.figure()
plt.plot(torch_output.detach().numpy().ravel(), tf_output.numpy().ravel(),'.', label='old torch')
plt.plot(new_torch_output.detach().numpy().ravel(), tf_output.numpy().ravel(), '.', label='new torch')
plt.legend()
plt.xlabel('torch'); plt.ylabel('tf')
plt.title('TF output vs torch output')




            

### Replace torch GRU with something just like the tensorflow GRU in the GNN encoder

In [None]:
assert len(torch_model.model.encoder.gnn.gnn_blocks) ==1
torch_model.model.encoder.gnn.gnn_blocks[0] = new_gru
old_torch_gnn_output = torch_gnn_output
torch_gnn_output = torch_model.model.encoder(x=torch_batch.x, edge_indices=torch_batch.edge_indices)
tf_gnn_output = tf_model.keras_model.encoder(tf_input)
compare_cumdist(torch_gnn_output, tf_gnn_output, 'Encoder only, with new GRU')

# Compare tf vs. torch gradients for GNN (encoder)

In [None]:

torch_model.model.encoder.zero_grad()
torch_gnn_loss  = torch_gnn_output.sum()
torch_gnn_loss.backward()
torch_embed_gradients = torch_model.model.encoder.embedding.weight.grad

with tf.GradientTape() as tape:
    tf_gnn_output = tf_model.keras_model.encoder(tf_input)
    tf_gnn_loss = tf.math.reduce_sum(tf_gnn_output)

dir(tf_model.keras_model.encoder.initial_node_embedding)
slices = tape.gradient(tf_gnn_loss, tf_model.keras_model.encoder.initial_node_embedding.weights)

tf_grads = np.zeros((28,200))
print(tf_grads.shape)
print(slices[0].values.shape)
for i, val in zip(slices[0].indices, slices[0].values):
    tf_grads[i] = val.numpy()


compare_cumdist(torch_embed_gradients, tf.convert_to_tensor(tf_grads), 'embedding gradients')

In [None]:
torch_output