# Molecular set representation learning - Molecular property prediction

## Imports

In [None]:
from multiprocessing import cpu_count

import torch

import pandas as pd
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint

# from torch.utils.data import DataLoader
from torch_geometric.loader import DataLoader

from torch_geometric.datasets import LRGBDataset

from molsetrep.models import LightningSRGNNClassifier, LightningSRGNNRegressor

## Prepare the data

### Get LRGB Data

In [None]:
dataset_train = LRGBDataset("./tmp", "Peptides-struct", "train")
dataset_val = LRGBDataset("./tmp", "Peptides-struct", "val")
dataset_test = LRGBDataset("./tmp", "Peptides-struct", "test")

In [None]:
dims_graph = [
    dataset_train.num_node_features,
    dataset_train.num_edge_features,
]
dims_graph

### Get torch data loaders

#### Set-enhanced GNN (SR-GNN)

In [None]:
train_loader = DataLoader(
    dataset_train,
    batch_size=64,
    shuffle=True,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

val_loader = DataLoader(
    dataset_val,
    batch_size=64,
    shuffle=False,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

test_loader = DataLoader(
    dataset_test,
    batch_size=64,
    shuffle=False,
    num_workers=cpu_count() if cpu_count() < 8 else 8,
    drop_last=True
)

## Train

### Initialise the model

#### Set-enhanced GNN (SR-GNN)

In [None]:
# model_graph = LightningSRGNNClassifier(
#     [128, 128], [64, 64],
#     n_hidden_channels=[128, 64],
#     n_in_channels=dims_graph[0], 
#     n_edge_channels=dims_graph[1],
#     n_layers=8,
#     n_classes=10,
#     metrics=[""],
#     metrics_task="multilabel"
# )

model_graph = LightningSRGNNRegressor(
    [128, 128], [64, 64],
    n_hidden_channels=[128, 64],
    n_in_channels=dims_graph[0], 
    n_edge_channels=dims_graph[1],
    n_layers=8,
    n_tasks=11,
)

### Initialise the trainer and fit

#### Set-enhanced GNN (SR-GNN)

In [None]:
checkpoint_callback = ModelCheckpoint(dirpath="./tmp", save_top_k=1)

trainer_graph = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback],
    
)

trainer_graph.fit(model_graph, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer_graph.test(dataloaders=test_loader)