In [1]:
# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# !conda install pyg -c pyg
# !conda install gemmi -c conda-forge
# !conda install biopython

In [2]:
import os
import torch
torch_version = torch.__version__.split("+")
os.environ["TORCH"] = torch_version[0]
os.environ["CUDA"] = torch_version[1] if len(torch_version) > 1 else "cu118"

In [3]:
# %%capture
# !pip install pyg-lib torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
# !pip install torch-geometric
# !pip install wandb
# !pip install plotly
# !pip install --upgrade scipy
# !wget "https://gist.githubusercontent.com/mogproject/50668d3ca60188c50e6ef3f5f3ace101/raw/e11d5ac2b83fb03c0e5a9448ee3670b9dfcd5bf9/visualize.py"

In [4]:
import wandb
use_wandb = True #@param {type:"boolean"}
wandb_project = "structural_binding_affinity_predictions_using_gnn" #@param {type:"string"}
wandb_run_name = "Test_run" #@param {type:"string"}

wandb.init(project=wandb_project, name=wandb_run_name)

[34m[1mwandb[0m: Currently logged in as: [33mmwfjord[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [5]:
import torch
import torch_geometric
import gemmi
import Bio
# General imports
import os
import json
import collections

# Data science imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import scipy.sparse as sp

# Import Weights & Biases for Experiment Tracking
import wandb

# Graph imports
import torch
from torch import Tensor
import torch.nn.functional as F

import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx

import networkx as nx
from networkx.algorithms import community

from tqdm.auto import trange
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch
from torch_geometric.loader import DataLoader
%run visualize.py
from visualize import GraphVisualization

# Load preprocessed datasets
train_dataset = torch.load("data/train_dataset.pt")
val_dataset = torch.load("data/val_dataset.pt")
test_dataset = torch.load("data/test_dataset.pt")

print(train_dataset)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)


data_details = {
    "num_node_features": train_dataset.num_node_features,
    "num_edge_features": train_dataset.num_edge_features,
    "num_node_labels": train_dataset.num_node_labels,
    "num_edge_labels": train_dataset.num_edge_labels
}


# Log all the details about the data to W&B.
wandb.log(data_details) #🪄🐝

    
def create_graph(graph):
    g = to_networkx(graph)
    pos = nx.spring_layout(g)
    vis = GraphVisualization(
        g, pos, node_text_position='top left', node_size=20,
    )
    fig = vis.create_figure()
    return fig

fig = create_graph(train_dataset[0])
fig.show()


# Log exploratory visualizations for each data point to W&B
table = wandb.Table(columns=["Graph", "Number of Nodes", "Number of Edges", "Label"])
for graph in train_dataset:
    fig = create_graph(graph)
    n_nodes = graph.num_nodes
    n_edges = graph.num_edges
    label = graph.y.item()

table.add_data(wandb.Html(plotly.io.to_html(fig)), n_nodes, n_edges, label)
wandb.log({"data": table})


# Log the dataset to W&B as an artifact.
dataset_artifact = wandb.Artifact(name="KLF1_K_d", type="dataset", metadata=data_details)
dataset_artifact.add_dir("./data")
wandb.log_artifact(dataset_artifact)

# End the W&B run
wandb.finish()

  train_dataset = torch.load("data/train_dataset.pt")
  val_dataset = torch.load("data/val_dataset.pt")
  test_dataset = torch.load("data/test_dataset.pt")


[Data(x=[1779, 1], edge_index=[2, 43891], y=[1], pos=[1779, 3]), Data(x=[5059, 1], edge_index=[2, 128822], y=[1], pos=[5059, 3]), Data(x=[2599, 1], edge_index=[2, 65457], y=[1], pos=[2599, 3]), Data(x=[2189, 1], edge_index=[2, 54696], y=[1], pos=[2189, 3]), Data(x=[1369, 1], edge_index=[2, 32787], y=[1], pos=[1369, 3]), Data(x=[5059, 1], edge_index=[2, 128662], y=[1], pos=[5059, 3]), Data(x=[5059, 1], edge_index=[2, 128905], y=[1], pos=[5059, 3]), Data(x=[1369, 1], edge_index=[2, 32921], y=[1], pos=[1369, 3]), Data(x=[2599, 1], edge_index=[2, 65255], y=[1], pos=[2599, 3]), Data(x=[5059, 1], edge_index=[2, 124304], y=[1], pos=[5059, 3]), Data(x=[1779, 1], edge_index=[2, 43874], y=[1], pos=[1779, 3]), Data(x=[2189, 1], edge_index=[2, 54810], y=[1], pos=[2189, 3]), Data(x=[2599, 1], edge_index=[2, 65444], y=[1], pos=[2599, 3]), Data(x=[1779, 1], edge_index=[2, 43504], y=[1], pos=[1779, 3]), Data(x=[1369, 1], edge_index=[2, 32755], y=[1], pos=[1369, 3]), Data(x=[2189, 1], edge_index=[2, 54

AttributeError: 'list' object has no attribute 'num_node_features'

In [6]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn as nn

class ProteinDNAGNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim=1):
        super(ProteinDNAGNN, self).__init__()

        # Graph convolutional layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)

        # Fully connected layer for regression
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))

        # Global pooling to get a graph-level representation
        x = global_mean_pool(x, batch)

        # Output the binding affinity prediction
        return self.fc(x)


# Initialize the model

In [7]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = train_dataset[0].num_node_features 
hidden_dim = 32  

model = ProteinDNAGNN(input_dim, hidden_dim).to(device[3])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()  # Mean Squared Error for regression


In [8]:
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

def validate():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss.item()
    return total_loss / len(val_loader)


In [None]:
epochs = 1000  # You can increase this based on your dataset size

for epoch in range(1, epochs + 1):
    train_loss = train()
    val_loss = validate()
    
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "protein_dna_gnn_model.pth")


In [None]:
def test():
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss.item()
    return total_loss / len(test_loader)

test_loss = test()
print(f"Test Loss: {test_loss:.4f}")


# Make Predictions on New Data

In [None]:
def predict(data):
    model.eval()
    data = data.to(device)
    with torch.no_grad():
        out = model(data.x, data.edge_index, data.batch)
    return out.item()

# Example prediction on a single protein-DNA complex
sample_data = test_dataset[0]  # Take the first test sample
predicted_affinity = predict(sample_data)
print(f"Predicted Binding Affinity: {predicted_affinity:.4f}")
