In [1]:
from rnaglib.tasks import gRNAde, BindingSiteDetection, BenchmarkLigandBindingSiteDetection, InverseFolding
from rnaglib.representations import GraphRepresentation
from rnaglib.data_loading import Collater
import torch
from torch.utils.data import DataLoader
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, SAGEConv, RGCNConv
import torch.optim as optim
import wandb
from collections import Counter
from torch.nn import BatchNorm1d, Dropout
import shutil
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, matthews_corrcoef
from pathlib import Path

In [2]:
if Path('ifchim').exists():
    shutil.rmtree('ifchim')
ta = InverseFolding(root='ifchim')

Database was found and not overwritten
>>> Computing splits...
>>> Saving dataset.
>>> Done


In [3]:
ta.dataset.add_representation(GraphRepresentation(framework = 'pyg'))

In [4]:
train_ind, val_ind, test_ind = ta.split()
train_set = ta.dataset.subset(train_ind)
val_set = ta.dataset.subset(val_ind)
test_set = ta.dataset.subset(test_ind)

>>> Loading splits...


In [5]:
collater = Collater(train_set)
train_loader = DataLoader(train_set, shuffle=False, collate_fn=collater) #batch_size=20, #shuffle needs to be true
val_loader = DataLoader(val_set, shuffle=False, collate_fn=collater) #batch_size=2, 
test_loader = DataLoader(test_set,  shuffle=False, collate_fn=collater) #batch_size=2,

In [15]:
for batch in train_loader:
    print(batch)
    graph = batch['graph']
    print(f'Batch node features shape: \t{graph.x.shape}')
    print(f'Batch edge index shape: \t{graph.edge_index.shape}')
    print(f'Batch labels shape: \t\t{graph.y.shape}')
    break

def calculate_length_statistics(loader):
    lengths = [data['graph'].x.shape[0] for data in loader.dataset]
    
    max_length = np.max(lengths)
    min_length = np.min(lengths)
    avg_length = np.mean(lengths)
    median_length = np.median(lengths)
    
    return {
        "max_length": max_length,
        "min_length": min_length,
        "average_length": avg_length,
        "median_length": median_length
    }

stats = calculate_length_statistics(train_loader)
print("Max Length:", stats["max_length"])
print("Min Length:", stats["min_length"])
print("Average Length:", stats["average_length"])
print("Median Length:", stats["median_length"])

def count_unique_edge_attrs(train_loader):
    unique_edge_attrs = set()
    
    for batch in train_loader.dataset:
        unique_edge_attrs.update(batch['graph'].edge_attr.tolist())
    
    return len(unique_edge_attrs), unique_edge_attrs

num_unique_edge_attrs, unique_edge_attrs = count_unique_edge_attrs(train_loader)
print("Number of unique edge attributes:", num_unique_edge_attrs)
print("Unique edge attributes:", unique_edge_attrs)

{'graph': DataBatch(x=[951, 1], edge_index=[2, 2714], edge_attr=[2714], y=[951, 4], batch=[951], ptr=[2]), 'rna': [<networkx.classes.digraph.DiGraph object at 0x148c892eaa20>]}
Batch node features shape: 	torch.Size([951, 1])
Batch edge index shape: 	torch.Size([2, 2714])
Batch labels shape: 		torch.Size([951, 4])
Max Length: 5811
Min Length: 20
Average Length: 284.24665856622113
Median Length: 65.0
Number of unique edge attributes: 20
Unique edge attributes: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}


# Model

In [16]:
wandb.init(project="inverse_design", config={
    "learning_rate": 0.0001,
    "epochs": 300,
    "batch_size": 1,
    "dropout_rate": 0.1,  
    "num_layers": 2, 
    "batch_norm": True 
})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mluiswyss[0m ([33mmlsb[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [17]:
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_node_features, 16)
        self.bn1 = BatchNorm1d(16)  
        self.dropout1 = Dropout(0.1) 
        self.conv2 = GCNConv(16, num_classes)
        self.bn2 = BatchNorm1d(num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout1(x)
        x = self.conv2(x, edge_index)
        x = self.bn2(x)
        x = F.relu(x) 
    
        return F.log_softmax(x, dim=1)

In [18]:
num_classes = train_loader.dataset[0]['graph'].y.size(1)
model = GCN(train_set.input_dim, num_classes)

In [19]:
for batch in train_loader:
   print(batch)
   break


{'graph': DataBatch(x=[951, 1], edge_index=[2, 2714], edge_attr=[2714], y=[951, 4], batch=[951], ptr=[2]), 'rna': [<networkx.classes.digraph.DiGraph object at 0x148c892eaa20>]}


In [20]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()#weight=class_weights_tensor)

In [21]:
def train():
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch in train_loader:
        graph = batch['graph']
        graph = graph.to(device)
        optimizer.zero_grad()
        out = model(graph)

        # Convert one-hot encoded labels to class indices
        labels = graph.y.argmax(dim=1).long()
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        
        # Log the loss
        wandb.log({"train_loss": loss.item()})
        running_loss += loss.item()

        # Convert one-hot encoded predictions to class indices
        preds = out.argmax(dim=1)
        correct_predictions += (preds == labels).sum().item()
        total_predictions += labels.size(0)

    # Calculate average loss and accuracy
    avg_loss = running_loss / len(train_loader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy

def evaluate(loader):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for batch in loader:
            graph = batch['graph']
            graph = graph.to(device)
            out = model(graph)

            # Convert one-hot encoded labels to class indices
            labels = graph.y.argmax(dim=1).long()
            loss = criterion(out, labels)
            total_loss += loss.item()

            # Convert one-hot encoded predictions to class indices
            preds = out.argmax(dim=1)
            correct_predictions += (preds == labels).sum().item()
            total_predictions += labels.size(0)

    avg_loss = total_loss / len(loader)
    accuracy = correct_predictions / total_predictions
    return avg_loss, accuracy

# Main training loop
for epoch in range(5000):
    train_loss, train_accuracy = train()
    val_loss, val_accuracy = evaluate(val_loader)
    print(f'Epoch: {epoch}, Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}')
    wandb.log({
        "train_loss": train_loss,
        "train_accuracy": train_accuracy,
        "val_loss": val_loss,
        "val_accuracy": val_accuracy
    })

Epoch: 0, Train Loss: 1.4162, Train Acc: 0.2589, Val Loss: 1.7368, Val Acc: 0.2659
Epoch: 1, Train Loss: 1.4119, Train Acc: 0.2678, Val Loss: 1.5801, Val Acc: 0.2664
Epoch: 2, Train Loss: 1.4087, Train Acc: 0.2726, Val Loss: 1.5238, Val Acc: 0.2665
Epoch: 3, Train Loss: 1.4066, Train Acc: 0.2759, Val Loss: 1.5299, Val Acc: 0.2672
Epoch: 4, Train Loss: 1.4027, Train Acc: 0.2778, Val Loss: 1.5487, Val Acc: 0.2803
Epoch: 5, Train Loss: 1.4001, Train Acc: 0.2779, Val Loss: 1.4970, Val Acc: 0.2803
Epoch: 6, Train Loss: 1.3980, Train Acc: 0.2800, Val Loss: 1.4859, Val Acc: 0.2803
Epoch: 7, Train Loss: 1.3969, Train Acc: 0.2810, Val Loss: 1.5221, Val Acc: 0.2803
Epoch: 8, Train Loss: 1.3960, Train Acc: 0.2814, Val Loss: 1.4868, Val Acc: 0.2804
Epoch: 9, Train Loss: 1.3943, Train Acc: 0.2808, Val Loss: 1.4818, Val Acc: 0.2812
Epoch: 10, Train Loss: 1.3938, Train Acc: 0.2801, Val Loss: 1.4878, Val Acc: 0.2809
Epoch: 11, Train Loss: 1.3924, Train Acc: 0.2796, Val Loss: 1.4638, Val Acc: 0.2819
Ep