In [None]:
from dataclasses import dataclass

import torch
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader

from molsetrep.encoders import RXNSetEncoder
from molsetrep.utils.datasets import get_class_weights
from molsetrep.models import LightningDualSRClassifier

from sklearn.preprocessing import LabelEncoder

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import wandb
from wandb import finish as wandb_finish


## Setup

### Data Loader

In [None]:
@dataclass
class DataSet:
    ids: np.ndarray
    y: np.ndarray

def schneider_loader(
    name: str, preproc: bool = False, **kwargs
):
    df = pd.read_csv("../data/schneider50k.tsv.gz", sep="\t")
    le = LabelEncoder()

    df["class"] = le.fit_transform(df["rxn_class"])

    X_train = np.array([row["rxn"] for _, row in df[df.split == "train"].iterrows()])
    y_train = np.array(
        [
            [row["class"]]
            for _, row in df[df.split == "train"].iterrows()
        ],
        dtype=int,
    )

    X_test = np.array([row["rxn"] for _, row in df[df.split == "test"].iterrows()])
    y_test = np.array(
        [
            [row["class"]]
            for _, row in df[df.split == "test"].iterrows()
        ],
        dtype=int,
    )

    # Just use test set as valid as no valid set is profided as is
    train = DataSet(X_train, y_train)
    test = DataSet(X_test, y_test)
    return (train, test, test)


## Train

### Load Data

In [None]:
data_set_name = "schneider"

train, valid, test = schneider_loader(data_set_name)

enc = RXNSetEncoder()

class_weights, _ = get_class_weights(train.y.flatten())
print(class_weights)

train_dataset = enc.encode(train.ids, [y[0] for y in train.y], label_dtype=torch.long)
valid_dataset = enc.encode(valid.ids, [y[0] for y in valid.y], label_dtype=torch.long)
test_dataset = enc.encode(test.ids, [y[0] for y in test.y], label_dtype=torch.long)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8, drop_last=True)

d = len(train_dataset[0][0][0])
d2 = len(train_dataset[0][1][0])

### Fit

In [None]:
model = LightningDualSRClassifier([32, 32], [32, 32], [d, d2], len(class_weights), n_hidden_channels=[128, 16], class_weights=class_weights)

checkpoint_callback = ModelCheckpoint(monitor=f"val/auroc", mode="max")

wandb_logger = wandb.WandbLogger(project=f"MSR_Schneider_RXN")
wandb_logger.watch(model, log="all")

trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    max_epochs=50,
    log_every_n_steps=1,
    logger=wandb_logger,
)

trainer.fit(
    model, train_dataloaders=train_loader, val_dataloaders=valid_loader
)
trainer.test(ckpt_path="best", dataloaders=test_loader)

wandb_logger.finalize("success")
wandb_finish()