## Install requirements
Uncomment and install the requirements

In [None]:
# !pip install -U message-passing-nn

## Clone the repository to get the data folders

In [None]:
# !git clone https://github.com/kovanostra/message-passing-nn/

## Imports

In [None]:
import torch
import datetime
from message_passing_nn.model.model_trainer import ModelTrainer
from message_passing_nn.graph.graph_rnn_encoder import GraphRNNEncoder
from message_passing_nn.graph.graph_gru_encoder import GraphGRUEncoder
from message_passing_nn.data.data_preprocessor import DataPreprocessor
from message_passing_nn.repository.file_system_repository import FileSystemRepository

# Set up the variables 

In [None]:
device = "gpu" # You can use "cuda" for GraphRNNEncoder, but it is currently adviced to use "cpu" for the GraphGRUEncoder
epochs = 10
model = 'RNN'
loss_function = 'MSE'
optimizer = 'SGD'
batch_size = 5
maximum_number_of_nodes = 250 # Some of the adjacency matrices in our dataset are too big, this variable controls the maximum size of the matrices to load. To load the whole dataset set this value to -1.
maximum_number_of_features = 10 # Similarly for the number of features
validation_split = 0.2
test_split = 0.1
time_steps = 1 # The time steps of the message passing algorithm
validation_period = 20

configuration_dictionary = {'time_steps': time_steps,
                            'model': model,
                            'loss_function': loss_function,
                            'optimizer': optimizer}

## Prerocess the dataset
We load the protein-folding datacet in which each graph contains three pickle files:
  1. The features of each node (as torch.tensor.Size([M,N]))
  2. The adjacency matrix (as torch.tensor.Size([M,M]))
  3. The labels to predict (as torch.tensor.Size([L]))

where M is the number of graph nodes, N the number of features per node, and L the number of values to predict.

The dataset contains features and labels from 31 proteins from (https://www.rcsb.org). We apply a limit to the size of the proteins (to not crush the runtime) to we end up with 17 proteins which we equalize in size and split into training, validation and test datasets.

In [None]:
dataset_name = 'protein-folding'
data_directory = 'message-passing-nn/data/'
file_system_repository = FileSystemRepository(data_directory, dataset_name)
raw_dataset = file_system_repository.get_all_data()

Please uncomment the following block to see examples of the data used as input to the model.

In [None]:
# node_features_example, adjacency_matrix_example, labels_example = raw_dataset[0]
# print(node_features_example.size(), adjacency_matrix_example.size(), labels_example.size())

### Next we equalize the tensor sizes and split to train, validation and test sets

In [None]:
data_preprocessor = DataPreprocessor()
equalized_dataset = data_preprocessor.equalize_dataset_dimensions(raw_dataset,
                                                                  maximum_number_of_nodes,
                                                                  maximum_number_of_features)
training_data, validation_data, test_data = data_preprocessor.train_validation_test_split(equalized_dataset, 
                                                                                          batch_size, 
                                                                                          validation_split, 
                                                                                          test_split)
data_dimensions = data_preprocessor.extract_data_dimensions(equalized_dataset)

## Instantiate the model and the trainer

The Trainer is responsible for the instantiation, training and evaluation of the model. It also controls whether a mini-batch normalization over the node features and labels should be applied. The ModelTrainer can use either the RnnEncoder or the GRUEncoder.

In [None]:
configuration_dictionary = {'time_steps': time_steps,
                            'model': model,
                            'loss_function': loss_function,
                            'optimizer': optimizer}
model_trainer = ModelTrainer(data_preprocessor, device)
model_trainer.instantiate_attributes(data_dimensions, configuration_dictionary)

## Train the model
This block will train the model and output the training, validation and test losses along with the time. Our use case contains fully connected graphs and therefore the time to train is significantly longer than for sparsely connected graphs.

In [None]:
for epoch in range(epochs):
    training_loss = model_trainer.do_train(training_data, epoch)
    print("Epoch", epoch, "Training loss:", training_loss)
    if epoch % validation_period == 0:
        validation_loss = model_trainer.do_evaluate(validation_data, epoch)
        print("Epoch", epoch, "Validation loss:", validation_loss)
test_loss = model_trainer.do_evaluate(test_data)
print("Test loss:", validation_loss)