In [1]:
import torch
import numpy as np

from torch.nn import Parameter, Linear, BatchNorm1d, ReLU, LeakyReLU, Linear, Dropout, CrossEntropyLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torchmetrics.regression import R2Score, MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import Accuracy, AUROC

from molsetrep.utils.torch_trainer import TorchTrainer
from molsetrep.utils.multiset_torch_trainer import MultisetTorchTrainer
from molsetrep.utils.datasets import molnet_loader
from molsetrep.utils.converters import molnet_to_pyg
from molsetrep.utils.root_mean_squared_error import RootMeanSquaredError
from molsetrep.utils.imbalanced_sampler import ImbalancedSampler
# from molsetrep.models import SetRepClassifier, SetRepRegressor, GNNDeepSetClassifier, DeepSet, DualSetRepClassifier, DualSetRepRegressor
from molsetrep.encoders import SECMQNFPEncoder, SECFPEncoder, ECFPEncoder, Mol2VecEncoder, Mol2SetEncoder

from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import wandb
from wandb import finish as wandb_finish


Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


## Setup

### Lightning Module

In [2]:
class DualSetClassifier(pl.LightningModule):
    def __init__(self, n_hidden_sets, n_hidden_sets_2, n_elements, n_elements_2, d, d_2, n_classes, class_weights):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_hidden_sets = n_hidden_sets
        self.n_elements = n_elements

        self.n_hidden_sets_2 = n_hidden_sets_2
        self.n_elements_2 = n_elements_2

        self.class_weights = class_weights

        self.Wc = Parameter(torch.FloatTensor(d, n_hidden_sets * n_elements))
        self.Wc_2 = Parameter(torch.FloatTensor(d_2, n_hidden_sets_2 * n_elements_2))
        self.fc1 = Linear(n_hidden_sets, 32)
        self.fc1_2 = Linear(n_hidden_sets_2, 32)
        self.bn = BatchNorm1d(n_hidden_sets)
        self.bn_2 = BatchNorm1d(n_hidden_sets_2)
        self.dropout_1 = Dropout(0.8)
        self.dropout_2 = Dropout(0.8)
        self.fc2 = Linear(32 * 2, 32)
        self.bn_3 = BatchNorm1d(32)
        self.fc3 = Linear(32, 16)
        self.fc4 = Linear(16, n_classes)

        
        # Init weights
        # self.Wc.data.normal_()
        # self.Wc_2.data.normal_()

        self.Wc.data.uniform_(-1, 1)
        self.Wc_2.data.uniform_(-1, 1)

        # Metrics
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.train_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.test_auroc = AUROC(task="multiclass", num_classes=n_classes)

        self.criterion = CrossEntropyLoss(weight=torch.FloatTensor(self.class_weights).to(self.device))
        self.criterion_eval = CrossEntropyLoss()

    def forward(self, X, X2):
        # First sets (e.g. atoms)
        t = torch.matmul(X, self.Wc)

        t = torch.relu(t)
        t = t.view(t.size()[0], t.size()[1], self.n_elements, self.n_hidden_sets)
        t, _ = torch.max(t, dim=2)
        t = torch.sum(t, dim=1)
        t = self.bn(t)

        t = self.fc1(t)
        # t = self.dropout_1(t)
        t = torch.relu(t)

        # Second sets (e.g. bonds)
        t_2 = torch.matmul(X2, self.Wc_2)
        t_2 = torch.relu(t_2)
        t_2 = t_2.view(
            t_2.size()[0], t_2.size()[1], self.n_elements_2, self.n_hidden_sets_2
        )
        t_2, _ = torch.max(t_2, dim=2)
        t_2 = torch.sum(t_2, dim=1)
        t_2 = self.bn_2(t_2)
        t_2 = self.fc1_2(t_2)
        # t_2 = self.dropout_1(t_2)
        t_2 = torch.relu(t_2)

        # Concat, mlp, and softmax
        out = self.fc2(torch.cat((t, t_2), 1))

        out = self.bn_3(out)
        out = torch.relu(out)
        # out = self.dropout_1(out)
        out = self.fc3(out)
        out = torch.relu(out)
        out = self.fc4(out)

        return out

    def training_step(self, batch, batch_idx):
        x, x2, y = batch

        out = self(x, x2)

        loss = self.criterion(out, y)

        # Metrics
        self.train_accuracy(out, y)
        self.train_auroc(out, y)

        self.log("train/loss", loss, on_step=False, on_epoch=True)
        self.log("train/acc", self.train_accuracy, on_step=False, on_epoch=True)
        self.log("train/auroc", self.train_auroc, on_step=False, on_epoch=True)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, x2, y = val_batch
        out = self.forward(x, x2)

        loss = self.criterion_eval(out, y)

        # Metrics
        self.val_accuracy(out, y)
        self.val_auroc(out, y)

        self.log("val/loss", loss, on_step=False, on_epoch=True)
        self.log("val/acc", self.val_accuracy, on_step=False, on_epoch=True)
        self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True)

    def test_step(self, val_batch, batch_idx):
        x, x2, y = val_batch
        out = self.forward(x, x2)
        loss = self.criterion_eval(out, y)

        # Metrics
        self.test_accuracy(out, y)
        self.test_auroc(out, y)

        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", self.test_accuracy)
        self.log("test/auroc", self.test_auroc)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

## Train

### Utilities

In [3]:
def get_class_weights(y):
    n = len(y)
    unique, counts = np.unique(y, return_counts=True)
    weights = [1 - c / n for c in counts]
    return np.array(weights)

### Load Data

In [4]:
data_set_name = "bace_classification"

train, valid, test = molnet_loader(data_set_name, splitter="scaffold")

enc = ECFPEncoder()

class_weights = get_class_weights(train.y.flatten())
print(class_weights)

train_dataset = enc.encode(train.ids, [y[0] for y in train.y], label_dtype=torch.long)
valid_dataset = enc.encode(valid.ids, [y[0] for y in valid.y], label_dtype=torch.long)
test_dataset = enc.encode(test.ids, [y[0] for y in test.y], label_dtype=torch.long)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)

d = len(train_dataset[0][0][0])
d2 = len(train_dataset[0][1][0])

[0.42561983 0.57438017]


### Fit

In [5]:
for _ in range(4):
    # Setup wandb logging
    wandb_logger = wandb.WandbLogger(project=f"MolRepSet-{data_set_name}")

    # Define callback for which callpoint to load for testing
    checkpoint_callback = ModelCheckpoint(monitor="val/auroc", mode="max")

    trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=50, log_every_n_steps=1, logger=wandb_logger)

    model = DualSetClassifier(128, 128, 64, 64, d, d2, 2, class_weights=class_weights)
    wandb_logger.watch(model, log="all")

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
    trainer.test(ckpt_path="best", dataloaders=test_loader)

    wandb_logger.finalize("success")
    wandb_finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdaenuprobst[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Ti Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 4.1 K 
1  | fc1_2          | Linear             | 4.1 K 
2  | bn             | BatchNorm1d        | 256   
3  | bn_2           | BatchNorm1d        | 256   
4  | dropout_1      | Dropout           

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-bace_classification/v3m04twl/checkpoints/epoch=39-step=720.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-bace_classification/v3m04twl/checkpoints/epoch=39-step=720.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▁▁▁▁▅▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇███▇██████▇█████
train/auroc,▁▁▂▂▅▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇▇████████████
train/loss,██▇▇▆▅▄▄▄▃▃▄▃▃▃▂▃▃▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▁▁▂▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,▂▁▁▁▅▅▃▅▅▅▅▅▅▅▆▅▅▇▆▅▇█▅▆▇▅▇▅█▇▆▆▆▆▅▅▅▇▇▇
val/auroc,▁▁▁▂▆▆▅▄▆▄▄▆▄▄▆▃▆▇▅▄▄▆▆▇▆▅▄▅▇▇▆▆▆▇▆▅▄▇██

0,1
epoch,50.0
test/acc,0.70312
test/auroc,0.7924
test/loss,0.80363
train/acc,0.86024
train/auroc,0.93455
train/loss,0.32071
trainer/global_step,900.0
val/acc,0.64844
val/auroc,0.69513


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667263093334744, max=1.0)…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 4.1 K 
1  | fc1_2          | Linear             | 4.1 K 
2  | bn             | BatchNorm1d        | 256   
3  | bn_2           | BatchNorm1d        | 256   
4  | dropout_1      | Dropout            | 0     
5  | dropout_2      | Dropout            | 0     
6  | fc2            | Linear             | 2.1 K 
7  | bn_3           | BatchNorm1d        | 64    
8  | fc3            | Linear             | 528   
9  | fc4            | Linear             | 34    
10 | train_accuracy | MulticlassAccuracy | 0     
11 | train_auroc    | MulticlassAUROC    | 0     
12 | val_accuracy 

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-bace_classification/6o7foops/checkpoints/epoch=29-step=540.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-bace_classification/6o7foops/checkpoints/epoch=29-step=540.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.089 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▁▂▁▂▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇▇█▇█▇███▇████
train/auroc,▁▂▃▄▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████
train/loss,█▇▇▇▅▅▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,▁▁▃▄▅▄▄▄▄▅▆▇▅▆▅▅▆▆▇▇▆▇▆▇▇▇▅▇▆▅▇▅▆▆█▅▆▆▆▆
val/auroc,▁▂▁▃▄▅▄▃▂▄▃▅▄▄▄▇▄▄▅▅▅▇▇███▆▅▆▄▇▇▆▆█▄▆▆▇▆

0,1
epoch,50.0
test/acc,0.67969
test/auroc,0.77892
test/loss,0.66756
train/acc,0.87153
train/auroc,0.94445
train/loss,0.29951
trainer/global_step,900.0
val/acc,0.61719
val/auroc,0.66773


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016672817100000733, max=1.0…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 4.1 K 
1  | fc1_2          | Linear             | 4.1 K 
2  | bn             | BatchNorm1d        | 256   
3  | bn_2           | BatchNorm1d        | 256   
4  | dropout_1      | Dropout            | 0     
5  | dropout_2      | Dropout            | 0     
6  | fc2            | Linear             | 2.1 K 
7  | bn_3           | BatchNorm1d        | 64    
8  | fc3            | Linear             | 528   
9  | fc4            | Linear             | 34    
10 | train_accuracy | MulticlassAccuracy | 0     
11 | train_auroc    | MulticlassAUROC    | 0     
12 | val_accuracy 

Sanity Checking: 0it [00:00, ?it/s]

Traceback (most recent call last):
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.9/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.9/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.9/shutil.py", line 740, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.9/shutil.py", line 738, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-90nee_px'


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-bace_classification/p35z023p/checkpoints/epoch=45-step=828.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-bace_classification/p35z023p/checkpoints/epoch=45-step=828.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▁▂▃▄▆▆▆▆▆▇▇▆▇▇▇▇▇▇▇▇█▇▇███▇███▇████▇████
train/auroc,▁▂▃▄▆▆▆▆▇▇▇▆▇▇▇▇▇▇▇▇█▇▇█████████████████
train/loss,█▇▇▇▅▄▄▄▄▄▃▄▃▃▃▃▃▃▂▃▂▂▂▂▂▁▂▂▂▁▂▁▂▁▁▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,▁▂▂▅▄▃▄▅▅▅▆▅▆▅▆▆▅▅▅▇▆▆▆▆▆▇▆▆▆▇▇▇▄▅▆▅██▆▇
val/auroc,▁▂▁▆▃▁▃▃▄▃▅▅▆▄▆▆▄█▇▅▇▆▇█▆▆▆▇▇▇██▅▇▇▆█▇▇▅

0,1
epoch,50.0
test/acc,0.70312
test/auroc,0.80441
test/loss,0.674
train/acc,0.86111
train/auroc,0.93942
train/loss,0.30913
trainer/global_step,900.0
val/acc,0.64062
val/auroc,0.64693


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016674429083347302, max=1.0…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 4.1 K 
1  | fc1_2          | Linear             | 4.1 K 
2  | bn             | BatchNorm1d        | 256   
3  | bn_2           | BatchNorm1d        | 256   
4  | dropout_1      | Dropout            | 0     
5  | dropout_2      | Dropout            | 0     
6  | fc2            | Linear             | 2.1 K 
7  | bn_3           | BatchNorm1d        | 64    
8  | fc3            | Linear             | 528   
9  | fc4            | Linear             | 34    
10 | train_accuracy | MulticlassAccuracy | 0     
11 | train_auroc    | MulticlassAUROC    | 0     
12 | val_accuracy 

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-bace_classification/6ogh0qfl/checkpoints/epoch=44-step=810.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-bace_classification/6ogh0qfl/checkpoints/epoch=44-step=810.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▁▁▂▃▅▅▅▆▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██▇████████▇
train/auroc,▁▂▃▄▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇█████████████
train/loss,██▇▆▅▅▅▄▄▄▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▃▂▂▁▁▂▁▂▁▁▁▁▁▁▂
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,▁▁▄▄▆▆▄▅▅▅▅▅▆▆▅▇█▇▆▄▅▇▆▆▆▇▅▅▅▆▄▇▇▅▅▆▆▇█▆
val/auroc,▁▁▃▃▇▇▄▅▅▅▃▆▅▄▄▆▃▅▆▅▆▆▆▄▆▄▆▅▆█▄█▇▆▅▇▇▇█▅

0,1
epoch,50.0
test/acc,0.73438
test/auroc,0.82721
test/loss,0.69977
train/acc,0.8559
train/auroc,0.93607
train/loss,0.3206
trainer/global_step,900.0
val/acc,0.60938
val/auroc,0.65182
