In [1]:
!pip install torch_geometric
from transformers import AutoTokenizer, AutoModel
from torch_geometric.datasets import MoleculeNet
from tqdm import tqdm
from torch.nn import LayerNorm, Dropout
import torch.nn.functional as F
from torch_geometric.data import DataLoader
from sklearn.metrics import roc_auc_score
import torch
from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.nn import ASAPooling
from torch.nn import LayerNorm, Linear, ReLU, Sigmoid

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model_name = "seyonec/ChemBERTa-zinc-base-v1"
model_name = "DeepChem/ChemBERTa-10M-MTR"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name).to(device)
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# bert_model = AutoModel.from_pretrained(model_name).to(device)

NUM_EPOCHS = 100
BATCH_SIZE = 32

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/17.7k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/6.96k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.26k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/420 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.0M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-10M-MTR and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
!pip install rdkit
dataset = MoleculeNet(root='data/MoleculeNet', name='HIV')
# dataset = MoleculeNet(root='data/MoleculeNet', name='BBBP')
# dataset = MoleculeNet(root='data/MoleculeNet', name='BACE')

smiles_list = []
for i in range(len(dataset)):
    mol = dataset[i].smiles
    smiles_list.append(mol)

Collecting rdkit
  Downloading rdkit-2024.9.6-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.9.6-cp311-cp311-manylinux_2_28_x86_64.whl (34.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/34.3 MB[0m [31m?[0m eta [36m-:--:--[0m

model.safetensors:   0%|          | 0.00/14.0M [00:00<?, ?B/s]

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.9.6


Downloading https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/HIV.csv
Processing...
Done!


In [4]:
class MultiModalGIN(torch.nn.Module):
    def __init__(self,
                 num_node_features:int,
                 hidden_channels:int,
                 out_channels:int,
                 bert_dim:int,
                 dropout:float=0.5,
                 use_asap:bool=False,
                 asap_ratio:float=0.5):
        super().__init__()
        # GIN layers
        self.conv1 = GINConv(torch.nn.Sequential(
            Linear(num_node_features, hidden_channels), ReLU(),
            Linear(hidden_channels, hidden_channels)))
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)

        self.conv2 = GINConv(torch.nn.Sequential(
            Linear(hidden_channels, hidden_channels), ReLU(),
            Linear(hidden_channels, hidden_channels)))
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)

        self.conv3 = GINConv(torch.nn.Sequential(
            Linear(hidden_channels, out_channels), ReLU()))
        self.bn3 = torch.nn.BatchNorm1d(out_channels)

        # ASAPooling
        self.use_asap = use_asap
        if use_asap:
            self.asap = ASAPooling(hidden_channels, ratio=asap_ratio)

        # fusion gating
        fusion_dim = out_channels * 3 + bert_dim
        self.gate = Linear(fusion_dim, fusion_dim)
        self.sigmoid = Sigmoid()

        # classification head
        self.head = torch.nn.Sequential(
            Linear(fusion_dim, 256), LayerNorm(256), ReLU(),
            Linear(256, 128), LayerNorm(128), ReLU(),
            Linear(128, 1), Sigmoid()
        )
        self.dropout = dropout

    def forward(self, x, edge_index, batch, bert_emb):
        # GIN conv layers
        x = F.dropout(F.relu(self.bn1(self.conv1(x, edge_index))), p=self.dropout, training=self.training)
        x = F.dropout(F.relu(self.bn2(self.conv2(x, edge_index))), p=self.dropout, training=self.training)

        if self.use_asap:
            x, edge_index, _, batch, _ = self.asap(x, edge_index, batch=batch)

        x = F.dropout(F.relu(self.bn3(self.conv3(x, edge_index))), p=self.dropout, training=self.training)

        # pooling layer
        x_mean = global_mean_pool(x, batch)
        x_max  = global_max_pool(x, batch)
        x_sum  = global_add_pool(x, batch)
        x_graph = torch.cat([x_mean, x_max, x_sum], dim=-1)

        # fuse graph and bert embeddings
        fused = torch.cat([x_graph, bert_emb], dim=-1)
        g = self.sigmoid(self.gate(fused))
        fused = fused * g

        # Classification
        out = self.head(fused)
        return out

In [5]:
model = MultiModalGIN(
    num_node_features=dataset.num_node_features,
    hidden_channels=128,
    out_channels=64,
    bert_dim=384,
    dropout=0.3,
    use_asap=True,
    asap_ratio=0.5
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.BCELoss()

torch.manual_seed(0)
train_dataset = dataset[:int(0.8 * len(dataset))]
val_dataset = dataset[int(0.8 * len(dataset)):int(0.9 * len(dataset))]
test_dataset = dataset[int(0.9 * len(dataset)):]

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)



In [6]:
from sklearn.metrics import accuracy_score
patience = 5
best_val_auc = 0  # best val AUC
best_val_acc = 0  # best val accuracy
patience_counter = 0  # initialize patience counter

for epoch in range(1, NUM_EPOCHS + 1):
    # Training
    model.train()
    total_loss = 0
    all_train_preds, all_train_labels = [], []  # for accuracy calculation
    for batch in tqdm(train_loader, desc=f"Epoch {epoch} [Train]"):
        batch = batch.to(device)
        # BERT embedding from SMILES strings
        smiles = batch.smiles
        encoded = tokenizer(list(smiles), return_tensors='pt', padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = bert_model(**encoded)
            emb_bert = outputs.last_hidden_state[:, 0, :]
        preds = model(batch.x.float(), batch.edge_index, batch.batch, emb_bert)
        y = batch.y.float().view(-1, 1)

        preds_labels = (preds > 0.5).float()  # threshold at 0.5 for binary classification
        all_train_preds.append(preds_labels.cpu())
        all_train_labels.append(y.cpu())

        loss = criterion(preds, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.num_graphs

    avg_loss = total_loss / len(train_dataset)
    # accuracy on training set
    all_train_preds = torch.cat(all_train_preds)
    all_train_labels = torch.cat(all_train_labels)
    train_acc = accuracy_score(all_train_labels, all_train_preds)
    print(f"Epoch {epoch} -> Train Loss: {avg_loss:.4f}, Train ACC: {train_acc:.4f}")

    # Validation
    model.eval()
    all_preds, all_labels = [], []
    all_val_preds, all_val_labels = [], []  # for accuracy calculation
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            smiles = batch.smiles
            encoded = tokenizer(list(smiles), return_tensors='pt', padding=True, truncation=True).to(device)
            outputs = bert_model(**encoded)
            emb_bert = outputs.last_hidden_state[:, 0, :]
            preds = model(batch.x.float(), batch.edge_index, batch.batch, emb_bert)
            all_preds.append(preds.cpu())
            all_labels.append(batch.y.view(-1, 1).cpu())

            preds_labels = (preds > 0.5).float()
            all_val_preds.append(preds_labels.cpu())
            all_val_labels.append(batch.y.view(-1, 1).cpu())

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    val_auc = roc_auc_score(all_labels, all_preds)

    # accuracy on validation set
    all_val_preds = torch.cat(all_val_preds)
    all_val_labels = torch.cat(all_val_labels)
    val_acc = accuracy_score(all_val_labels, all_val_preds)

    print(f"Epoch {epoch} -> Val AUC: {val_auc:.4f}, Val ACC: {val_acc:.4f}\n")

    if val_auc > best_val_auc:
        best_val_auc = val_auc
        best_val_acc = val_acc
        patience_counter = 0  # reset counter
        print("Validation AUC improved. Saving model...")
        torch.save(model.state_dict(), "best_model.pth")  # save the best model
    else:
        patience_counter += 1
        print(f"No improvement in validation AUC. Patience counter: {patience_counter}/{patience}")
        if patience_counter >= patience:
            print("Early stopping triggered. Ending training.")
            break

# Load the best model for testing
print("Loading the best model for evaluation...")
model.load_state_dict(torch.load("best_model.pth"))

# test set evaluation
model.eval()
all_preds, all_labels = [], []
all_test_preds, all_test_labels = [], []
with torch.no_grad():
    for batch in test_loader:
        batch = batch.to(device)
        smiles = batch.smiles
        encoded = tokenizer(list(smiles), return_tensors='pt', padding=True, truncation=True).to(device)
        outputs = bert_model(**encoded)
        emb_bert = outputs.last_hidden_state[:, 0, :]
        preds = model(batch.x.float(), batch.edge_index, batch.batch, emb_bert)
        all_preds.append(preds.cpu())
        all_labels.append(batch.y.view(-1, 1).cpu())

        preds_labels = (preds > 0.5).float()
        all_test_preds.append(preds_labels.cpu())
        all_test_labels.append(batch.y.view(-1, 1).cpu())

all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()
test_auc = roc_auc_score(all_labels, all_preds)

# accuracy on test set
all_test_preds = torch.cat(all_test_preds)
all_test_labels = torch.cat(all_test_labels)
test_acc = accuracy_score(all_test_labels, all_test_preds)

print(f"Test AUC: {test_auc:.4f}, Test ACC: {test_acc:.4f}")

  adj = torch.sparse_csr_tensor(
Epoch 1 [Train]: 100%|██████████| 1028/1028 [00:50<00:00, 20.25it/s]


Epoch 1 -> Train Loss: 0.1284, Train ACC: 0.9691
Epoch 1 -> Val AUC: 0.7237, Val ACC: 0.9647

Validation AUC improved. Saving model...


Epoch 2 [Train]: 100%|██████████| 1028/1028 [00:27<00:00, 37.07it/s]


Epoch 2 -> Train Loss: 0.1109, Train ACC: 0.9708
Epoch 2 -> Val AUC: 0.7261, Val ACC: 0.9625

Validation AUC improved. Saving model...


Epoch 3 [Train]: 100%|██████████| 1028/1028 [00:28<00:00, 36.17it/s]


Epoch 3 -> Train Loss: 0.1045, Train ACC: 0.9726
Epoch 3 -> Val AUC: 0.7556, Val ACC: 0.9667

Validation AUC improved. Saving model...


Epoch 4 [Train]: 100%|██████████| 1028/1028 [00:28<00:00, 36.41it/s]


Epoch 4 -> Train Loss: 0.0972, Train ACC: 0.9744
Epoch 4 -> Val AUC: 0.7460, Val ACC: 0.9691

No improvement in validation AUC. Patience counter: 1/5


Epoch 5 [Train]: 100%|██████████| 1028/1028 [00:28<00:00, 36.19it/s]


Epoch 5 -> Train Loss: 0.0902, Train ACC: 0.9757
Epoch 5 -> Val AUC: 0.7310, Val ACC: 0.9638

No improvement in validation AUC. Patience counter: 2/5


Epoch 6 [Train]: 100%|██████████| 1028/1028 [00:28<00:00, 36.51it/s]


Epoch 6 -> Train Loss: 0.0849, Train ACC: 0.9766
Epoch 6 -> Val AUC: 0.7475, Val ACC: 0.9635

No improvement in validation AUC. Patience counter: 3/5


Epoch 7 [Train]: 100%|██████████| 1028/1028 [00:27<00:00, 37.30it/s]


Epoch 7 -> Train Loss: 0.0808, Train ACC: 0.9777
Epoch 7 -> Val AUC: 0.7523, Val ACC: 0.9635

No improvement in validation AUC. Patience counter: 4/5


Epoch 8 [Train]: 100%|██████████| 1028/1028 [00:28<00:00, 35.98it/s]


Epoch 8 -> Train Loss: 0.0745, Train ACC: 0.9788
Epoch 8 -> Val AUC: 0.7059, Val ACC: 0.9621

No improvement in validation AUC. Patience counter: 5/5
Early stopping triggered. Ending training.
Loading the best model for evaluation...
Test AUC: 0.7444, Test ACC: 0.9404
