In [1]:
import os
import sys

# append the path of the project to sys.path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../src")))
os.chdir("../src")

In [2]:
from lightning import seed_everything
seed_everything(42, workers=True)

Seed set to 42


42

In [3]:
from src.models import MOVVAELightning
import yaml

In [4]:
# Load the configuration file
config_path = "../configs/movae_config.yml"
with open(config_path, "r") as file:
    config = yaml.safe_load(file)

In [5]:
import torch

torch.set_float32_matmul_precision("high")

In [6]:
from src.data_loader import LightningMoleDataModule

data_path = "../data/moles_ohe_no_stereo_sv_500k.npy"
data_module = LightningMoleDataModule(data_path, batch_size=2048, seed=42)

In [7]:
import pickle
with open("../data/int_to_char_no_stereo_sv_500k.pkl", "rb") as file:
    int_to_char = pickle.load(file)

In [8]:
print(int_to_char)

{0: 'N', 1: 'O', 2: 'H', 3: 'S', 4: '2', 5: '-', 6: ']', 7: '1', 8: 'C', 9: '3', 10: '4', 11: 'r', 12: '[', 13: 'P', 14: '6', 15: 'l', 16: 'F', 17: '=', 18: '#', 19: 'c', 20: 'o', 21: '+', 22: 'I', 23: 'n', 24: '(', 25: 'B', 26: 's', 27: '5', 28: ')', 29: 'i', 30: '^', 31: '$', 32: '?'}


In [9]:
seq_len = data_module.seq_length
charset_size = len(int_to_char)

model_config = config.get("model", {})
model_config["args"]["lr"] = 1e-3
model_config

{'name': 'MOVVAELightning',
 'args': {'kl_weight': 1,
  'lr': 0.001,
  'params': {'encoder_params': {'conv_layers': {'conv_1': {'out_channels': 9,
      'kernel_size': 9,
      'activation': 'Tanh',
      'batch_norm': True,
      'name': 'encoder_conv_1'},
     'conv_2': {'out_channels': 9,
      'kernel_size': 9,
      'activation': 'Tanh',
      'batch_norm': True,
      'name': 'encoder_conv_2'},
     'conv_3': {'out_channels': 10,
      'kernel_size': 11,
      'activation': 'Tanh',
      'batch_norm': True,
      'name': 'encoder_conv_3'}},
    'flatten_layers': {'name': 'encoder_flatten'},
    'dense_layers': {'dense_1': {'dimension': 436,
      'activation': 'Tanh',
      'name': 'encoder_dense_1',
      'dropout': 0.083,
      'batch_norm': True}},
    'sampling_layers': {'activation': 'Tanh', 'mean': 0.0, 'stddev': 0.01},
    'latent_dimension': 192},
   'decoder_params': {'latent_dimension': 192,
    'dense_layers': {'dense_1': {'dimension': 436,
      'activation': 'Tanh',


In [10]:
model = MOVVAELightning(
    **model_config["args"],
    seq_len=seq_len,
    charset_size=charset_size,
    int_to_char=int_to_char,
    loss="ce",
    ignore_character="?"
)
model.model

MOVAE(
  (encoder): MOAVEncoder(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (conv_layers): Sequential(
      (0): Conv1d(33, 9, kernel_size=(9,), stride=(1,))
      (1): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
      (3): Conv1d(9, 9, kernel_size=(9,), stride=(1,))
      (4): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Tanh()
      (6): Conv1d(9, 10, kernel_size=(11,), stride=(1,))
      (7): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): Tanh()
    )
    (dense_layers): Sequential(
      (0): Linear(in_features=460, out_features=436, bias=True)
      (1): BatchNorm1d(436, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Tanh()
      (3): Dropout(p=0.083, inplace=False)
    )
    (sampling_layers): Sequential(
      (0): Linear(in_features=436, out_features=384, bias=True)
      (1): Tanh()
    )
  )
  (deco

In [11]:
from lightning.pytorch.loggers import WandbLogger

run_name = "VAE-nbook-ce-step_scheduler15-BS_2048_v1"
wandb_logger = WandbLogger(
    project="MolsVAE",
    name=run_name,
    log_model=True,
    config=model_config,
    id="ovca3kfy",
    resume="allow",
    tags=["VAE", "ce", "ignore_idx", "batch_size-2048", "nbook", "step_scheduler15", "no-stereo", "500k"]
)
wandb_logger.watch(model, log="all")

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: W&B API key is configured. Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888888925108, max=1.0…

wandb: logging graph, to disable use `wandb.watch(log_graph=False)`


In [12]:
from lightning.pytorch.callbacks import EarlyStopping, LearningRateFinder, ModelCheckpoint
from src.callbacks import EarlyStoppingExt
callbacks = []

monitor = "val/cross_entropy_recon_loss"
callbacks.append(
    ModelCheckpoint(
        monitor=monitor,
        save_top_k=3,
        mode="min",
        dirpath=f"../checkpoints/{run_name}",
        filename=f'epoch={{epoch:02d}}-step={{step}}-loss={{{monitor}:.2f}}',
        save_last=True,
        auto_insert_metric_name=False,
    )
)

# callbacks.append(
#     EarlyStoppingExt(
#         monitor=monitor,
#         patience=10,
#         mode="min",
#         reset_on_improvement=True,
#     )
# )

In [13]:
from lightning import Trainer

trainer = Trainer(
    max_epochs=120,
    log_every_n_steps=1,
    num_sanity_val_steps=0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    limit_test_batches=1.0,
    enable_progress_bar=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    logger=wandb_logger,
    deterministic=True,
    accelerator="auto",
    devices=1,
    callbacks=callbacks,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..


In [14]:
trainer.fit(model, data_module)

In [15]:
trainer.fit(model, data_module, ckpt_path="last")

C:\Users\jedra\anaconda3\envs\KE\Lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:654: Checkpoint directory D:\Side-Projects\SMILESculptor\checkpoints\VAE-nbook-ce-step_scheduler15-BS_2048_v1 exists and is not empty.
Restoring states from the checkpoint path at D:\Side-Projects\SMILESculptor\checkpoints\VAE-nbook-ce-step_scheduler15-BS_2048_v1\last.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type               | Params | Mode 
---------------------------------------------------------------------
0 | model                 | MOVAE              | 4.7 M  | train
1 | bce_loss              | BCEWithLogitsLoss  | 0      | train
2 | ce_loss               | CrossEntropyLoss   | 0      | train
3 | perfect_recon_tracker | MeanMetric         | 0      | train
4 | accuracy              | MulticlassAccuracy | 0      | train
---------------------------------------------------------------------
4.7 M     Trainable params
0         Non-trainable params
4.

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=120` reached.


In [16]:
trainer.test(model, dataloaders=data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                 DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  test/binary_ce_recon_loss       1.8101484775543213
test/cross_entropy_recon_loss     0.16551443934440613
        test/kl_loss              0.06992951035499573
          test/loss               0.16551443934440613
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test/binary_ce_recon_loss': 1.8101484775543213,
  'test/cross_entropy_recon_loss': 0.16551443934440613,
  'test/kl_loss': 0.06992951035499573,
  'test/loss': 0.16551443934440613}]

In [17]:
import wandb
wandb.finish()

0,1
epoch,▁▁▂▂▂▂▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇███
test/accuracy,▁
test/binary_ce_recon_loss,▁
test/cross_entropy_recon_loss,▁
test/f1,▁
test/kl_loss,▁
test/loss,▁
test/perfect_reconstruction,▁
test/precision,▁
test/recall,▁

0,1
epoch,120.0
test/accuracy,0.9692
test/binary_ce_recon_loss,1.81015
test/cross_entropy_recon_loss,0.16551
test/f1,0.84216
test/kl_loss,0.06993
test/loss,0.16551
test/perfect_reconstruction,0.4961
test/precision,0.91668
test/recall,0.81982
