In [None]:
from dataset import MolerDataset
from torch_geometric.loader import DataLoader
from model import BaseModel
from model_utils import get_params
from torch.utils.data import RandomSampler
from sampler import DuplicatedIndicesSamplerWrapper
import pandas as pd
from pytorch_lightning import Trainer

processed_file_metadata = (
    "/data/ongh0068/l1000/pyg_output_playground/train/processed_file_paths.csv"
)
molecule_gen_steps_lengths = pd.read_csv(processed_file_metadata)[
    "molecule_gen_steps_length"
].tolist()



train_dataset = MolerDataset(
    root="/data/ongh0068",
    raw_moler_trace_dataset_parent_folder="/data/ongh0068/l1000/trace_playground",
    output_pyg_trace_dataset_parent_folder="/data/ongh0068/l1000/pyg_output_playground",
    split="train",
)

random_sampler = RandomSampler(data_source=[i for i in range(len(train_dataset))])
sampler = DuplicatedIndicesSamplerWrapper(
    sampler=random_sampler,
    frequency_mapping={
        idx: length for idx, length in enumerate(molecule_gen_steps_lengths)
    },
)


train_dataloader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=False,
    sampler=sampler,
    follow_batch=[
        "correct_edge_choices",
        "correct_edge_types",
        "valid_edge_choices",
        "valid_attachment_point_choices",
        "correct_attachment_point_choice",
        "correct_node_type_choices",
        "original_graph_x",
        'correct_first_node_type_choices'
    ],
)
params = get_params()
print('HERE')
model = BaseModel(params, train_dataset).to("cuda:1")


# datamodule = LightningDataset(dataset)
trainer = Trainer(accelerator = 'gpu', devices = 1)  # overfit_batches=1)


2022-12-25 10:17:10.772730: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
model 

In [None]:
trainer.fit(model, train_dataloader, train_dataloader)


In [5]:
for batch in train_dataloader:
    print(batch)
    break

MolerDataBatch(x=[4411, 32], edge_index=[2, 13381], original_graph_edge_type=[22916], original_graph_node_categorical_features=[7184], focus_node=[256], edge_type=[13381], edge_features=[1555, 3], correct_edge_choices=[1555], correct_edge_choices_batch=[1555], correct_edge_choices_ptr=[257], num_correct_edge_choices=[256], stop_node_label=[256], valid_edge_choices=[1555, 2], valid_edge_choices_batch=[1555], valid_edge_choices_ptr=[257], valid_edge_types=[117, 3], correct_edge_types=[117, 3], correct_edge_types_batch=[117], correct_edge_types_ptr=[257], partial_node_categorical_features=[4411], correct_attachment_point_choice=[28], correct_attachment_point_choice_batch=[28], correct_attachment_point_choice_ptr=[257], correct_node_type_choices=[118, 139], correct_node_type_choices_batch=[118], correct_node_type_choices_ptr=[257], correct_first_node_type_choices=[256, 139], correct_first_node_type_choices_batch=[256], correct_first_node_type_choices_ptr=[257], sa_score=[256], clogp=[256],

MolerDataBatch(x=[4352, 32], edge_index=[2, 13038], original_graph_edge_type=[24491], original_graph_node_categorical_features=[7601], focus_node=[256], edge_type=[13038], edge_features=[2169, 3], correct_edge_choices=[2169], correct_edge_choices_batch=[2169], correct_edge_choices_ptr=[257], num_correct_edge_choices=[256], stop_node_label=[256], valid_edge_choices=[2169, 2], valid_edge_choices_batch=[2169], valid_edge_choices_ptr=[257], valid_edge_types=[142, 3], correct_edge_types=[142, 3], correct_edge_types_batch=[142], correct_edge_types_ptr=[257], partial_node_categorical_features=[4352], correct_attachment_point_choice=[9], correct_attachment_point_choice_batch=[9], correct_attachment_point_choice_ptr=[257], correct_node_type_choices=[121, 139], correct_node_type_choices_batch=[121], correct_node_type_choices_ptr=[257], correct_first_node_type_choices=[256, 139], correct_first_node_type_choices_batch=[256], correct_first_node_type_choices_ptr=[257], sa_score=[256], clogp=[256], m