In [1]:
import os
import sys

# append the path of the project to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../src")))
os.chdir("../src")

In [2]:
from lightning import seed_everything
seed_everything(42, workers=True)

Seed set to 42


42

In [3]:
from src.models import MOVVAELightning
import yaml

In [4]:
# Load the configuration file
config_path = "../configs/movae_config.yml"
with open(config_path, "r") as file:
    config = yaml.safe_load(file)

In [5]:
import torch

torch.set_float32_matmul_precision("high")

In [6]:
from src.data_loader import LightningMoleDataModule

data_path = "../data/moles_ohe_no_stereo_sv_500k.npy"
data_module = LightningMoleDataModule(data_path, batch_size=512, seed=42)

In [7]:
import pickle
with open("../data/int_to_char_no_stereo_sv_500k.pkl", "rb") as file:
    int_to_char = pickle.load(file)

In [8]:
print(int_to_char)

{0: 'N', 1: 'O', 2: 'H', 3: 'S', 4: '2', 5: '-', 6: ']', 7: '1', 8: 'C', 9: '3', 10: '4', 11: 'r', 12: '[', 13: 'P', 14: '6', 15: 'l', 16: 'F', 17: '=', 18: '#', 19: 'c', 20: 'o', 21: '+', 22: 'I', 23: 'n', 24: '(', 25: 'B', 26: 's', 27: '5', 28: ')', 29: 'i', 30: '^', 31: '$', 32: '?'}


In [12]:
seq_len = data_module.seq_length
charset_size = len(int_to_char)

model_config = config.get("model", {})
model_config["args"]["lr"] = 1e-3

In [13]:
model = MOVVAELightning(
    **model_config["args"],
    seq_len=seq_len,
    charset_size=charset_size,
    int_to_char=int_to_char,
    loss="ce",
)
model.model

MOVAE(
  (encoder): MOAVEncoder(
    (softplus): Softplus(beta=1.0, threshold=20.0)
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (conv_layers): Sequential(
      (0): Conv1d(33, 9, kernel_size=(9,), stride=(1,))
      (1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
      (3): Conv1d(9, 9, kernel_size=(9,), stride=(1,))
      (4): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Tanh()
      (6): Conv1d(9, 10, kernel_size=(11,), stride=(1,))
      (7): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): Tanh()
    )
    (dense_layers): Sequential(
      (0): Linear(in_features=460, out_features=436, bias=True)
      (1): BatchNorm1d(436, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
      (3): Dropout(p=0.083, inplace=False)
    )
  )
  (decoder): MOVAEDecoder(
    (dense_layers): Sequential(
      (0): Linear(in

In [14]:
from lightning.pytorch.loggers import WandbLogger

run_name = "MAE-notebook-ce-step_scheduler"
wandb_logger = WandbLogger(
    project="MolsVAE",
    name=run_name,
    log_model=True,
    config=model_config,
    tags=["ce", "step-scheduler", "batch-size-512", "no-stereo", "500k"],
)
wandb_logger.watch(model, log="all")

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: jedrasowicz. Use `wandb login --relogin` to force relogin


wandb: logging graph, to disable use `wandb.watch(log_graph=False)`


In [15]:
from lightning.pytorch.callbacks import EarlyStopping, LearningRateFinder, ModelCheckpoint

callbacks = []

monitor = "val/cross_entropy_recon_loss"
callbacks.append(
    ModelCheckpoint(
        monitor=monitor,
        save_top_k=3,
        mode="min",
        dirpath=f"../checkpoints/{run_name}",
        filename=f'epoch={{epoch:02d}}-step={{step}}-loss={{{monitor}:.2f}}',
        save_last=True,
        auto_insert_metric_name=False,
    )
)

# callbacks.append(
#     EarlyStopping(
#         monitor=monitor,
#         patience=10,
#         mode="min",
#     )
# )

In [16]:
from lightning import Trainer

trainer = Trainer(
    max_epochs=100,
    log_every_n_steps=1,
    num_sanity_val_steps=0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    limit_test_batches=1.0,
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    logger=wandb_logger,
    deterministic=True,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..


In [17]:
trainer.fit(model, data_module)

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [15]:
trainer.test(dataloaders=data_module)

Restoring states from the checkpoint path at D:\Side-Projects\SMILESculptor\checkpoints\MAE-notebook-bce-step_scheduler\3-3124.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at D:\Side-Projects\SMILESculptor\checkpoints\MAE-notebook-bce-step_scheduler\3-3124.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  test/binary_ce_recon_loss      0.037198010832071304
test/cross_entropy_recon_loss     0.7965277433395386
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/binary_ce_recon_loss': 0.037198010832071304,
  'test/cross_entropy_recon_loss': 0.7965277433395386}]

In [16]:
callbacks = []

monitor = "val/binary_ce_recon_loss"
callbacks.append(
    ModelCheckpoint(
        monitor=monitor,
        save_top_k=3,
        mode="min",
        dirpath=f"../checkpoints/{run_name}",
        save_last=True,
        auto_insert_metric_name=False,
    )
)

In [17]:
trainer = Trainer(
    max_epochs=100,
    log_every_n_steps=1,
    num_sanity_val_steps=0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    limit_test_batches=1.0,
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    logger=wandb_logger,
    deterministic=True,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..


In [18]:
trainer.fit(model, data_module, ckpt_path="last")

C:\Users\jedra\anaconda3\envs\KE\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:654: Checkpoint directory D:\Side-Projects\SMILESculptor\checkpoints\MAE-notebook-bce-step_scheduler exists and is not empty.
Restoring states from the checkpoint path at D:\Side-Projects\SMILESculptor\checkpoints\MAE-notebook-bce-step_scheduler\last.ckpt
C:\Users\jedra\anaconda3\envs\KE\Lib\site-packages\lightning\pytorch\trainer\call.py:277: Be aware that when using `ckpt_path`, callbacks used to create the checkpoint need to be provided during `Trainer` instantiation. Please add the following callbacks: ["EarlyStopping{'monitor': 'val/binary_ce_recon_loss', 'mode': 'min'}"].
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                  | Type                | Params | Mode 
-----------------------------------------------------------------------
0  | model                 | MOVAE               | 4.6 M  | train
1  | bce_loss              | BCEWithLogitsLoss   | 0      | train
2  

Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [19]:
import wandb
wandb.finish()

0,1
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇█
test/accuracy,▁
test/binarized_accuracy,▁
test/binarized_f1,▁
test/binarized_precision,▁
test/binarized_recall,▁
test/binary_ce_recon_loss,▁
test/cross_entropy_recon_loss,▁
test/f1,▁
test/perfect_reconstruction,▁

0,1
epoch,80.0
test/accuracy,0.73231
test/binarized_accuracy,0.98513
test/binarized_f1,0.70837
test/binarized_precision,0.87295
test/binarized_recall,0.596
test/binary_ce_recon_loss,0.0372
test/cross_entropy_recon_loss,0.79653
test/f1,0.73231
test/perfect_reconstruction,0.0
