In [1]:
import torch
from torch.utils.data import DataLoader
from molsetrep.utils.datasets import molnet_loader, get_class_weights
from wandb import finish as wandb_finish
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import wandb
import lightning.pytorch as pl

Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
def molnet_test_runner(data_set_name, encoder, model, label_dtype, checkpoint_callback, max_epochs=50, batch_size=64):
    train, valid, test, tasks = molnet_loader(data_set_name, splitter="random")

    for task in range(len(tasks)):
        class_weights, class_counts = get_class_weights(train.y, task)
        print(class_weights)
        print(class_counts)

        train_dataset = encoder.encode(train.ids, [y[task] for y in train.y], label_dtype=label_dtype)
        valid_dataset = encoder.encode(valid.ids, [y[task] for y in valid.y], label_dtype=label_dtype)
        test_dataset = encoder.encode(test.ids, [y[task] for y in test.y], label_dtype=label_dtype)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True)

        d = len(train_dataset[0][0][0])
        d2 = len(train_dataset[0][1][0])
        d3 = len(train_dataset[0][2][0])

        for _ in range(4):
            # Make sure no run is ongoing
            wandb_finish()
            
            # Setup wandb logging
            wandb_logger = wandb.WandbLogger(project=f"MolRepSet-triple-{data_set_name}")
            wandb_logger.experiment.config["task"] = tasks[task]

            trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=max_epochs, log_every_n_steps=1, logger=wandb_logger)
            wandb_logger.watch(model, log="all")

            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
            trainer.test(ckpt_path="best", dataloaders=test_loader)

            wandb_logger.finalize("success")
            wandb_finish()