In [1]:
import sys

sys.path.append("/home/calvin/code/cmpnn_revised")

In [2]:
import pickle
from torch.utils.data import DataLoader
from grover_cmpnn.dataset.atom_vocab import AtomVocab
from grover_cmpnn.dataset import atom_dataset

In [3]:
with open("/home/calvin/code/cmpnn_revised/grover_cmpnn/builder/atom_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)
with open("/home/calvin/code/cmpnn_revised/grover_cmpnn/builder/token_ids.pkl", "rb") as f:
    token_seqs = pickle.load(f)

In [4]:
# Chek if there is None in the token_seqs
for i, token_seq in enumerate(token_seqs):
    if token_seq is None:
        print(f"token_seqs[{i}] is None")
        break

In [5]:
token_seqs = [seq for seq in token_seqs if seq is not None]

In [6]:
vocab = AtomVocab.load("/home/calvin/code/cmpnn_revised/grover_cmpnn/builder/atom_vocab.pkl")

In [7]:
# Check if there is None in the vocab
vocab.stoi

{'<pad>': 0,
 '<other>': 1,
 'C2sc': 2,
 'CO': 3,
 'Cd': 4,
 'Cdd': 5,
 'Cs': 6,
 'Csc': 7,
 'Ct': 8,
 'F1s': 9,
 'H0': 10,
 'N1sc': 11,
 'N3d': 12,
 'N3s': 13,
 'N3t': 14,
 'N5dc': 15,
 'N5sc': 16,
 'O0sc': 17,
 'O2d': 18,
 'O2s': 19,
 '<mask>': 20}

In [8]:
dataset = atom_dataset.MaskedAtomDataset(token_sequences=token_seqs, vocab=vocab, add_cls_token=False)

In [9]:
dataset.mask_token_id

20

In [10]:
dataset.pad_token_id

0

In [11]:
from torch import nn
import torch
from cmpnn.models.cmpnn import CMPNNEncoder

class CMPNNPretrainModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, cmpnn_config):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.encoder = CMPNNEncoder(atom_fdim=embed_dim, bond_fdim=0, **cmpnn_config)
        self.proj = nn.Linear(cmpnn_config['hidden_dim'], vocab_size)

    def forward(self, input_ids, a2b, b2a, b2revb, a_scope):
        f_atoms = self.embedding(input_ids)
        f_bonds = torch.zeros((a2b.size(0), 1), device=f_atoms.device)  # dummy
        atom_repr = self.encoder(f_atoms, f_bonds, a2b, b2a, b2revb, a_scope)
        return self.proj(atom_repr)


In [12]:
from grover_cmpnn.dataset.masked_graph import MaskedMoleculeDataset
from cmpnn.featurizer.molecule_dataset import MoleculeDataset
from cmpnn.featurizer.atom_bond import AtomFeaturizer, BondFeaturizer

In [13]:
from rdkit import Chem
molecules = Chem.SupplierFromFilename("/home/calvin/code/cmpnn_revised/molnet_data/qm9_data/raw/gdb9.sdf", sanitize=False, removeHs=False)
# Save molecules/smiles to a csv file
import pandas as pd
from tqdm import tqdm

# molecules_list = []

# for mol in tqdm(molecules):
#     if mol is None:
#         continue
#     try:
#         smiles = Chem.MolToSmiles(mol)
#         molecules_list.append(smiles)
#     except Exception as e:
#         print(f"Error: {e}")
#         continue

# df = pd.DataFrame(molecules_list, columns=["smiles"])
# df.to_csv("/home/calvin/code/cmpnn_revised/molnet_data/qm9_data/raw/gdb9_smiles.csv", index=False)

In [14]:
vocab.stoi['<mask>']

20

In [15]:
df = pd.read_csv("/home/calvin/code/cmpnn_revised/grover_cmpnn/builder/smiles.csv")

In [16]:
from collections import OrderedDict

# Keep the first occurrence only (preserves order)
unique_data = OrderedDict()
for smi, toks in zip(df["smiles"].tolist(), token_seqs):
    if smi not in unique_data:
        unique_data[smi] = toks

unique_smiles = list(unique_data.keys())
unique_token_ids = list(unique_data.values())

In [40]:
dataset = MaskedMoleculeDataset(smiles_list=unique_smiles, token_ids_list=unique_token_ids, atom_featurizer=AtomFeaturizer(v2=False), bond_featurizer=BondFeaturizer(), mask_token_id=20,
                                add_hs=True, sanitize=False, mask_prob=0.4, k_per_class=3, string_dedupe=True, canonicalize_dedupe=True, atom_messages=True)

Using all atomic numbers from 1 to 100


In [41]:
dataset[0]

MoleculeData(f_atoms=[5, 133], f_bonds=[8, 14], a2b=[5], b2a=[8], a_scope=[1], b_scope=[1], bonds=[4, 2], smiles='[H]C([H])([H])[H]', b2revb=[8], input_ids=[5], labels=[5])

In [42]:
token_seqs[13]

[10, 19, 6, 10, 10, 6, 10, 10, 10]

In [43]:
df["smiles"].tolist()[13]

'[H]OC([H])([H])C([H])([H])[H]'

In [44]:
dataset[13].labels

tensor([  10,   19,    6,   10, -100,    6,   10, -100, -100])

In [45]:
def visualize_masking(dataset, vocab, idx: int = 0):
    inverse_vocab = {v: k for k, v in vocab.stoi.items()}
    item = dataset[idx]
    input_ids = item.input_ids.tolist()
    labels = item.labels.tolist()

    print(f"{'Pos':<5} {'Input ID':<10} {'Input Token':<15} {'Label':<10} {'Label Token'}")
    print("-" * 60)
    for i, (inp, lbl) in enumerate(zip(input_ids, labels)):
        input_tok = inverse_vocab.get(inp, f"<unk:{inp}>")
        label_tok = inverse_vocab.get(lbl, "-") if lbl != -100 else "-"
        print(f"{i:<5} {inp:<10} {input_tok:<15} {lbl:<10} {label_tok}")


In [46]:
visualize_masking(dataset, vocab, idx=16)


Pos   Input ID   Input Token     Label      Label Token
------------------------------------------------------------
0     20         <mask>          10         H0
1     20         <mask>          6          Cs
2     20         <mask>          10         H0
3     20         <mask>          19         O2s
4     20         <mask>          6          Cs
5     10         H0              -100       -
6     20         <mask>          10         H0


In [47]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from sklearn.metrics import classification_report
from cmpnn.models.cmpnn import CMPNNEncoder

class MaskedCMPNNPretrainModule(pl.LightningModule):
    def __init__(self,
                 atom_fdim: int,
                 bond_fdim: int,
                 hidden_dim: int,
                 vocab_size: int,
                 dropout: float = 0.1,
                 depth: int = 3,
                 lr: float = 1e-3,
                 weight_decay: float = 0.0,
                 vocab: dict = None,
                 comm_mode: str = 'add',
                 booster: str = 'sum',
                 dynamic_depth: str = None):
        super().__init__()

        self.save_hyperparameters()

        self.encoder = CMPNNEncoder(
            atom_fdim=atom_fdim,
            bond_fdim=bond_fdim,
            atom_messages=True,
            depth=depth,
            hidden_dim=hidden_dim,
            dropout=dropout,
            comm_mode=comm_mode,
            booster=booster,
            dynamic_depth=dynamic_depth
        )
        #self.mask_vector = nn.Parameter(torch.zeros(atom_fdim), requires_grad=True)
        # Tank the accuracy of the model
        self.mask_vector = nn.Parameter(torch.randn(atom_fdim), requires_grad=False)
        self.vocab = vocab
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

        self.val_preds = []
        self.val_labels = []

    def forward(self, batch):
        f_atoms_masked = batch.f_atoms.clone()

        # Replace RDKit features at masked positions
        mask = (batch.labels != -100)
        f_atoms_masked = batch.f_atoms.clone()
        #f_atoms_masked[mask] = self.mask_vector 
        #f_atoms_masked[mask] = torch.randn_like(self.mask_vector)
        f_atoms_masked[mask] = torch.full_like(f_atoms_masked[mask], fill_value=-1)

        # if mask.sum() == 0:
        #     print("⚠️ No masked atoms in batch")

        # print("Mask shape:", mask.shape, "Masked count:", mask.sum().item())
        # print("Diff in f_atoms masked?", not torch.equal(batch.f_atoms[mask], f_atoms_masked[mask]))

        atom_repr = self.encoder(
            f_atoms=f_atoms_masked,
            f_bonds=batch.f_bonds,
            a2b=batch.a2b,
            b2a=batch.b2a,
            b2revb=batch.b2revb,
            a_scope=batch.a_scope,
            mask=mask,
        )
        return self.output_layer(atom_repr)

    def compute_loss_and_acc(self, logits, labels, input_ids):
        mask = (labels != -100) & (input_ids != -100)

        logits_masked = logits[mask]
        labels_masked = labels[mask]
        preds = logits.argmax(dim=-1)[mask]

        loss = F.cross_entropy(logits_masked, labels_masked)
        acc = (preds == labels_masked).float().mean()

        return loss, acc

    def training_step(self, batch, batch_idx):
        logits = self.forward(batch)
        loss, acc = self.compute_loss_and_acc(logits, batch.labels, batch.input_ids)

        num_masked = (batch.labels != -100).sum().item()
        total_atoms = (batch.labels != -100).numel()
        mask_rate = num_masked / total_atoms

        self.log("train_loss", loss, prog_bar=True, batch_size=batch.n_mols) 
        self.log("train_acc", acc, prog_bar=True, batch_size=batch.n_mols) 
        self.log("train_mask_rate", mask_rate, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        logits = self.forward(batch)
        labels = batch.labels

        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)
        mask = labels != -100

        logits_masked = logits[mask]
        labels_masked = labels[mask]
        preds = logits_masked.argmax(dim=-1)

        # Store for epoch-end
        self.val_preds.append(preds)
        self.val_labels.append(labels_masked)

        loss = F.cross_entropy(logits_masked, labels_masked)
        acc = (preds == labels_masked).float().mean()

        self.log("val_loss", loss, prog_bar=True, batch_size=batch.n_mols)
        self.log("val_acc", acc, prog_bar=True, batch_size=batch.n_mols)

        return loss

    def on_validation_epoch_end(self):
        print(f"[Epoch {self.current_epoch}] Entered on_validation_epoch_end")
        print(f"val_preds: {len(self.val_preds)}, val_labels: {len(self.val_labels)}")

        if self.val_preds and self.val_labels:
            print(f"[Epoch {self.current_epoch}] Entered on_validation_epoch_end - if loop")
            preds = torch.cat(self.val_preds)
            labels = torch.cat(self.val_labels)

            cm_figure = self.plot_confusion_matrix(labels, preds)
            print(f"[Epoch {self.current_epoch}] Entered on_validation_epoch_end - plot_confusion_matrix")
            if hasattr(self.logger, "experiment"):
                self.logger.experiment.add_figure("Confusion Matrix", cm_figure, self.current_epoch)
            else:
                print("No logger attached to experiment")
    def on_validation_epoch_start(self):
        self.val_preds = []
        self.val_labels = []

    def test_step(self, batch, batch_idx):
        logits = self.forward(batch)
        labels = batch.labels
        input_ids = batch.input_ids

        # Flatten
        logits = logits.view(-1, logits.size(-1))
        labels = labels.view(-1)

        # Mask
        mask = (labels != -100) & (input_ids != -100)
        if mask.sum() == 0:
            return {"test_loss": torch.tensor(0.0, device=self.device), "test_acc": torch.tensor(0.0, device=self.device)}

        logits_masked = logits[mask]
        labels_masked = labels[mask]
        preds = logits_masked.argmax(dim=-1)

        loss = F.cross_entropy(logits_masked, labels_masked)
        acc = (preds == labels_masked).float().mean()

        self.log("test_loss", loss, prog_bar=True, batch_size=batch.n_mols) 
        self.log("test_acc", acc, prog_bar=True, batch_size=batch.n_mols)

        return {"test_loss": loss, "test_acc": acc}
    
    def visualize_predictions(self, logits, labels, inverse_vocab, max_rows=20):
        """
        Visualizes predictions and ground truths for masked tokens.
        """
        preds = logits.argmax(dim=-1)
        output = []

        for i, (true_id, pred_id) in enumerate(zip(labels.tolist(), preds.tolist())):
            if true_id == -100:
                continue
            true_tok = inverse_vocab.get(true_id, '<unk>')
            pred_tok = inverse_vocab.get(pred_id, '<unk>')
            output.append((i, true_tok, pred_tok))
            if len(output) >= max_rows:
                break

        # Print table
        print(f"{'Pos':>4}  {'True Token':>12}  {'Pred Token':>12}")
        print('-' * 36)
        for i, true_tok, pred_tok in output:
            print(f"{i:>4}  {true_tok:>12}  {pred_tok:>12}")

    def plot_confusion_matrix(self, labels, preds):
        import matplotlib.pyplot as plt
        import seaborn as sns
        from sklearn.metrics import confusion_matrix
        import io
        from PIL import Image
        import numpy as np
        import torch
        labels = labels.cpu().numpy()
        preds = preds.cpu().numpy()
        cm = confusion_matrix(labels, preds, labels=list(range(len(self.vocab))))

        fig, ax = plt.subplots(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=self.vocab.itos,
                    yticklabels=self.vocab.itos,
                    ax=ax)
        ax.set_xlabel("Predicted")
        ax.set_ylabel("True")
        ax.set_title("Confusion Matrix")
        # plt.tight_layout()
        # plt.show()
        return fig
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)


In [48]:
from grover_cmpnn.dataset.collate import mol_masked_collate

total_size = len(dataset)
print(f"Total dataset size: {total_size}")
train_size = int(0.8 * total_size)
print(f"Train dataset size: {train_size}")
val_size = int(0.1 * total_size)
print(f"Validation dataset size: {val_size}")
test_size = total_size - train_size - val_size
print(f"Test dataset size: {test_size}")
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, collate_fn=mol_masked_collate, num_workers=6)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, collate_fn=mol_masked_collate, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=mol_masked_collate)


# train_loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=mol_masked_collate)

Total dataset size: 131557
Train dataset size: 105245
Validation dataset size: 13155
Test dataset size: 13157
Train dataset size: 105245
Validation dataset size: 13155
Test dataset size: 13157


In [49]:
from rdkit import Chem
from collections import Counter

def canonicalize(smi):
    mol = Chem.MolFromSmiles(smi)
    return Chem.MolToSmiles(mol, canonical=True) if mol else None

# Suppose smiles_list is your raw list
canonical_smiles = [canonicalize(smi) for smi in df["smiles"].tolist()]
canonical_smiles = [smi for smi in canonical_smiles if smi is not None]

counter = Counter(canonical_smiles)
duplicates = {smi: count for smi, count in counter.items() if count > 1}

print(f"Found {len(duplicates)} canonical duplicates.")
for smi, count in list(duplicates.items())[:10]:
    print(f"{smi}: {count} times")


Found 263 canonical duplicates.
CC#CC: 2 times
Cc1cnccn1: 2 times
Nc1cnccn1: 2 times
Oc1cnccn1: 2 times
N/C=N/CC(=O)O: 2 times
C/[NH+]=C(\N)C(=O)[O-]: 2 times
Fc1cnccn1: 2 times
Nc1cc(=O)cc[nH]1: 3 times
Cc1cnc(N)cn1: 2 times
Cc1cnc(O)cn1: 2 times


In [50]:
all_smiles = dataset.fixed_smiles  # should already be set
print(f"Unique SMILES in full dataset: {len(set(all_smiles))}")


Unique SMILES in full dataset: 131557


In [51]:
train_smiles = set(dataset.fixed_smiles[i] for i in train_dataset.indices)
val_smiles = set(dataset.fixed_smiles[i] for i in val_dataset.indices)
test_smiles = set(dataset.fixed_smiles[i] for i in test_dataset.indices)

overlap_tv = train_smiles & val_smiles
overlap_tt = train_smiles & test_smiles
overlap_vt = val_smiles & test_smiles

print(f"Overlap Train/Val: {len(overlap_tv)}")
print(f"Overlap Train/Test: {len(overlap_tt)}")
print(f"Overlap Val/Test: {len(overlap_vt)}")


Overlap Train/Val: 0
Overlap Train/Test: 0
Overlap Val/Test: 0


In [52]:
train_dataset

<torch.utils.data.dataset.Subset at 0x7fc3828005b0>

In [53]:
dataset[22]

MoleculeData(f_atoms=[10, 133], f_bonds=[18, 14], a2b=[10], b2a=[18], a_scope=[1], b_scope=[1], bonds=[9, 2], smiles='[H]CC#CC[H]', b2revb=[18], input_ids=[10], labels=[10])

In [54]:
token_seqs[22]

[10, 6, 8, 8, 6, 10, 10, 10, 10, 10]

In [55]:
df["smiles"].tolist()[22]

'[H]CC#CC[H]'

In [56]:
batch = next(iter(train_loader))

In [57]:
batch

DataMoleculeDataBatch(smiles=[64], n_mols=64, max_num_bonds=4, f_atoms=[1202, 133], f_bonds=[2475, 14], a2b=[1202, 4], b2a=[2475], b2revb=[2475], bonds=[2, 2475], a_scope=[64], b_scope=[64], input_ids=[1202], labels=[1202])

In [58]:
len(vocab)

21

In [61]:
encoder = CMPNNEncoder(
            atom_fdim=133,
            bond_fdim=14,
            atom_messages=True,
            depth=3,
            hidden_dim=128,
            dropout=0.1,
        )

atom_hiddens = encoder(
    f_atoms=batch.f_atoms,
    f_bonds=batch.f_bonds,
    a2b=batch.a2b,
    b2a=batch.b2a,
    b2revb=batch.b2revb,
    a_scope=batch.a_scope,
)

vocab

logits = nn.Linear(128, 21)(atom_hiddens)

mask = batch.labels != -100
logits_masked = logits[mask]              # [N_masked, vocab_size]
labels_masked = batch.labels[mask]        # [N_masked]

loss = F.cross_entropy(logits_masked, labels_masked)


In [62]:
loss

tensor(2.9747, grad_fn=<NllLossBackward0>)

In [63]:
labels = batch.labels                # [n_atoms]

print("logits", logits.shape)
print("labels", labels.shape)
print("mask", mask.shape)
print("masked logits", logits[mask].shape)
print("masked labels", labels[mask].shape)


logits torch.Size([1202, 21])
labels torch.Size([1202])
mask torch.Size([1202])
masked logits torch.Size([602, 21])
masked labels torch.Size([602])


In [64]:
logits.shape

torch.Size([1202, 21])

In [65]:
loss = F.cross_entropy(logits[mask], labels[mask])

In [66]:
atom_hiddens.shape

torch.Size([1202, 128])

In [67]:
# Just once before training
train_smiles_set = set(train_dataset.dataset.smiles)
val_smiles_set = set(val_dataset.dataset.smiles)
print("Overlap in SMILES:", len(train_smiles_set & val_smiles_set))


Overlap in SMILES: 131557


In [68]:
mol1 = dataset[0]
mol2 = dataset[0]
print("Equal input_ids?", torch.equal(mol1.input_ids, mol2.input_ids))

Equal input_ids? False


In [69]:
train_dataset.dataset.smiles

('[H]C([H])([H])[H]',
 '[H]N([H])[H]',
 '[H]O[H]',
 '[H]C#C[H]',
 '[H]C#N',
 '[H]C([H])=O',
 '[H]C([H])([H])C([H])([H])[H]',
 '[H]OC([H])([H])[H]',
 '[H]C#CC([H])([H])[H]',
 '[H]C([H])([H])C#N',
 '[H]C(=O)C([H])([H])[H]',
 '[H]C(=O)N([H])[H]',
 '[H]C([H])([H])C([H])([H])C([H])([H])[H]',
 '[H]OC([H])([H])C([H])([H])[H]',
 '[H]C([H])([H])OC([H])([H])[H]',
 '[H]C1([H])C([H])([H])C1([H])[H]',
 '[H]C1([H])OC1([H])[H]',
 '[H]C([H])([H])C(=O)C([H])([H])[H]',
 '[H]N([H])C(=O)C([H])([H])[H]',
 '[H]N([H])C(=O)N([H])[H]',
 '[H]C([H])([H])C([H])(C([H])([H])[H])C([H])([H])[H]',
 '[H]OC([H])(C([H])([H])[H])C([H])([H])[H]',
 '[H]CC#CC[H]',
 '[H]CC#C[NH3+]',
 'N#CC#N',
 '[H]C#CC([H])=O',
 '[H]C(=O)C#N',
 '[H]C(=O)C([H])=O',
 '[H]C#CC([H])([H])C([H])([H])[H]',
 '[H]C([H])([H])C([H])([H])C#N',
 '[H]N([H])C([H])([H])C#N',
 '[H]C#CC([H])([H])O[H]',
 '[H]OC([H])([H])C#N',
 '[H]C(=O)C([H])([H])C([H])([H])[H]',
 '[H]C(=O)N([H])C([H])([H])[H]',
 '[H]C(=O)OC([H])([H])[H]',
 '[H]OC([H])([H])C([H])=O',
 '[H]C([H

In [70]:
from pytorch_lightning import Trainer

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

checkpoint_callback = ModelCheckpoint(
    monitor="val_acc",
    dirpath="checkpoints/",
    filename="best-checkpoint",
    save_top_k=1,
    mode="max",
)
early_stopping_callback = EarlyStopping(
    monitor="val_acc",
    patience=5,
    verbose=True,
    mode="max",
)

trainer = Trainer(
    max_epochs=2,
    accelerator="gpu",
    callbacks=[checkpoint_callback, early_stopping_callback],
    logger=True,
    enable_checkpointing=True,
    enable_model_summary=True,
    check_val_every_n_epoch=1,
    num_sanity_val_steps=0
)
model = MaskedCMPNNPretrainModule(
    atom_fdim=133,
    bond_fdim=14,
    hidden_dim=128,
    vocab=vocab,
    vocab_size=len(vocab),
    dropout=0.1,
    depth=3,
    lr=1e-3,
    weight_decay=0.0,
    comm_mode='add',
    booster='sum',
)
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/calvin/miniforge3/envs/dmpnn_rocm/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /home/calvin/code/cmpnn_revised/grover_cmpnn/dataset/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type         | Params | Mode 
------------------------------------------------------
0 | encoder      | CMPNNEncoder | 332 K  | train
1 | output_layer | Linear       | 2.7 K  | train
  | other params | n/a          | 133    | n/a  
------------------------------------------------------
335 K     Trainable params
133       Non-trainable params
335 K     Total params
1.342     Total estimated model params size (MB)
13        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] st

Validation: |          | 0/? [00:00<?, ?it/s]

input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] st

Metric val_acc improved. New best score: 0.984


[Epoch 0] Entered on_validation_epoch_end - plot_confusion_matrix
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_

Validation: |          | 0/? [00:00<?, ?it/s]

input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] st

Metric val_acc improved by 0.002 >= min_delta = 0.0. New best score: 0.986
`Trainer.fit` stopped: `max_epochs=2` reached.


[Epoch 1] Entered on_validation_epoch_end - plot_confusion_matrix


In [None]:
y_pred = trainer.predict(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/calvin/miniforge3/envs/dmpnn_rocm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Predicting: |          | 0/? [00:00<?, ?it/s]

input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] st

In [None]:
y_pred.shape

AttributeError: 'list' object has no attribute 'shape'

In [71]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/calvin/miniforge3/envs/dmpnn_rocm/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] std: 0.0
input_atom[mask] mean: 0.0
input_atom[mask] st

[{'test_loss': 0.0313100591301918, 'test_acc': 0.9865939021110535}]

In [None]:
for i, b in enumerate(dataset):
    print(i)
    mol1 = dataset[i]
    mol2 = dataset[i]

    print("1st mask:", mol1.input_ids.tolist())
    print("2nd mask:", mol2.input_ids.tolist())

    if i == 10:
        break

0
1st mask: [20, 20, 20, 10, 10]
2nd mask: [20, 20, 10, 10, 20]
1
1st mask: [10, 20, 20, 20]
2nd mask: [20, 20, 10, 20]
2
1st mask: [20, 20, 20]
2nd mask: [20, 20, 20]
3
1st mask: [20, 20, 20, 20]
2nd mask: [20, 20, 20, 20]
4
1st mask: [20, 20, 20]
2nd mask: [20, 20, 20]
5
1st mask: [20, 20, 20, 20]
2nd mask: [20, 20, 20, 20]
6
1st mask: [10, 20, 10, 20, 20, 10, 20, 10]
2nd mask: [20, 20, 20, 10, 20, 10, 10, 10]
7
1st mask: [20, 20, 20, 20, 10, 10]
2nd mask: [20, 20, 20, 10, 10, 20]
8
1st mask: [20, 20, 20, 20, 10, 20, 10]
2nd mask: [10, 20, 20, 20, 10, 20, 20]
9
1st mask: [10, 20, 20, 20, 20, 20]
2nd mask: [10, 20, 20, 20, 20, 20]
10
1st mask: [10, 20, 20, 20, 20, 10, 20]
2nd mask: [20, 20, 20, 20, 20, 10, 10]


In [None]:
from collections import Counter
print(Counter(labels.tolist()))


Counter({-100: 684, 10: 127, 6: 112, 19: 56, 4: 51, 3: 26, 18: 26, 13: 26, 8: 17, 12: 16, 16: 8, 11: 8, 14: 6, 7: 3, 15: 1, 9: 1})


In [None]:
for i in range(15):
    data = dataset[i]
    print("input_ids:", data.input_ids.tolist())
    print("labels    :", data.labels.tolist())
    print()


input_ids: [20, 20, 10, 20, 10]
labels    : [10, 6, -100, 10, -100]

input_ids: [20, 20, 10, 20]
labels    : [10, 13, -100, 10]

input_ids: [20, 20, 20]
labels    : [10, 19, 10]

input_ids: [20, 20, 20, 20]
labels    : [10, 8, 8, 10]

input_ids: [20, 20, 20]
labels    : [10, 8, 14]

input_ids: [20, 20, 20, 20]
labels    : [10, 3, 10, 18]

input_ids: [10, 20, 20, 10, 20, 20, 10, 10]
labels    : [-100, 6, 10, -100, 6, 10, -100, -100]

input_ids: [10, 20, 20, 20, 10, 20]
labels    : [-100, 19, 6, 10, -100, 10]

input_ids: [10, 20, 20, 20, 20, 20, 10]
labels    : [-100, 8, 8, 6, 10, 10, -100]

input_ids: [10, 20, 20, 20, 20, 20]
labels    : [-100, 6, 10, 10, 8, 14]

input_ids: [10, 20, 20, 20, 10, 20, 20]
labels    : [-100, 3, 18, 6, -100, 10, 10]

input_ids: [20, 20, 20, 20, 10, 20]
labels    : [10, 3, 18, 13, -100, 10]

input_ids: [10, 6, 10, 10, 20, 10, 10, 20, 20, 20, 10]
labels    : [-100, -100, -100, -100, 6, -100, -100, 6, 10, 10, -100]

input_ids: [10, 20, 20, 10, 10, 20, 20, 20, 1

In [None]:
from collections import Counter
flat_tokens = [tok for ids in token_seqs for tok in ids]
print(Counter(flat_tokens).most_common(10))


[(10, 1252508), (6, 606943), (19, 128364), (4, 110493), (13, 59414), (18, 56019), (8, 55987), (3, 55882), (12, 30717), (16, 19624)]


In [None]:
for i in range(10):
    print(f"{i}: input={batch.input_ids[i].item()}  label={batch.labels[i].item()}")


0: input=-100  label=-100
1: input=10  label=-100
2: input=20  label=3
3: input=20  label=18
4: input=20  label=4
5: input=20  label=1
6: input=20  label=18
7: input=20  label=2
8: input=20  label=10
9: input=20  label=6
