In [1]:
import torch
import numpy as np

from torch.nn import Parameter, Linear, BatchNorm1d, ReLU, LeakyReLU, Linear, Dropout, CrossEntropyLoss, BCEWithLogitsLoss
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torchmetrics.regression import R2Score, MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import Accuracy, AUROC

from molsetrep.utils.torch_trainer import TorchTrainer
from molsetrep.utils.multiset_torch_trainer import MultisetTorchTrainer
from molsetrep.utils.datasets import molnet_loader
from molsetrep.utils.converters import molnet_to_pyg
from molsetrep.utils.root_mean_squared_error import RootMeanSquaredError
from molsetrep.utils.imbalanced_sampler import ImbalancedSampler
from molsetrep.encoders import TripleSetEncoder
from molsetrep.models import FocalLoss

from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import wandb
from wandb import finish as wandb_finish


Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


## Setup

### Lightning Module

In [2]:
class TripleSetClassifier(pl.LightningModule):
    def __init__(
        self,
        n_hidden_sets,
        n_hidden_sets_2,
        n_hidden_sets_3,
        n_elements,
        n_elements_2,
        n_elements_3,
        d,
        d_2,
        d_3,
        n_classes,
        class_weights,
        learning_rate=0.001,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.n_hidden_sets = n_hidden_sets
        self.n_elements = n_elements

        self.n_hidden_sets_2 = n_hidden_sets_2
        self.n_elements_2 = n_elements_2

        self.n_hidden_sets_3 = n_hidden_sets_3
        self.n_elements_3 = n_elements_3

        self.class_weights = class_weights
        self.learning_rate = learning_rate

        self.Wc = Parameter(torch.FloatTensor(d, n_hidden_sets * n_elements))
        self.Wc_2 = Parameter(torch.FloatTensor(d_2, n_hidden_sets_2 * n_elements_2))
        self.Wc_3 = Parameter(torch.FloatTensor(d_3, n_hidden_sets_3 * n_elements_3))

        self.fc1 = Linear(n_hidden_sets, 32)
        self.fc1_2 = Linear(n_hidden_sets_2, 32)
        self.fc1_3 = Linear(n_hidden_sets_3, 32)

        self.bn = BatchNorm1d(n_hidden_sets)
        self.bn_2 = BatchNorm1d(n_hidden_sets_2)
        self.bn_3 = BatchNorm1d(n_hidden_sets_3)

        self.dropout_1 = Dropout(0.8)
        self.dropout_2 = Dropout(0.8)
        self.dropout_3 = Dropout(0.8)

        # MLP
        self.fc2 = Linear(32 * 3, 32)
        self.bn_mlp = BatchNorm1d(32)
        self.fc3 = Linear(32, 16)
        self.fc4 = Linear(16, n_classes)

        # Init weights
        self.Wc.data.uniform_(-1, 1)
        self.Wc_2.data.uniform_(-1, 1)
        self.Wc_3.data.uniform_(-1, 1)

        # Metrics
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.train_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.val_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.val_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.test_auroc = AUROC(task="multiclass", num_classes=n_classes)

        self.criterion = CrossEntropyLoss(
            weight=torch.FloatTensor(self.class_weights).to(self.device)
        )
        self.criterion_eval = CrossEntropyLoss()


    def forward(self, X, X2, X3):
        # First sets (e.g. atoms)
        t = torch.matmul(X, self.Wc)

        t = torch.relu(t)
        t = t.view(t.size()[0], t.size()[1], self.n_elements, self.n_hidden_sets)
        t, _ = torch.max(t, dim=2)
        t = torch.sum(t, dim=1)
        t = self.bn(t)

        t = self.fc1(t)
        # t = self.dropout_1(t)
        t = torch.relu(t)

        # Second sets (e.g. bonds)
        t_2 = torch.matmul(X2, self.Wc_2)
        t_2 = torch.relu(t_2)
        t_2 = t_2.view(
            t_2.size()[0], t_2.size()[1], self.n_elements_2, self.n_hidden_sets_2
        )
        t_2, _ = torch.max(t_2, dim=2)
        t_2 = torch.sum(t_2, dim=1)
        t_2 = self.bn_2(t_2)
        t_2 = self.fc1_2(t_2)
        # t_2 = self.dropout_1(t_2)
        t_2 = torch.relu(t_2)

        # Third sets
        t_3 = torch.matmul(X3, self.Wc_3)
        t_3 = torch.relu(t_3)
        t_3 = t_3.view(
            t_3.size()[0], t_3.size()[1], self.n_elements_3, self.n_hidden_sets_3
        )
        t_3, _ = torch.max(t_3, dim=2)
        t_3 = torch.sum(t_3, dim=1)
        t_3 = self.bn_3(t_3)
        t_3 = self.fc1_3(t_3)
        # t_2 = self.dropout_1(t_2)
        t_3 = torch.relu(t_3)

        # Concat, mlp, and softmax
        out = self.fc2(torch.cat((t, t_2, t_3), 1))

        out = self.bn_mlp(out)
        out = torch.relu(out)
        # out = self.dropout_1(out)
        out = self.fc3(out)
        out = torch.relu(out)
        out = self.fc4(out)

        return out

    def training_step(self, batch, batch_idx):
        x, x2, x3, y = batch

        out = self(x, x2, x3)
        loss = self.criterion(out, y)

        # Metrics
        self.train_accuracy(out, y)
        self.train_auroc(out, y)

        self.log("train/loss", loss, on_step=False, on_epoch=True)
        self.log("train/acc", self.train_accuracy, on_step=False, on_epoch=True)
        self.log("train/auroc", self.train_auroc, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, val_batch, batch_idx):
        x, x2, x3, y = val_batch

        out = self.forward(x, x2, x3)
        loss = self.criterion_eval(out, y)

        # Metrics
        self.val_accuracy(out, y)
        self.val_auroc(out, y)

        self.log("val/loss", loss, on_step=False, on_epoch=True)
        self.log("val/acc", self.val_accuracy, on_step=False, on_epoch=True)
        self.log("val/auroc", self.val_auroc, on_step=False, on_epoch=True)

    def test_step(self, val_batch, batch_idx):
        x, x2, x3, y = val_batch

        out = self.forward(x, x2, x3)
        loss = self.criterion_eval(out, y)

        # Metrics
        self.test_accuracy(out, y)
        self.test_auroc(out, y)

        self.log("test/loss", loss, on_step=False, on_epoch=True)
        self.log("test/acc", self.test_accuracy)
        self.log("test/auroc", self.test_auroc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

## Train

### Utilities

In [3]:
def get_class_weights(y, task_idx):
    y_t = y.T

    unique, counts = np.unique(y_t[task_idx], return_counts=True)
    weights = [1 - c / y_t[task_idx].shape[0] for c in counts]

    return np.array(weights), np.array(counts)

### Load Data & Fit

In [4]:
data_set_name = "clintox"

train, valid, test, tasks = molnet_loader(data_set_name, splitter="random")
enc = TripleSetEncoder()

for task in range(len(tasks)):
    class_weights, class_counts = get_class_weights(train.y, task)
    print(class_weights)
    print(class_counts)

    train_dataset = enc.encode(train.ids, [y[task] for y in train.y], label_dtype=torch.long)
    valid_dataset = enc.encode(valid.ids, [y[task] for y in valid.y], label_dtype=torch.long)
    test_dataset = enc.encode(test.ids, [y[task] for y in test.y], label_dtype=torch.long)

    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)

    d = len(train_dataset[0][0][0])
    d2 = len(train_dataset[0][1][0])
    d3 = len(train_dataset[0][2][0])

    for _ in range(4):
        # Make sure no run is ongoing
        wandb_finish()
        
        # Setup wandb logging
        wandb_logger = wandb.WandbLogger(project=f"MolRepSet-triple-{data_set_name}")
        wandb_logger.experiment.config["task"] = tasks[task]

        # Define callback for which callpoint to load for testing
        checkpoint_callback = ModelCheckpoint(monitor="val/auroc", mode="max")

        trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=50, log_every_n_steps=1, logger=wandb_logger)

        model = TripleSetClassifier(32, 32, 16, 16, 16, 8, d, d2, d3, 2, class_weights=class_weights, learning_rate=0.01)
        wandb_logger.watch(model, log="all")

        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
        trainer.test(ckpt_path="best", dataloaders=test_loader)

        wandb_logger.finalize("success")
        wandb_finish()

[0.9357022 0.0642978]
[  76 1106]


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdaenuprobst[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
You are using a CUDA device ('NVIDIA GeForce RTX 4070 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 1.1 K 
1  | fc1_2          | Linear             | 1.1 K 
2  | fc1_3          | Linear             | 544   
3  | bn             | BatchNorm1d        | 64    
4  | bn_2           | BatchNorm1d        | 64    
5

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-triple-clintox/kb6fdn2n/checkpoints/epoch=38-step=702.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-triple-clintox/kb6fdn2n/checkpoints/epoch=38-step=702.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,██▇▂▃▃▂▃▃▃▃▃▄▄▅▁▃▃▃▅▅▄▄▅▄▃▅▄▆▃▅▅▆▆▆▆▆▇▆▇
train/auroc,▁▃▃▃▃▄▄▄▅▅▅▅▆▇▆▅▆▆▆▆▇▆▆▇▆▆▆▇▇▆▇▇▇█▇▇████
train/loss,█▇▇▇▇▇▇▆▆▆▆▆▅▄▅▆▆▅▅▅▄▄▄▄▄▄▄▄▃▄▃▃▂▂▂▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,██▁▁▆▄███▅▁▇▁█▆▁██▁██▆▇█▇▅▇▇█▂▇█████████
val/auroc,▆▁▄▇▆▆▆▄▅▇▄█▇█▇█▆▄▇▇▆█▇▆▇▇██▅▇██▇▇▇▇▇▇▇▇

0,1
epoch,50.0
test/acc,0.90625
test/auroc,0.61458
test/loss,0.54365
train/acc,0.87674
train/auroc,0.95317
train/loss,0.25347
trainer/global_step,900.0
val/acc,0.94531
val/auroc,0.66352


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666787388330704, max=1.0)…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 1.1 K 
1  | fc1_2          | Linear             | 1.1 K 
2  | fc1_3          | Linear             | 544   
3  | bn             | BatchNorm1d        | 64    
4  | bn_2           | BatchNorm1d        | 64    
5  | bn_3           | BatchNorm1d        | 32    
6  | dropout_1      | Dropout            | 0     
7  | dropout_2      | Dropout            | 0     
8  | dropout_3      | Dropout            | 0     
9  | fc2            | Linear             | 3.1 K 
10 | bn_mlp         | BatchNorm1d        | 64    
11 | fc3            | Linear             | 528   
12 | fc4          

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-triple-clintox/l10ewhgn/checkpoints/epoch=22-step=414.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-triple-clintox/l10ewhgn/checkpoints/epoch=22-step=414.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▁▅▅▄▃▂▃▃▃▃▅▄▅▅▅▅▄▆▆▅▅▆▇▇▇▆▇▇▇█▇█▇███▇▇██
train/auroc,▁▂▂▂▄▃▄▄▅▄▅▅▆▆▅▆▆▆▆▆▇▆▇▇▇▇█▇█▇██████████
train/loss,█▇▇▇▇▇▆▇▆▆▆▆▅▅▅▅▅▄▄▅▄▄▃▃▄▃▂▃▁▂▂▁▂▁▁▂▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,███████▁████████████████████████████████
val/auroc,▆▁▂▅▄▄▃▇▃▅▇▇▇▇▇▇▇██▇▇▇▇▇▇▇▆▇▇▇▇▇▆▇▇▇▇▇▇▇

0,1
epoch,50.0
test/acc,0.9375
test/auroc,0.69062
test/loss,0.33315
train/acc,0.89149
train/auroc,0.97223
train/loss,0.19024
trainer/global_step,900.0
val/acc,0.94531
val/auroc,0.66824


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666786748337472, max=1.0)…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 1.1 K 
1  | fc1_2          | Linear             | 1.1 K 
2  | fc1_3          | Linear             | 544   
3  | bn             | BatchNorm1d        | 64    
4  | bn_2           | BatchNorm1d        | 64    
5  | bn_3           | BatchNorm1d        | 32    
6  | dropout_1      | Dropout            | 0     
7  | dropout_2      | Dropout            | 0     
8  | dropout_3      | Dropout            | 0     
9  | fc2            | Linear             | 3.1 K 
10 | bn_mlp         | BatchNorm1d        | 64    
11 | fc3            | Linear             | 528   
12 | fc4          

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.
Restoring states from the checkpoint path at ./MolRepSet-triple-clintox/rxobjth8/checkpoints/epoch=48-step=882.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at ./MolRepSet-triple-clintox/rxobjth8/checkpoints/epoch=48-step=882.ckpt


Testing: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
test/acc,▁
test/auroc,▁
test/loss,▁
train/acc,▂▁▄▄▃▃▄▄▃▅▃▄▅▄▄▅▅▅▆▄▇▆▆▆▇▇▇█▇▇▇▆▇▇█▆▇▇██
train/auroc,▁▂▂▃▃▄▄▄▅▄▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇██████▇████████
train/loss,██▇▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▅▄▃▃▃▃▂▂▂▂▂▂▁▃▂▂▂▁▂▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
val/acc,▁███▃█▅██▃█▅█▅██████████████████████████
val/auroc,█▇▃▄▂▁▂▂▂▂▃▄▄▅▅▅▂▃▂▂▂▁▃▅▁▁▁▁▁▁▂▂▂▂▂▃▃▂▄▃

0,1
epoch,50.0
test/acc,0.9375
test/auroc,0.62708
test/loss,0.22777
train/acc,0.90365
train/auroc,0.96835
train/loss,0.22949
trainer/global_step,900.0
val/acc,0.94531
val/auroc,0.44274


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666886365001119, max=1.0)…

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 1.1 K 
1  | fc1_2          | Linear             | 1.1 K 
2  | fc1_3          | Linear             | 544   
3  | bn             | BatchNorm1d        | 64    
4  | bn_2           | BatchNorm1d        | 64    
5  | bn_3           | BatchNorm1d        | 32    
6  | dropout_1      | Dropout            | 0     
7  | dropout_2      | Dropout            | 0     
8  | dropout_3      | Dropout            | 0     
9  | fc2            | Linear             | 3.1 K 
10 | bn_mlp         | BatchNorm1d        | 64    
11 | fc3            | Linear             | 528   
12 | fc4          

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]



Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Exception ignored in: <function _releaseLock at 0x7f1f09b4beb0>
Traceback (most recent call last):
  File "/home/daenu/miniconda3/envs/molsetrep/lib/python3.10/logging/__init__.py", line 228, in _releaseLock
    def _releaseLock():
KeyboardInterrupt: 


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]