In [1]:
import torch
import numpy as np

from torch.nn import Parameter, Linear, BatchNorm1d, ReLU, LeakyReLU, Linear, Dropout
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torcheval.metrics import BinaryAccuracy, BinaryAUROC
from torchmetrics.regression import R2Score, MeanSquaredError, MeanAbsoluteError
from torchmetrics.classification import Accuracy, AUROC

from molsetrep.utils.torch_trainer import TorchTrainer
from molsetrep.utils.multiset_torch_trainer import MultisetTorchTrainer
from molsetrep.utils.datasets import molnet_loader
from molsetrep.utils.converters import molnet_to_pyg
from molsetrep.utils.root_mean_squared_error import RootMeanSquaredError
from molsetrep.utils.imbalanced_sampler import ImbalancedSampler
# from molsetrep.models import SetRepClassifier, SetRepRegressor, GNNDeepSetClassifier, DeepSet, DualSetRepClassifier, DualSetRepRegressor
from molsetrep.encoders import SECMQNFPEncoder, SECFPEncoder, ECFPEncoder, Mol2VecEncoder, Mol2SetEncoder

from sklearn.preprocessing import StandardScaler
from sklearn.utils.class_weight import compute_class_weight

import matplotlib.pyplot as plt

import lightning.pytorch as pl


Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


## Setup

### Lightning Module

In [5]:
class DualSetClassifier(pl.LightningModule):
    def __init__(self, n_hidden_sets, n_hidden_sets_2, n_elements, n_elements_2, d, d_2, n_classes, class_weights):
        super().__init__()
        self.n_hidden_sets = n_hidden_sets
        self.n_elements = n_elements

        self.n_hidden_sets_2 = n_hidden_sets_2
        self.n_elements_2 = n_elements_2

        self.class_weights = class_weights

        self.Wc = Parameter(torch.FloatTensor(d, n_hidden_sets * n_elements))
        self.Wc_2 = Parameter(torch.FloatTensor(d_2, n_hidden_sets_2 * n_elements_2))
        self.fc1 = Linear(n_hidden_sets, 32)
        self.fc1_2 = Linear(n_hidden_sets_2, 32)
        self.bn = BatchNorm1d(n_hidden_sets)
        self.bn_2 = BatchNorm1d(n_hidden_sets_2)
        self.dropout_1 = Dropout(0.8)
        self.dropout_2 = Dropout(0.8)
        self.fc2 = Linear(32 * 2, 32)
        self.bn_3 = BatchNorm1d(32)
        self.fc3 = Linear(32, 16)
        self.fc4 = Linear(16, n_classes)

        
        # Init weights
        self.Wc.data.normal_()
        self.Wc_2.data.normal_()

        # Metrics
        self.train_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.train_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.valid_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.valid_auroc = AUROC(task="multiclass", num_classes=n_classes)
        self.test_accuracy = Accuracy(task="multiclass", num_classes=n_classes)
        self.test_auroc = AUROC(task="multiclass", num_classes=n_classes)

    def forward(self, X, X2):
        

        # First sets (e.g. atoms)
        t = torch.matmul(X, self.Wc)
        if torch.isnan(t).any():
            print(t)
            return None
        t = torch.relu(t)
        t = t.view(t.size()[0], t.size()[1], self.n_elements, self.n_hidden_sets)
        t, _ = torch.max(t, dim=2)
        t = torch.sum(t, dim=1)
        t = self.bn(t)
        t = self.fc1(t)
        # t = self.dropout_1(t)
        t = torch.relu(t)

        # Second sets (e.g. bonds)
        t_2 = torch.matmul(X2, self.Wc_2)
        t_2 = torch.relu(t_2)
        t_2 = t_2.view(
            t_2.size()[0], t_2.size()[1], self.n_elements_2, self.n_hidden_sets_2
        )
        t_2, _ = torch.max(t_2, dim=2)
        t_2 = torch.sum(t_2, dim=1)
        t_2 = self.bn_2(t_2)
        t_2 = self.fc1_2(t_2)
        # t_2 = self.dropout_1(t_2)
        t_2 = torch.relu(t_2)

        # Concat, mlp, and softmax
        out = self.fc2(torch.cat((t, t_2), 1))
        out = self.bn_3(out)
        out = torch.relu(out)
        # out = self.dropout_1(out)
        out = self.fc3(out)
        out = torch.relu(out)
        out = self.fc4(out)
        out = F.log_softmax(out, dim=1)

        return out

    def training_step(self, batch, batch_idx):
        x, x2, y = batch
        out = self(x, x2)
        loss = F.nll_loss(out, y, weight=torch.FloatTensor(self.class_weights).to(self.device))

        # Metrics
        self.train_accuracy(out, y)
        self.train_auroc(out, y)

        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, x2, y = val_batch
        out = self.forward(x, x2)
        loss = F.nll_loss(out, y)

        # Metrics
        self.valid_accuracy(out, y)
        self.valid_auroc(out, y)

        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

    def test_step(self, val_batch, batch_idx):
        x, x2, y = val_batch
        out = self.forward(x, x2)
        loss = F.nll_loss(out, y)

        # Metrics
        self.test_accuracy(out, y)
        self.test_auroc(out, y)

        self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log("test_accuracy", self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True)
        self.log("test_auroc", self.test_auroc, prog_bar=True, on_step=False, on_epoch=True)

    def on_train_epoch_end(self):
        self.log("train_acc_epoch", self.train_accuracy)
        self.log("train_auroc_epoch", self.train_auroc)

        print("Train AUROC", self.train_auroc.compute())

    def on_validation_epoch_end(self):
        self.log("valid_acc_epoch", self.valid_accuracy)
        self.log("valid_auroc_epoch", self.valid_auroc)

        print("Valid AUROC", self.valid_auroc.compute())
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

## Train

### Load Data

In [3]:
train, valid, test = molnet_loader("bbbp", splitter="scaffold")

enc = ECFPEncoder()

class_weights = compute_class_weight("balanced", classes=np.sort(np.unique(train.y.flatten())), y=train.y.flatten())
print(class_weights)

train_dataset = enc.encode(train.ids, [y[0] for y in train.y], label_dtype=torch.long)
valid_dataset = enc.encode(valid.ids, [y[0] for y in valid.y], label_dtype=torch.long)
test_dataset = enc.encode(test.ids, [y[0] for y in test.y], label_dtype=torch.long)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False, num_workers=8)#, sampler=ImbalancedSampler(train_dataset))
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)

d = len(train_dataset[0][0][0])
d2 = len(train_dataset[0][1][0])

[2.81206897 0.60812826]


### Fit

In [6]:
trainer = pl.Trainer(max_epochs=100, log_every_n_steps=1, gradient_clip_val=0.5)
model = DualSetClassifier(16, 16, 8, 8, d, d2, 2, class_weights=class_weights)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=valid_loader)
trainer.test(ckpt_path="best", dataloaders=test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type               | Params
-------------------------------------------------------
0  | fc1            | Linear             | 544   
1  | fc1_2          | Linear             | 544   
2  | bn             | BatchNorm1d        | 32    
3  | bn_2           | BatchNorm1d        | 32    
4  | dropout_1      | Dropout            | 0     
5  | dropout_2      | Dropout            | 0     
6  | fc2            | Linear             | 2.1 K 
7  | bn_3           | BatchNorm1d        | 64    
8  | fc3            | Linear             | 528   
9  | fc4            | Linear             | 34    
10 | train_accuracy | MulticlassAccuracy | 0     
11 | train_auroc    | MulticlassAUROC    | 0     
12 | valid_accuracy | MulticlassAccuracy | 0     
13 | valid_auroc    | MulticlassAUROC    | 0     

Sanity Checking: 0it [00:00, ?it/s]

Valid AUROC tensor(0.8029, device='cuda:0')




Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.3693, device='cuda:0')
Train AUROC tensor(0.5152, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.8407, device='cuda:0')
Train AUROC tensor(0.6153, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.8753, device='cuda:0')
Train AUROC tensor(0.6589, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.8908, device='cuda:0')
Train AUROC tensor(0.6769, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.8994, device='cuda:0')
Train AUROC tensor(0.6909, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9054, device='cuda:0')
Train AUROC tensor(0.7023, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9098, device='cuda:0')
Train AUROC tensor(0.7113, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9147, device='cuda:0')
Train AUROC tensor(0.7177, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9189, device='cuda:0')
Train AUROC tensor(0.7259, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9246, device='cuda:0')
Train AUROC tensor(0.7359, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9263, device='cuda:0')
Train AUROC tensor(0.7416, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9303, device='cuda:0')
Train AUROC tensor(0.7463, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9328, device='cuda:0')
Train AUROC tensor(0.7493, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9356, device='cuda:0')
Train AUROC tensor(0.7540, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9362, device='cuda:0')
Train AUROC tensor(0.7564, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9391, device='cuda:0')
Train AUROC tensor(0.7603, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9385, device='cuda:0')
Train AUROC tensor(0.7613, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9409, device='cuda:0')
Train AUROC tensor(0.7639, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9393, device='cuda:0')
Train AUROC tensor(0.7656, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9421, device='cuda:0')
Train AUROC tensor(0.7677, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9424, device='cuda:0')
Train AUROC tensor(0.7694, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9444, device='cuda:0')
Train AUROC tensor(0.7707, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9437, device='cuda:0')
Train AUROC tensor(0.7715, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9428, device='cuda:0')
Train AUROC tensor(0.7731, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9424, device='cuda:0')
Train AUROC tensor(0.7746, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9425, device='cuda:0')
Train AUROC tensor(0.7754, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9421, device='cuda:0')
Train AUROC tensor(0.7748, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9463, device='cuda:0')
Train AUROC tensor(0.7765, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9481, device='cuda:0')
Train AUROC tensor(0.7816, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9461, device='cuda:0')
Train AUROC tensor(0.7839, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9473, device='cuda:0')
Train AUROC tensor(0.7865, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9485, device='cuda:0')
Train AUROC tensor(0.7888, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9489, device='cuda:0')
Train AUROC tensor(0.7911, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9490, device='cuda:0')
Train AUROC tensor(0.7931, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9493, device='cuda:0')
Train AUROC tensor(0.7955, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9495, device='cuda:0')
Train AUROC tensor(0.7973, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9500, device='cuda:0')
Train AUROC tensor(0.8000, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9514, device='cuda:0')
Train AUROC tensor(0.8031, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9522, device='cuda:0')
Train AUROC tensor(0.8055, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9525, device='cuda:0')
Train AUROC tensor(0.8081, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9527, device='cuda:0')
Train AUROC tensor(0.8110, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9526, device='cuda:0')
Train AUROC tensor(0.8135, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9517, device='cuda:0')
Train AUROC tensor(0.8164, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9513, device='cuda:0')
Train AUROC tensor(0.8196, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9519, device='cuda:0')
Train AUROC tensor(0.8222, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9505, device='cuda:0')
Train AUROC tensor(0.8239, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9513, device='cuda:0')
Train AUROC tensor(0.8259, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9494, device='cuda:0')
Train AUROC tensor(0.8275, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9510, device='cuda:0')
Train AUROC tensor(0.8292, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9496, device='cuda:0')
Train AUROC tensor(0.8305, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9497, device='cuda:0')
Train AUROC tensor(0.8317, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9497, device='cuda:0')
Train AUROC tensor(0.8330, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9481, device='cuda:0')
Train AUROC tensor(0.8343, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9484, device='cuda:0')
Train AUROC tensor(0.8354, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9474, device='cuda:0')
Train AUROC tensor(0.8367, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9487, device='cuda:0')
Train AUROC tensor(0.8384, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9492, device='cuda:0')
Train AUROC tensor(0.8397, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9492, device='cuda:0')
Train AUROC tensor(0.8414, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9483, device='cuda:0')
Train AUROC tensor(0.8417, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9490, device='cuda:0')
Train AUROC tensor(0.8429, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9475, device='cuda:0')
Train AUROC tensor(0.8437, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9486, device='cuda:0')
Train AUROC tensor(0.8449, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9476, device='cuda:0')
Train AUROC tensor(0.8454, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9487, device='cuda:0')
Train AUROC tensor(0.8464, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9482, device='cuda:0')
Train AUROC tensor(0.8472, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9467, device='cuda:0')
Train AUROC tensor(0.8485, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9486, device='cuda:0')
Train AUROC tensor(0.8492, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9477, device='cuda:0')
Train AUROC tensor(0.8498, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9467, device='cuda:0')
Train AUROC tensor(0.8502, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9462, device='cuda:0')
Train AUROC tensor(0.8512, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9450, device='cuda:0')
Train AUROC tensor(0.8516, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9438, device='cuda:0')
Train AUROC tensor(0.8523, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9439, device='cuda:0')
Train AUROC tensor(0.8524, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9431, device='cuda:0')
Train AUROC tensor(0.8540, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9454, device='cuda:0')
Train AUROC tensor(0.8547, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9421, device='cuda:0')
Train AUROC tensor(0.8560, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9443, device='cuda:0')
Train AUROC tensor(0.8569, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9439, device='cuda:0')
Train AUROC tensor(0.8575, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9440, device='cuda:0')
Train AUROC tensor(0.8587, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9437, device='cuda:0')
Train AUROC tensor(0.8591, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9438, device='cuda:0')
Train AUROC tensor(0.8601, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9432, device='cuda:0')
Train AUROC tensor(0.8611, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9439, device='cuda:0')
Train AUROC tensor(0.8614, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9428, device='cuda:0')
Train AUROC tensor(0.8627, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9417, device='cuda:0')
Train AUROC tensor(0.8631, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9421, device='cuda:0')
Train AUROC tensor(0.8639, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9418, device='cuda:0')
Train AUROC tensor(0.8643, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9416, device='cuda:0')
Train AUROC tensor(0.8650, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9395, device='cuda:0')
Train AUROC tensor(0.8654, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9424, device='cuda:0')
Train AUROC tensor(0.8663, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9396, device='cuda:0')
Train AUROC tensor(0.8663, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9408, device='cuda:0')
Train AUROC tensor(0.8675, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9417, device='cuda:0')
Train AUROC tensor(0.8680, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9411, device='cuda:0')
Train AUROC tensor(0.8690, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9391, device='cuda:0')
Train AUROC tensor(0.8691, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9387, device='cuda:0')
Train AUROC tensor(0.8701, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9361, device='cuda:0')
Train AUROC tensor(0.8710, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9377, device='cuda:0')
Train AUROC tensor(0.8717, device='cuda:0')


Validation: 0it [00:00, ?it/s]

Valid AUROC tensor(0.9359, device='cuda:0')
Train AUROC tensor(0.8725, device='cuda:0')


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


Valid AUROC tensor(0.9382, device='cuda:0')
Train AUROC tensor(0.8727, device='cuda:0')


Restoring states from the checkpoint path at /home/daenu/code/molsetrep/notebooks/lightning_logs/version_2/checkpoints/epoch=99-step=2600.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/daenu/code/molsetrep/notebooks/lightning_logs/version_2/checkpoints/epoch=99-step=2600.ckpt


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.7739259600639343,
  'test_accuracy': 0.5882353186607361,
  'test_auroc': 0.6561325788497925}]