In [1]:
from dataclasses import dataclass

import torch
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader

from molsetrep.encoders import RXNSetEncoder
from molsetrep.utils.datasets import get_class_weights
from molsetrep.models import LightningDualSRClassifier

from sklearn.preprocessing import LabelEncoder

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import wandb
from wandb import finish as wandb_finish


Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


## Setup

### Data Loader

In [2]:
@dataclass
class DataSet:
    ids: np.ndarray
    y: np.ndarray

def schneider_loader(
    name: str, preproc: bool = False, **kwargs
):
    df = pd.read_csv("../data/schneider50k.tsv.gz", sep="\t")
    le = LabelEncoder()

    df["class"] = le.fit_transform(df["rxn_class"])

    X_train = np.array([row["rxn"] for _, row in df[df.split == "train"].iterrows()])
    y_train = np.array(
        [
            [row["class"]]
            for _, row in df[df.split == "train"].iterrows()
        ],
        dtype=int,
    )

    X_test = np.array([row["rxn"] for _, row in df[df.split == "test"].iterrows()])
    y_test = np.array(
        [
            [row["class"]]
            for _, row in df[df.split == "test"].iterrows()
        ],
        dtype=int,
    )

    # Just use test set as valid as no valid set is profided as is
    train = DataSet(X_train, y_train)
    test = DataSet(X_test, y_test)
    return (train, test, test)


## Train

### Load Data

In [3]:
data_set_name = "schneider"

train, valid, test = schneider_loader(data_set_name)

enc = RXNSetEncoder()

class_weights, _ = get_class_weights(train.y.flatten())
print(class_weights)

train_dataset = enc.encode(train.ids, [y[0] for y in train.y], label_dtype=torch.long)
valid_dataset = enc.encode(valid.ids, [y[0] for y in valid.y], label_dtype=torch.long)
test_dataset = enc.encode(test.ids, [y[0] for y in test.y], label_dtype=torch.long)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)

d = len(train_dataset[0][0][0])
d2 = len(train_dataset[0][1][0])

[0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98
 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98
 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98
 0.98 0.98 0.98 0.98 0.98 0.98 0.98 0.98]


### Fit

In [4]:
model = LightningDualSRClassifier([8, 8], [8, 8], [d, d2], len(class_weights), class_weights=class_weights)

checkpoint_callback = ModelCheckpoint(monitor=f"val/auroc", mode="max")

wandb_logger = wandb.WandbLogger(project=f"MSR_Schneider_RXN")
wandb_logger.watch(model, log="all")

trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    max_epochs=50,
    log_every_n_steps=1,
    logger=wandb_logger,
)

trainer.fit(
    model, train_dataloaders=train_loader, val_dataloaders=valid_loader
)
trainer.test(ckpt_path="best", dataloaders=test_loader)

wandb_logger.finalize("success")
wandb_finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdaenuprobst[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Ti Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name           | Type                       | Params
---------------------------------------------------------------
0  | sr_classifier  | DualSRClassifier           | 133 K 
1  | criterion      | CrossEntropyLoss           | 0     
2  | criterion_eval | CrossEntropyLoss           | 0     
3  | train_accuracy | MulticlassAccuracy         |

Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]