# Colab

In [None]:
# from google.colab import drive
# drive.mount('/content/gdrive')

# %cd gdrive/My\ Drive/Colab\ Notebooks/gnn/sudoku-test

# !rm -r ./tf2-gnn
# !git clone --branch sudoku https://github.com/oliverdutton/tf2-gnn.git 

# !pip install ./tf2-gnn

# Imports

In [None]:
import tensorflow as tf
import tf2_gnn as gnn
import numpy as np
from dpu_utils.utils import RichPath

In [None]:
# %load_ext autoreload
# %autoreload 2

# Dataset

### Create dataset structure to load data into

In [None]:
params = gnn.data.PSSDataset.get_default_hyperparameters()
params['connectivity'] = 3
params

In [None]:
params['max_nodes_per_batch'] = 1000000

In [None]:
dataset = gnn.data.PSSDataset(params)

### Load the data

In [None]:
!pwd

In [None]:
# path = RichPath.create('./data/')
path = RichPath.create('/Users/personal/Documents/Sudoku/data/pss-data/')

In [None]:
# dataset.load_data(path, folds_to_load=[gnn.DataFold.TRAIN])
dataset.load_data(path)

In [None]:
tf_dataset = dataset.get_tensorflow_dataset(gnn.DataFold.TRAIN, use_worker_threads=False)

# Build model

In [None]:
params = gnn.models.NodeMulticlassTask.get_default_hyperparameters('gnn_edge_mlp')
params['gnn_hidden_dim'] = 16
params['gnn_num_edge_MLP_hidden_layers'] = 2
params['gnn_num_layers'] = 2
params['gnn_share_weights_between_mlps'] = False
params["gnn_message_activation_function"] = "gelu"

params['gnn_dense_every_num_layers'] = -1
params['gnn_residual_every_num_layers'] = 1e5
params['gnn_global_exchange_every_num_layers'] = 1e5
params['use_intermediate_gnn_results'] = False
params['loss_at_every_layer'] = False
params

In [None]:
model = gnn.models.PSSTask(params,dataset)

In [None]:
# Build model
input_shapes = dataset.get_batch_tf_data_description().batch_features_shapes
model.build(input_shapes)

# Inspect model
# ly = model.layers[1]
for ly in model.layers:
    for v,w in zip(ly.variables, ly.get_weights()):
        print(w.shape, v.name)

In [None]:
model.summary()

# Load Model

In [None]:
# gnn.load_weights_verbosely(
#     save_file='/Users/personal/Documents/Sudoku/tf2-gnn/best_models/pss_1.hdf5',
#     model=model
# )

# Train Model

In [None]:
def log_info(x):
  print(x)
  open("./best_models/pss_1.txt","a").write(x+'\n')

In [None]:
gnn.train(
    model, 
    dataset,    
    log_fun=log_info,
    run_id=0,
    max_epochs=5,
    patience=48,
    save_dir='./',
    quiet=False,
    aml_run=None,
    use_worker_threads=False,
)

In [None]:
gnn.test(    
    model, 
    dataset,    
    log_fun=log_info,
    use_worker_threads=False,
)

# Miscellaneous

In [None]:
datum = list(tf_dataset.take(1))[0]

In [None]:
output = model.call(datum[0], training=False)

In [None]:
model.compute_task_metrics(input_shapes,output, datum[1])

In [None]:
for i in [4,8,16,32,64,128,256]:
    print(f"\n{i}")
    params['gnn_num_layers']=i
    model = gnn.NodeMulticlassTask(params,dataset)
    # Build model
    input_shapes = dataset.get_batch_tf_data_description().batch_features_shapes
    model.build(input_shapes)
    # Load
    gnn.load_weights_verbosely(
        save_file='/Users/personal/Documents/PSS/tf2-gnn/best_models/colab_best.hdf5',
        model=model
    )
    gnn.test(    
    model, 
    dataset,    
    log_fun=log_info,
    use_worker_threads=False,
    )

# FiLM Model

In [None]:
params = gnn.models.NodeMulticlassTask.get_default_hyperparameters('GNN_FiLM')
params['gnn_hidden_dim'] = 32
params['gnn_num_edge_MLP_hidden_layers'] = 0
params['gnn_num_layers'] = 4
params['gnn_share_weights_between_mlps'] = False
params["gnn_message_activation_function"] = "gelu"

params['gnn_dense_every_num_layers'] = 1e5
params['gnn_residual_every_num_layers'] = 2
params['gnn_global_exchange_every_num_layers'] = 1e5
params['use_intermediate_gnn_results'] = False
params['loss_at_every_layer'] = False
params

In [None]:
model = gnn.NodeMulticlassTask(params,dataset)

In [None]:
# Build model
input_shapes = dataset.get_batch_tf_data_description().batch_features_shapes
model.build(input_shapes)

# Inspect model
# ly = model.layers[1]
for ly in model.layers:
    for v,w in zip(ly.variables, ly.get_weights()):
        print(w.shape, v.name)

In [None]:
model.summary()