In [1]:
from collections import Counter
from networkx import get_node_attributes
import shutil
import torch
from torch.nn import Linear
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.nn import GCNConv, GraphConv, SAGEConv
import wandb

In [2]:
from rnaglib.tasks import BenchmarkLigandBindingSiteDetection, BindingSiteDetection
from rnaglib.representations import GraphRepresentation
from rnaglib.data_loading import Collater

Index file not found at /home/vmallet/.rnaglib/indexes/rnaglib-nr-1.0.0.json. Run rnaglib_index
Index file not found at /home/vmallet/.rnaglib/indexes/rnaglib-nr-1.0.0.json. Run rnaglib_index
Index file not found at /home/vmallet/.rnaglib/indexes/rnaglib-nr-1.0.0.json. Run rnaglib_index


In [3]:
shutil.rmtree('test_fri')
ta = BenchmarkLigandBindingSiteDetection(root='test_fri')
ta.dataset.add_representation(GraphRepresentation(framework='pyg'))
# get_node_attributes(ta.dataset[0]['rna'], 'nt_code')

Database was found and not overwritten
>>> Computing splits...
>>> Saving dataset.
>>> Done


In [4]:
train_ind, val_ind, test_ind = ta.split()
train_set = ta.dataset.subset(train_ind)
val_set = ta.dataset.subset(val_ind)
test_set = ta.dataset.subset(test_ind)

collater = Collater(train_set)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True, collate_fn=collater)
val_loader = DataLoader(val_set, batch_size=2, shuffle=False, collate_fn=collater)
test_loader = DataLoader(test_set, batch_size=2, shuffle=False, collate_fn=collater)

>>> Loading splits...


In [5]:
for batch in train_loader:
    print(batch)
    graph = batch['graph']
    print(f'Batch node features shape: \t{graph.x.shape}')
    print(f'Batch edge index shape: \t{graph.edge_index.shape}')
    print(f'Batch labels shape: \t\t{graph.y.shape}')
    break

{'graph': DataBatch(x=[357, 4], edge_index=[2, 958], edge_attr=[958], y=[357, 1], batch=[357], ptr=[9]), 'rna': [<networkx.classes.digraph.DiGraph object at 0x7f22c8915b80>, <networkx.classes.digraph.DiGraph object at 0x7f22c89158b0>, <networkx.classes.digraph.DiGraph object at 0x7f22c89158e0>, <networkx.classes.digraph.DiGraph object at 0x7f231042da60>, <networkx.classes.digraph.DiGraph object at 0x7f22c8915ac0>, <networkx.classes.digraph.DiGraph object at 0x7f22cad30580>, <networkx.classes.digraph.DiGraph object at 0x7f22c9f98160>, <networkx.classes.digraph.DiGraph object at 0x7f231042d8e0>]}
Batch node features shape: 	torch.Size([357, 4])
Batch edge index shape: 	torch.Size([2, 958])
Batch labels shape: 		torch.Size([357, 1])


In [6]:
import numpy as np


# Assuming train_loader is defined and loaded with the dataset
def calculate_length_statistics(dataset):
    lengths = [data['graph'].x.shape[0] for data in dataset]

    max_length = np.max(lengths)
    min_length = np.min(lengths)
    avg_length = np.mean(lengths)
    median_length = np.median(lengths)

    return {"max_length": max_length,
            "min_length": min_length,
            "average_length": avg_length,
            "median_length": median_length}


# Example usage
stats = calculate_length_statistics(train_set)
print("Max Length:", stats["max_length"])
print("Min Length:", stats["min_length"])
print("Average Length:", stats["average_length"])
print("Median Length:", stats["median_length"])

Max Length: 414
Min Length: 19
Average Length: 53.574074074074076
Median Length: 32.5


In [7]:
def calculate_fraction_of_ones(loader):
    total_ones = 0
    total_elements = 0
    for batch in loader.dataset:
        y = batch['graph'].y
        total_ones += (y == 1).sum().item()
        total_elements += y.numel()
    fraction_of_ones = total_ones / total_elements if total_elements > 0 else 0
    return fraction_of_ones


# Example usage
fraction = calculate_fraction_of_ones(train_loader)
print("Fraction of ones:", fraction)

Fraction of ones: 0.20255789837538887


In [8]:
# Assuming train_loader is defined and loaded with the dataset
def count_unique_edge_attrs(train_loader):
    unique_edge_attrs = set()
    for batch in train_loader:
        unique_edge_attrs.update(batch['graph'].edge_attr.tolist())
    return len(unique_edge_attrs), unique_edge_attrs


# Example usage
num_unique_edge_attrs, unique_edge_attrs = count_unique_edge_attrs(train_loader)
print("Number of unique edge attributes:", num_unique_edge_attrs)
print("Unique edge attributes:", unique_edge_attrs)

Number of unique edge attributes: 20
Unique edge attributes: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}


In [9]:
def add_edge_features_to_nodes(data):
    # Assuming edge_attr is a tensor of shape [num_edges]
    row, col = data.edge_index
    edge_features = data.edge_attr

    # Add edge features to the corresponding node features
    # Here we are adding a new dimension to edge_features to match the dimensions
    data.x[row] += edge_features.view(-1, 1)
    data.x[col] += edge_features.view(-1, 1)
    return data

## Model

In [10]:
wandb.init(project="gcn-node-classification", config={
    "learning_rate": 0.0001,
    "epochs": 2000,
    "batch_size": 1
})

[34m[1mwandb[0m: Currently logged in as: [33mvincentx15[0m ([33matomiclearning[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(num_node_features, 16)
        #self.conv2 = GCNConv(16, 32) 
        #self.conv3 = GCNConv(32, 16) 
        self.conv4 = GraphConv(16, num_classes)

    def forward(self, data):
        data = add_edge_features_to_nodes(data)
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        #x = self.conv2(x, edge_index)
        #x = F.relu(x)
        #x = self.conv3(x, edge_index)
        #x = F.relu(x)
        x = self.conv4(x, edge_index)

        return F.log_softmax(x, dim=1)


num_classes = 2
model = GCN(train_set.input_dim, num_classes)

## Training

In [12]:
all_labels = []
for batch in train_loader:
    batch_labels = batch['graph'].y
    all_labels.extend(torch.flatten(batch_labels).tolist())
class_counts = Counter(all_labels)
total_samples = len(all_labels)
class_weights = {cls: total_samples / count for cls, count in class_counts.items()}
weights = torch.tensor([class_weights[i] for i in range(num_classes)])

optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss(weight=weights)

In [13]:
def train():
    model.train()
    for batch in train_loader:
        graph = batch['graph']
        graph = graph.to(device)
        optimizer.zero_grad()
        out = model(graph)
        loss = criterion(out, torch.flatten(graph.y).long())
        loss.backward()
        optimizer.step()
        wandb.log({"loss": loss.item()})


def test(loader):
    model.eval()
    correct = 0
    for batch in loader:
        graph = batch['graph']
        graph = graph.to(device)
        out = model(graph)
        pred = out.argmax(dim=1)
        correct += (pred == graph.y).sum().item()
    return correct / len(loader.dataset)

In [14]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
for epoch in range(50):
    train()
    train_acc = test(train_loader)
    val_acc = test(val_loader)
    print(f'Epoch: {epoch}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
    wandb.log({"train_acc": train_acc, "val_acc": val_acc})

Epoch: 0, Train Acc: 15110.5370, Val Acc: 914.8333
Epoch: 1, Train Acc: 16308.7222, Val Acc: 895.1667
Epoch: 2, Train Acc: 14553.5741, Val Acc: 883.0000
Epoch: 3, Train Acc: 14298.7222, Val Acc: 871.0000
Epoch: 4, Train Acc: 12514.0000, Val Acc: 872.0000
Epoch: 5, Train Acc: 11732.5556, Val Acc: 871.5000
Epoch: 6, Train Acc: 12029.8519, Val Acc: 878.6667
Epoch: 7, Train Acc: 11670.6667, Val Acc: 878.6667
Epoch: 8, Train Acc: 11933.2407, Val Acc: 878.6667
Epoch: 9, Train Acc: 12664.5926, Val Acc: 874.8333
Epoch: 10, Train Acc: 11805.6296, Val Acc: 875.3333
Epoch: 11, Train Acc: 12056.5926, Val Acc: 880.6667
Epoch: 12, Train Acc: 11490.7222, Val Acc: 880.6667
Epoch: 13, Train Acc: 12331.4259, Val Acc: 881.1667
Epoch: 14, Train Acc: 11793.7963, Val Acc: 881.1667
Epoch: 15, Train Acc: 13052.3519, Val Acc: 881.1667
Epoch: 16, Train Acc: 11532.9815, Val Acc: 877.3333
Epoch: 17, Train Acc: 13242.4630, Val Acc: 877.3333
Epoch: 18, Train Acc: 12451.1481, Val Acc: 877.3333
Epoch: 19, Train Acc: 

## Evaluation

In [15]:
test_acc = test(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
wandb.log({"test_acc": test_acc})

Test Accuracy: 1212.2778


In [16]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score


def get_predictions(loader):
    model.eval()
    all_preds = []
    all_labels = []
    
    for batch in loader:
        graph = batch['graph']
        graph = graph.to(device)
        out = model(graph)
        preds = out.argmax(dim=1)
        all_preds.extend(preds.tolist())
        all_labels.extend(torch.flatten(graph.y).tolist())
    return all_preds, all_labels


def calculate_metrics(loader):
    preds, labels = get_predictions(loader)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return accuracy, f1, auc


test_accuracy, test_f1, test_auc = calculate_metrics(test_loader)
print(f'Test Accuracy: {test_accuracy:.4f}, Test F1 Score: {test_f1:.4f}, Test AUC: {test_auc:.4f}')
wandb.log({"test_accuracy": test_accuracy, "test_f1": test_f1, "test_auc": test_auc})

Test Accuracy: 0.4763, Test F1 Score: 0.3600, Test AUC: 0.4958


In [17]:
wandb.finish()


VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▆▃▄▃▄▃▃▂▃▃▃▃▃▂▁▃▃▁▁▂▁▂▂▂▁▂▂▂▂▁▂▂▂▂▂▃▁▂▁
test_acc,▁
test_accuracy,▁
test_auc,▁
test_f1,▁
train_acc,▆█▅▅▁▂▁▂▁▂▁▂▃▁▄▂▃▄▁▂▂▃▂▃▃▂▂▂▂▃▂▂▃▂▂▁▂▂▃▁
val_acc,█▅▃▁▁▂▂▂▂▃▃▃▃▂▂▂▃▃▂▃▂▂▂▂▃▂▂▂▂▂▁▂▁▁▁▂▃▃▃▃

0,1
loss,1.06153
test_acc,1212.27778
test_accuracy,0.47627
test_auc,0.49581
test_f1,0.36
train_acc,11698.27778
val_acc,881.66667
