In [None]:
%load_ext autoreload
%autoreload 2

from katana import remote
from katana.remote import export_data, import_data

my_client = remote.Client()
my_client.server_version

In [None]:
from config import hyperparams

input_config = hyperparams.load_input_config()
input_config

In [None]:
from timeit import default_timer

# Import the module that uses Dask to import drug data
from src import dask_ingestion

#  graph = client.create_graph(num_partitions=input_config.num_partitions)

#  if input_config.use_train_rdg:
#      print(f"Import pretrained graph from: {input_config.trained_rdg_path}")
#      import_data.rdg(graph, input_config.trained_rdg_path)
#  else:
#      print("Generate the graph with data from source")
#      dask_ingestion.generate_deepcdr_graph(graph, input_config)

print("--")


In [None]:

#  This is my method ..
#

NUM_PARTITIONS  = 3
   #
DB_NAME         = "my_db"
GRAPH_NAME      = "my_graph"

print("--")


In [None]:

#  CONNECT TO GRAPH

my_graph, *_ = my_client.get_database(name=DB_NAME).find_graphs_by_name(GRAPH_NAME)

print(my_graph)


In [None]:
from src import katana_pipeline

rec_pipeline = katana_pipeline.RecipePipeline(my_graph)
rec_pipeline.graph.schema().view()

In [None]:

start_time = default_timer()

rec_pipeline.feature_generator()

print(f"***Took {default_timer() - start_time} seconds to generate the features.***")


Collecting genomics_expression features...


          0/? [?op/s]

Collecting genomics_methylation features...


          0/? [?op/s]

Collecting genomics_mutation features...


          0/? [?op/s]

In [None]:

#  stats = rec_pipeline.stats()
#  stats


In [None]:

#  start_time = default_timer()

rec_pipeline.split_generator(input_config)

#  print(f"***Took {default_timer() - start_time} seconds to generate the split.***")
#  rec_pipeline.graph.schema().view()

print("--")



In [None]:
model_config = hyperparams.load_model_config()
model_config

In [None]:
training_config = hyperparams.load_training_config()
training_config

In [None]:
start_time = default_timer()
validation_metric = rec_pipeline.train(model_config, training_config)
print(f"***Took {default_timer() - start_time} seconds to train the model.***")
print("Validation metric: ", validation_metric)

### Testing

Test the trained model on the test data and compare the results with other achitecture baselines. We compare three different metrics:
- [Pearson correlation coefficient](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient): a measure of linear correlation between two sets of data. A value closer to 1 is better.
- [Spearman correlation coefficient](https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient): how well the relationship between two variables can be described using a monotonic function. A value closer to 1 is better.
- [Root mean square error (RMSE)](https://en.wikipedia.org/wiki/Root-mean-square_deviation): a measure of the differences between values predicted by a model and the values observed. A lower value is better.

In [None]:
start_time = default_timer()
metrics, predictions, labels = rec_pipeline.test(training_config)
print(f"***Took {default_timer() - start_time} seconds to test the model.***")
metrics

### Plot Predictions

Create a scatter plot of the `IC50` predicted by the trained model in function of `IC50` real value.

In [None]:
start_time = default_timer()
rec_pipeline.plot(labels, predictions)
print(f"***Took {default_timer() - start_time} seconds to plot figures.***")

<a id='step-4'></a>
## Trained model inference 

<img src="images/pipeline_inference.png" style="width: 1500px;"/>



In [None]:
rec_pipeline.infer(training_config)

In [None]:
bortezomib = "B(C(CC(C)C)NC(=O)C(CC1=CC=CC=C1)NC(=O)C2=NC=CN=C2)(O)O"
cell_line = "ACH-000001"
rec_pipeline.infer(training_config, drug=bortezomib, cell_line=cell_line)

## Run trained model to save node embeddings

Save the trained model embeddings as a node property. `drug_embeddings`, `epigenomics_embeddings`, `genomics_embeddings` and `transcriptomics_embeddings` will be created and saved. Those embeddings can be use for a future downstream task with a different model.

In [None]:
start_time = default_timer()
rec_pipeline.infer_embeddings(model_config)
print(f"***Took {default_timer() - start_time} seconds to save node embeddings.***")
rec_pipeline.graph.schema().view()

## Save Graph

Save the graph created to the bucket location named in the `save_graph_path` variable.

In [None]:
if input_config.save_graph_path:
    export_data.rdg(graph, input_config.save_graph_path)