In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

detanet_dir = "/content/drive/MyDrive/Colab Notebooks/DetaNet/code"

if 'google.colab' in str(get_ipython()):
  from google.colab import drive
  import sys

  print('Running on CoLab, mounting google drive...')
  drive.mount('/content/drive')

  base_dir = "/content/drive/MyDrive/Colab Notebooks/CLAMS"
else:
  base_dir = os.getcwd()
  print('Not running on CoLab')

src_dir = os.path.join(base_dir, "src")

sys.path.append(base_dir)
sys.path.append(src_dir)
sys.path.append(detanet_dir)

Running on CoLab, mounting google drive...
Mounted at /content/drive


In [3]:
!pip install dataset
!pip install transformers
!pip install rdkit
!pip install tqdm
!pip install accelerate -U
!pip install e3nn
!pip install torch_geometric
!pip install torch-cluster==1.6.3 -f https://data.pyg.org/whl/torch-2.2.1+cu121.html
!pip install torch-scatter==2.1.2 -f https://data.pyg.org/whl/torch-2.2.1+cu121.html

Collecting dataset
  Downloading dataset-1.6.2-py2.py3-none-any.whl (18 kB)
Collecting sqlalchemy<2.0.0,>=1.3.2 (from dataset)
  Downloading SQLAlchemy-1.4.52-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting alembic>=0.6.2 (from dataset)
  Downloading alembic-1.13.1-py3-none-any.whl (233 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.4/233.4 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting banal>=1.0.1 (from dataset)
  Downloading banal-1.0.6-py2.py3-none-any.whl (6.1 kB)
Collecting Mako (from alembic>=0.6.2->dataset)
  Downloading Mako-1.3.5-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.6/78.6 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: banal, sqlalchemy, Mako, alembic, datase

In [4]:
import logging
import sys
import warnings
warnings.filterwarnings('ignore')

# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Create STDERR handler
handler = logging.StreamHandler(sys.stdout)

# Create formatter and add it to the handler
formatter = logging.Formatter('%(asctime)-15s %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)

# Set STDERR handler as the only handler
logger.handlers = [handler]

# Read Data

In [5]:
import json

model_config = {
    'run4': {
        'vit': {
            'num_classes': 37,
            'hidden_size': 288,
            'num_hidden_layers': 9,
            'num_attention_heads': 9,
            'intermediate_size': 576,
            'num_channels': 1,
            'image_size': (66, 66),
            'patch_size': (6, 6),
            'hidden_dropout_prob': 0.1,
            'attention_probs_dropout_prob': 0.1,
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "vit_models", 'run4')
        },
        'vit_training': {
            'num_epochs': 100,
            'lr': 1e-3,
            'step_size': 1,
            'gamma': 0.975,
            'early_stopping_epochs': 5,
        },
        'ic_training': {
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "ic_models", 'run4'),
            'num_train_epochs': 120,
            'save_total_limit': 3,
            'max_length': 30,
            'num_beams': 5,
            'early_stopping_patience': 5,
            'canonicalize': True
        }
    },
    'run7': {
        'vit': {
            'num_classes': 18,
            'hidden_size': 288,
            'num_hidden_layers': 9,
            'num_attention_heads': 9,
            'intermediate_size': 576,
            'num_channels': 1,
            'image_size': (66, 66),
            'patch_size': (6, 6),
            'hidden_dropout_prob': 0.1,
            'attention_probs_dropout_prob': 0.1,
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "vit_models", 'run7')
        },
        'vit_training': {
            'num_epochs': 100,
            'lr': 1e-3,
            'step_size': 1,
            'gamma': 0.975,
            'early_stopping_epochs': 5,
        },
        'ic_training': {
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "ic_models", 'run7'),
            'num_train_epochs': 120,
            'save_total_limit': 3,
            'max_length': 30,
            'num_beams': 5,
            'early_stopping_patience': 5,
            'canonicalize': False
        },
    },
    'run8': {
        'vit': {
            'num_classes': 18,
            'hidden_size': 288,
            'num_hidden_layers': 9,
            'num_attention_heads': 9,
            'intermediate_size': 576,
            'num_channels': 1,
            'image_size': (60, 60),
            'patch_size': (6, 6),
            'hidden_dropout_prob': 0.1,
            'attention_probs_dropout_prob': 0.1,
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "vit_models", 'run8')
        },
        'vit_training': {
            'num_epochs': 100,
            'lr': 1e-3,
            'step_size': 1,
            'gamma': 0.975,
            'early_stopping_epochs': 5,
        },
        'ic_training': {
            'batch_size': 300,
            'model_dir': os.path.join(base_dir, "models", "ic_models", 'run8'),
            'num_train_epochs': 120,
            'save_total_limit': 3,
            'max_length': 30,
            'num_beams': 5,
            'early_stopping_patience': 5,
            'canonicalize': False
        },
    },
}

with open(os.path.join(base_dir, "configs/model_config.json"), "w") as f_hd:
    json.dump(model_config, f_hd)

In [6]:
from ir_dataset import IrDataset
from ir_smarts import SMARTS

run = 'run8'
config = model_config[run]

ds = IrDataset(data_list=None, data_path=os.path.join(base_dir, "data"), \
               use_transmittance=False, ir_only=True, \
               canonicalize=config['ic_training']['canonicalize'], \
               smarts=SMARTS,
               further_remove=[23, 169, 176, 303, 464, 745, 791, 3663, 8195, \
                               8416, 13839, 20761, 20770, 20774, 20784, 20785, \
                               20985, 22712, 22716, 22742, 22748, 22758, 22776, \
                               31672, 122842, 122845, 122852, 122855, 122857, \
                               124370, 124378, 124455, 124521, 125697, 125916, \
                               126030, 126087, 127048, 127051, 127066, 127090, \
                               127092, 127093, 127097, 127123])
ds.load()

2024-06-22 11:02:26,279 rdkit - INFO - Enabling RDKit 2023.09.6 jupyter extensions
2024-06-22 11:02:29,407 numexpr.utils - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-06-22 11:02:29,409 numexpr.utils - INFO - NumExpr defaulting to 8 threads.
2024-06-22 11:02:29,713 root - INFO - Loading ir feature data from /content/drive/MyDrive/Colab Notebooks/CLAMS/data/ir_features.pth...
2024-06-22 11:05:00,863 root - INFO - Loading ir label data from /content/drive/MyDrive/Colab Notebooks/CLAMS/data/ir_labels.pth...
2024-06-22 11:05:18,999 root - INFO - Loading uv feature data from /content/drive/MyDrive/Colab Notebooks/CLAMS/data/uv_features.pth...
2024-06-22 11:05:36,051 root - INFO - Loading uv label data from /content/drive/MyDrive/Colab Notebooks/CLAMS/data/uv_labels.pth...
2024-06-22 11:05:45,508 root - INFO - Loading nmr feature data from /content/drive/MyDrive/Colab Notebooks/CLAMS/data/nmr_features.pth...
2024-06-22 11:06:14

In [7]:
len(ds)

127465

# Prepare Data for Training

In [10]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split

batch_size = config['vit']['batch_size']

# Define the sizes of training, validation, and test sets
train_size = int(0.8 * len(ds))  # 80% of the data for training
val_size = int(0.1 * len(ds))    # 10% of the data for validation
test_size = len(ds) - train_size - val_size  # Remaining for testing

# Use random_split to split the dataset
torch.manual_seed(622)
train_dataset, val_dataset, test_dataset = random_split(ds, [train_size, val_size, test_size])

# You can optionally print the sizes of the splits
logging.info(f"Training set size: {len(train_dataset)}")
logging.info(f"Validation set size: {len(val_dataset)}")
logging.info(f"Testing set size: {len(test_dataset)}")

num_workers = 4
prefetch_factor = 2

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, \
                          num_workers=num_workers, prefetch_factor=prefetch_factor)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, \
                        num_workers=num_workers, prefetch_factor=prefetch_factor)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

2024-06-22 11:08:06,062 root - INFO - Training set size: 101972
2024-06-22 11:08:06,063 root - INFO - Validation set size: 12746
2024-06-22 11:08:06,064 root - INFO - Testing set size: 12747


# Inspect the Vit Model Architecture

In [11]:
from encoder import Encoder
import torch

config = model_config[run]
vit_model = Encoder(config['vit'], torch.device("cpu"))
vit_model

Encoder(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 288, kernel_size=(6, 6), stride=(6, 6))
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-8): 9 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=288, out_features=288, bias=True)
              (key): Linear(in_features=288, out_features=288, bias=True)
              (value): Linear(in_features=288, out_features=288, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=288, out_features=288, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=288, out_features

# Train the Vit + MLP for Classification of Functional Groups

In [12]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

logging.info("Device: %s", device)

2024-06-22 11:08:22,268 root - INFO - Device: cuda


In [None]:
from encoder import Encoder


config = model_config[run]
vit_model = Encoder(config['vit'])

model_dir = config['vit']['model_dir']

try:
    vit_model.load_weights(model_dir)
    logging.info("Model weights loaded from %s! Calculating metrics...", model_dir)
except Exception as ex:
    logging.error("Error loading model weights: %s", ex)

vit_model.train_model(**config['vit_training'],
                  model_dir=config['vit']['model_dir'],
                  train_loader=train_loader,
                  val_loader=val_loader,
                  device=device)

vit_model.vit.save_pretrained(config['vit']['model_dir'])

2024-05-12 17:05:47,690 root - ERROR - Error loading model weights: [Errno 2] No such file or directory: '/content/drive/MyDrive/Colab Notebooks/AISpec/models/vit_models/run8/model_weights.pth'


100%|██████████| 341/341 [02:06<00:00,  2.70it/s]

Epoch 1/100, Train Loss: 0.2899025086070704





2024-05-12 17:07:59,633 root - INFO - Epoch 1/100, Validation Loss: 0.28625208260723023
2024-05-12 17:07:59,635 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 2/100, Train Loss: 0.26189831262506247





2024-05-12 17:10:13,830 root - INFO - Epoch 2/100, Validation Loss: 0.23782741613163125
2024-05-12 17:10:13,833 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 3/100, Train Loss: 0.20412713914300945





2024-05-12 17:12:29,163 root - INFO - Epoch 3/100, Validation Loss: 0.18881928223776787
2024-05-12 17:12:29,165 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 4/100, Train Loss: 0.1555533471564229





2024-05-12 17:14:42,980 root - INFO - Epoch 4/100, Validation Loss: 0.1496768934057232
2024-05-12 17:14:42,982 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 5/100, Train Loss: 0.13444840106197212





2024-05-12 17:16:56,728 root - INFO - Epoch 5/100, Validation Loss: 0.16081195618978025


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 6/100, Train Loss: 0.12491377347218562





2024-05-12 17:19:11,086 root - INFO - Epoch 6/100, Validation Loss: 0.19623891443642866


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 7/100, Train Loss: 0.11794425587557444





2024-05-12 17:21:25,078 root - INFO - Epoch 7/100, Validation Loss: 0.1205512910389487
2024-05-12 17:21:25,080 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 8/100, Train Loss: 0.10925794707710683





2024-05-12 17:23:40,610 root - INFO - Epoch 8/100, Validation Loss: 0.1078373360293546
2024-05-12 17:23:40,612 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:07<00:00,  2.67it/s]

Epoch 9/100, Train Loss: 0.10277196992623713





2024-05-12 17:25:54,139 root - INFO - Epoch 9/100, Validation Loss: 0.10334949246687698
2024-05-12 17:25:54,141 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 10/100, Train Loss: 0.09828780257640404





2024-05-12 17:28:08,444 root - INFO - Epoch 10/100, Validation Loss: 0.11226142676283206


100%|██████████| 341/341 [02:08<00:00,  2.64it/s]

Epoch 11/100, Train Loss: 0.09440113872239407





2024-05-12 17:30:23,090 root - INFO - Epoch 11/100, Validation Loss: 0.09485719369146069
2024-05-12 17:30:23,095 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 12/100, Train Loss: 0.08833534563675029





2024-05-12 17:32:37,005 root - INFO - Epoch 12/100, Validation Loss: 0.08558442971419562
2024-05-12 17:32:37,007 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 13/100, Train Loss: 0.08459591110343653





2024-05-12 17:34:51,149 root - INFO - Epoch 13/100, Validation Loss: 0.08661892465384051


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 14/100, Train Loss: 0.08056481032143732





2024-05-12 17:37:05,709 root - INFO - Epoch 14/100, Validation Loss: 0.08132459704744686
2024-05-12 17:37:05,711 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 15/100, Train Loss: 0.07661593667479176





2024-05-12 17:39:19,924 root - INFO - Epoch 15/100, Validation Loss: 0.0728769741012259
2024-05-12 17:39:19,926 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 16/100, Train Loss: 0.0713585765964996





2024-05-12 17:41:35,511 root - INFO - Epoch 16/100, Validation Loss: 0.07099249728597218
2024-05-12 17:41:35,514 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 17/100, Train Loss: 0.0677624975833126





2024-05-12 17:43:49,639 root - INFO - Epoch 17/100, Validation Loss: 0.0668135767716556
2024-05-12 17:43:49,640 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 18/100, Train Loss: 0.06397501877962997





2024-05-12 17:46:04,301 root - INFO - Epoch 18/100, Validation Loss: 0.08219477135383348


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 19/100, Train Loss: 0.06339102293196056





2024-05-12 17:48:18,230 root - INFO - Epoch 19/100, Validation Loss: 0.0651554493592184
2024-05-12 17:48:18,233 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 20/100, Train Loss: 0.05563810427533875





2024-05-12 17:50:33,852 root - INFO - Epoch 20/100, Validation Loss: 0.06063599958980787
2024-05-12 17:50:33,855 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 21/100, Train Loss: 0.05209184752380032





2024-05-12 17:52:48,386 root - INFO - Epoch 21/100, Validation Loss: 0.05571534001923535
2024-05-12 17:52:48,388 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 22/100, Train Loss: 0.049567941261589664





2024-05-12 17:55:02,969 root - INFO - Epoch 22/100, Validation Loss: 0.06561533750183357


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 23/100, Train Loss: 0.04833324796386293





2024-05-12 17:57:17,129 root - INFO - Epoch 23/100, Validation Loss: 0.05602726836668147


100%|██████████| 341/341 [02:08<00:00,  2.64it/s]

Epoch 24/100, Train Loss: 0.042751071224084516





2024-05-12 17:59:31,866 root - INFO - Epoch 24/100, Validation Loss: 0.05366129110048036
2024-05-12 17:59:31,868 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 25/100, Train Loss: 0.040403935180497993





2024-05-12 18:01:46,017 root - INFO - Epoch 25/100, Validation Loss: 0.051255984402595246
2024-05-12 18:01:46,020 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 26/100, Train Loss: 0.03717445564366373





2024-05-12 18:04:00,301 root - INFO - Epoch 26/100, Validation Loss: 0.04951017985837514
2024-05-12 18:04:00,303 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 27/100, Train Loss: 0.03491513711352782





2024-05-12 18:06:14,255 root - INFO - Epoch 27/100, Validation Loss: 0.049181003832294
2024-05-12 18:06:14,257 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 28/100, Train Loss: 0.03267519675601607





2024-05-12 18:08:28,535 root - INFO - Epoch 28/100, Validation Loss: 0.047750630275521705
2024-05-12 18:08:28,538 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 29/100, Train Loss: 0.030063558486591563





2024-05-12 18:10:44,163 root - INFO - Epoch 29/100, Validation Loss: 0.054171679850716505


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 30/100, Train Loss: 0.02809321002825965





2024-05-12 18:12:58,230 root - INFO - Epoch 30/100, Validation Loss: 0.049524489559689536


100%|██████████| 341/341 [02:08<00:00,  2.64it/s]

Epoch 31/100, Train Loss: 0.026142647414528923





2024-05-12 18:15:12,966 root - INFO - Epoch 31/100, Validation Loss: 0.05192383664444179


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 32/100, Train Loss: 0.02500926282508973





2024-05-12 18:17:27,358 root - INFO - Epoch 32/100, Validation Loss: 0.046211751700137124
2024-05-12 18:17:27,361 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 33/100, Train Loss: 0.023843943762562513





2024-05-12 18:19:41,856 root - INFO - Epoch 33/100, Validation Loss: 0.05026492320773003


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 34/100, Train Loss: 0.02310588339125426





2024-05-12 18:21:56,119 root - INFO - Epoch 34/100, Validation Loss: 0.04692121770263008


100%|██████████| 341/341 [02:08<00:00,  2.66it/s]

Epoch 35/100, Train Loss: 0.01940952537835647





2024-05-12 18:24:09,926 root - INFO - Epoch 35/100, Validation Loss: 0.045413940617308204
2024-05-12 18:24:09,932 root - INFO - Saving model weights...


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 36/100, Train Loss: 0.018164532568756237





2024-05-12 18:26:24,188 root - INFO - Epoch 36/100, Validation Loss: 0.04942832896609361


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 37/100, Train Loss: 0.01847372567760264





2024-05-12 18:28:38,852 root - INFO - Epoch 37/100, Validation Loss: 0.04872690643973514


100%|██████████| 341/341 [02:09<00:00,  2.63it/s]

Epoch 38/100, Train Loss: 0.01554073481817237





2024-05-12 18:30:54,242 root - INFO - Epoch 38/100, Validation Loss: 0.04763898480066822


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 39/100, Train Loss: 0.015064824955638271





2024-05-12 18:33:08,823 root - INFO - Epoch 39/100, Validation Loss: 0.04893451567723764


100%|██████████| 341/341 [02:08<00:00,  2.65it/s]

Epoch 40/100, Train Loss: 0.014586447902978706





2024-05-12 18:35:23,228 root - INFO - Epoch 40/100, Validation Loss: 0.04756009100999386
2024-05-12 18:35:23,230 root - INFO - Early stopping triggered.


In [14]:
vit_model.vit.save_pretrained(config['vit']['model_dir'])

# Test the Trained Vit Model

In [15]:
from encoder import Encoder
from ir_smarts import SMARTS

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

logging.info("Device: %s", device)

config = model_config[run]
vit_model = Encoder(config['vit'])

vit_test_ret = vit_model.test_model(model_dir=config['vit']['model_dir'],
                    test_loader=test_loader, device=device,
                    labels=list(SMARTS.keys()))

2024-06-22 11:09:25,629 root - INFO - Device: cuda
2024-06-22 11:09:25,826 root - INFO - Model weights loaded from /content/drive/MyDrive/Colab Notebooks/CLAMS/models/vit_models/run8! Calculating metrics...


100%|██████████| 43/43 [00:01<00:00, 23.70it/s]


2024-06-22 11:09:27,797 root - INFO - Accuracy: 0.909547
2024-06-22 11:09:27,799 root - INFO - 
Classification Report:
              precision    recall  f1-score   support

      alkane       1.00      1.00      1.00     11846
      alkene       0.98      0.96      0.97      1635
      alkyne       0.99      0.99      0.99      1629
       arene       1.00      1.00      1.00        27
  haloalkane       0.94      0.89      0.92       217
     alcohol       1.00      1.00      1.00      4128
    aldehyde       0.99      0.99      0.99      1484
      ketone       0.97      0.96      0.97      1434
       ester       0.92      0.97      0.95       410
       ether       0.96      0.98      0.97      5589
       amine       0.95      0.92      0.94      3845
       amide       0.99      0.95      0.97       875
     nitrile       0.99      0.98      0.98      1541
       imide       0.98      0.88      0.92        56
       thial       0.99      0.99      0.99      1484
      phenol    

In [None]:
vit_test_ret['report_df'].to_hdf(
          os.path.join(config['vit']['model_dir'], "vit_test_results.h5"),
          key='report', mode='w')

In [None]:
with open(os.path.join(config['vit']['model_dir'], f"vit_test_results_{run}.json"), 'w') as fd:
    del vit_test_ret['report_df']
    json.dump(vit_test_ret, fd)

# Training the CLAMS Model

In [16]:
from ir_dataset import generate_ic_dataset
import json
from transformers import AutoTokenizer

# Pre-trained tokenizer
pretrained_decodert_dir = "seyonec/PubChem10M_SMILES_BPE_450k"
roberta_tokenizer = AutoTokenizer.from_pretrained(pretrained_decodert_dir)


ic_train_set = generate_ic_dataset(train_dataset, roberta_tokenizer,
                              max_length = config['ic_training']['max_length'])
ic_val_set = generate_ic_dataset(val_dataset, roberta_tokenizer,
                              max_length = config['ic_training']['max_length'])
ic_test_set = generate_ic_dataset(test_dataset, roberta_tokenizer,
                              max_length = config['ic_training']['max_length'])

tokenizer_config.json:   0%|          | 0.00/62.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/165k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/101k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

In [17]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

logging.info("Device: %s", device)

2024-06-22 11:11:04,701 root - INFO - Device: cuda


In [None]:
from clams import create_clams_model, train_clams_model
from transformers import VisionEncoderDecoderModel
import torch

# If training a model from scratch, set `create_new` to True; otherwise, False
create_new = True

# set encoder decoder tying to True
if create_new:
    logging.info("Creating new model")
    model = create_clams_model(config['ic_training'], config['vit']['model_dir'],
                pretrained_decodert_dir, roberta_tokenizer, device=device)
else:
    logging.info("Loading existing model")
    model = VisionEncoderDecoderModel.from_pretrained(config['ic_training']['model_dir'])


train_clams_model(model, config['ic_training'], ic_train_set, ic_val_set)


2024-05-11 22:48:32,288 root - INFO - Creating new model


pytorch_model.bin:   0%|          | 0.00/336M [00:00<?, ?B/s]

Some weights of RobertaForCausalLM were not initialized from the model checkpoint at seyonec/PubChem10M_SMILES_BPE_450k and are newly initialized: ['roberta.encoder.layer.0.crossattention.output.LayerNorm.bias', 'roberta.encoder.layer.0.crossattention.output.LayerNorm.weight', 'roberta.encoder.layer.0.crossattention.output.dense.bias', 'roberta.encoder.layer.0.crossattention.output.dense.weight', 'roberta.encoder.layer.0.crossattention.self.key.bias', 'roberta.encoder.layer.0.crossattention.self.key.weight', 'roberta.encoder.layer.0.crossattention.self.query.bias', 'roberta.encoder.layer.0.crossattention.self.query.weight', 'roberta.encoder.layer.0.crossattention.self.value.bias', 'roberta.encoder.layer.0.crossattention.self.value.weight', 'roberta.encoder.layer.1.crossattention.output.LayerNorm.bias', 'roberta.encoder.layer.1.crossattention.output.LayerNorm.weight', 'roberta.encoder.layer.1.crossattention.output.dense.bias', 'roberta.encoder.layer.1.crossattention.output.dense.weight'

Epoch,Training Loss,Validation Loss
1,0.7228,0.269946
2,0.2626,0.223774
3,0.2217,0.196689
4,0.1952,0.175975
5,0.1765,0.162704
6,0.1618,0.154973
7,0.1495,0.146573
8,0.1389,0.138451
9,0.1299,0.133828
10,0.1216,0.130295


Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}
Non-default generation parameters: {'max_length': 30, 'early

In [None]:
model.save_pretrained("ic_model_run8")

Non-default generation parameters: {'max_length': 30, 'early_stopping': True, 'num_beams': 5}


In [24]:
model

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 288, kernel_size=(6, 6), stride=(6, 6))
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-8): 9 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=288, out_features=288, bias=True)
              (key): Linear(in_features=288, out_features=288, bias=True)
              (value): Linear(in_features=288, out_features=288, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=288, out_features=288, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_feat