In [57]:
from rnaglib.tasks import BenchmarkLigandBindingSiteDetection, BindingSiteDetection
from rnaglib.representations import GraphRepresentation
import torch
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, SAGEConv
import torch.optim as optim
import wandb
from collections import Counter
from torch.nn import Linear



In [2]:
import shutil
shutil.rmtree('test_fri')


In [3]:
ta = BenchmarkLigandBindingSiteDetection('test_fri')

Dataset was found and not overwritten
>>> Computing splits...
>>> Saving dataset.
saving


In [4]:
train_ind, val_ind, test_ind = ta.splitter(ta.dataset)

In [5]:
ta.dataset.add_representation(GraphRepresentation(framework = 'pyg'))

In [7]:
from networkx import get_node_attributes
get_node_attributes(ta.dataset[0]['rna'], 'nt_code')

{'3skt.A.23': 'g',
 '3skt.A.24': 'G',
 '3skt.A.25': 'C',
 '3skt.A.26': 'U',
 '3skt.A.27': 'U',
 '3skt.A.28': 'A',
 '3skt.A.29': 'U',
 '3skt.A.30': 'A',
 '3skt.A.31': 'C',
 '3skt.A.32': 'A',
 '3skt.A.33': 'G',
 '3skt.A.34': 'G',
 '3skt.A.35': 'G',
 '3skt.A.36': 'U',
 '3skt.A.37': 'A',
 '3skt.A.38': 'G',
 '3skt.A.39': 'C',
 '3skt.A.40': 'A',
 '3skt.A.41': 'U',
 '3skt.A.42': 'A',
 '3skt.A.43': 'A',
 '3skt.A.44': 'U',
 '3skt.A.45': 'G',
 '3skt.A.46': 'G',
 '3skt.A.47': 'G',
 '3skt.A.48': 'C',
 '3skt.A.49': 'U',
 '3skt.A.50': 'A',
 '3skt.A.51': 'C',
 '3skt.A.52': 'U',
 '3skt.A.53': 'G',
 '3skt.A.54': 'A',
 '3skt.A.55': 'C',
 '3skt.A.56': 'C',
 '3skt.A.57': 'C',
 '3skt.A.58': 'C',
 '3skt.A.59': 'G',
 '3skt.A.60': 'C',
 '3skt.A.61': 'C',
 '3skt.A.62': 'U',
 '3skt.A.63': 'U',
 '3skt.A.64': 'C',
 '3skt.A.65': 'A',
 '3skt.A.66': 'A',
 '3skt.A.67': 'A',
 '3skt.A.68': 'C',
 '3skt.A.69': 'C',
 '3skt.A.70': 'U',
 '3skt.A.71': 'A',
 '3skt.A.72': 'U',
 '3skt.A.73': 'U',
 '3skt.A.74': 'U',
 '3skt.A.75'

## Preprocessing (w/o rnaglib)

In [8]:
train_data_list, val_data_list, test_data_list = [], [], []

In [9]:
for ind in train_ind:
    train_data_list.append(ta.dataset[ind]['graph'])
for ind in val_ind:
    val_data_list.append(ta.dataset[ind]['graph'])
for ind in test_ind:
    test_data_list.append(ta.dataset[ind]['graph'])

In [10]:
train_loader = DataLoader(train_data_list, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data_list, batch_size=2, shuffle=False)
test_loader = DataLoader(test_data_list, batch_size=2, shuffle=False)

In [11]:
# rnaglib converter provides right format for pyg objects but its contents are lists or lists of lists, not tensors.
def pyg_converter(loader):
    for data in loader.dataset:
        data.edge_index =  torch.tensor(data.edge_index).t()
        data.edge_attr =  torch.tensor(data.edge_attr)
        data.y = data.y.squeeze().long()


In [12]:
pyg_converter(train_loader)
pyg_converter(val_loader)
pyg_converter(test_loader)

In [13]:
for batch in train_loader:
    print(batch)
    print(f'Batch node features shape: {batch.x.shape}')
    print(f'Batch edge index shape: {batch.edge_index.shape}')
    print(f'Batch labels shape: {batch.y.shape}')
    break

DataBatch(x=[650, 4], edge_index=[2, 1818], edge_attr=[1818], y=[650], batch=[650], ptr=[9])
Batch node features shape: torch.Size([650, 4])
Batch edge index shape: torch.Size([2, 1818])
Batch labels shape: torch.Size([650])


In [14]:
num_node_features = train_data_list[0].num_features
num_classes = 2 

In [33]:
import numpy as np

# Assuming train_loader is defined and loaded with the dataset
def calculate_length_statistics(loader):
    lengths = [data.x.shape[0] for data in loader.dataset]
    
    max_length = np.max(lengths)
    min_length = np.min(lengths)
    avg_length = np.mean(lengths)
    median_length = np.median(lengths)
    
    return {
        "max_length": max_length,
        "min_length": min_length,
        "average_length": avg_length,
        "median_length": median_length
    }

# Example usage
stats = calculate_length_statistics(train_loader)
print("Max Length:", stats["max_length"])
print("Min Length:", stats["min_length"])
print("Average Length:", stats["average_length"])
print("Median Length:", stats["median_length"])

Max Length: 414
Min Length: 19
Average Length: 53.574074074074076
Median Length: 32.5


In [39]:
def calculate_fraction_of_ones(loader):
    total_ones = 0
    total_elements = 0
    
    for data in loader.dataset:
        y = data.y
        total_ones += (y == 1).sum().item()
        total_elements += y.numel()
    
    fraction_of_ones = total_ones / total_elements if total_elements > 0 else 0
    return fraction_of_ones

# Example usage
fraction = calculate_fraction_of_ones(train_loader)
print("Fraction of ones:", fraction)

Fraction of ones: 0.20255789837538887


In [59]:
# Assuming train_loader is defined and loaded with the dataset

def count_unique_edge_attrs(train_loader):
    unique_edge_attrs = set()
    
    for data in train_loader.dataset:
        unique_edge_attrs.update(data.edge_attr.tolist())
    
    return len(unique_edge_attrs), unique_edge_attrs

# Example usage
num_unique_edge_attrs, unique_edge_attrs = count_unique_edge_attrs(train_loader)
print("Number of unique edge attributes:", num_unique_edge_attrs)
print("Unique edge attributes:", unique_edge_attrs)

Number of unique edge attributes: 20
Unique edge attributes: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}


In [62]:
def add_edge_features_to_nodes(data):
    # Assuming edge_attr is a tensor of shape [num_edges]
    row, col = data.edge_index
    edge_features = data.edge_attr

    # Add edge features to the corresponding node features
    # Here we are adding a new dimension to edge_features to match the dimensions
    data.x[row] += edge_features.view(-1, 1)
    data.x[col] += edge_features.view(-1, 1)

    return data


## Model

In [63]:
wandb.init(project="gcn-node-classification", config={
    "learning_rate": 0.0001,
    "epochs": 2000,
    "batch_size": 1
})

In [64]:
class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(num_node_features, 16)
        #self.conv2 = GCNConv(16, 32) 
        #self.conv3 = GCNConv(32, 16) 
        self.conv4 = GraphConv(16, num_classes)

    def forward(self, data):
        data = add_edge_features_to_nodes(data)
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        #x = self.conv2(x, edge_index)
        #x = F.relu(x)
        #x = self.conv3(x, edge_index)
        #x = F.relu(x)
        x = self.conv4(x, edge_index)
    
        
        return F.log_softmax(x, dim=1)


model = GCN(num_node_features, num_classes)

## Training

In [65]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

all_labels = []
for data in train_data_list:
    all_labels.extend(data.y.tolist())
class_counts = Counter(all_labels)
total_samples = len(all_labels)
class_weights = {cls: total_samples/count for cls, count in class_counts.items()}
weights = [class_weights[i] for i in range(num_classes)]
class_weights_tensor = torch.tensor(weights).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights_tensor)

def train():
    model.train()
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        wandb.log({"loss": loss.item()})

def test(loader):
    model.eval()    
    correct = 0
    for data in loader:
        data = data.to(device)
        out = model(data)
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
    return correct / len(loader.dataset)

for epoch in range(2000):
    train()
    train_acc = test(train_loader)
    val_acc = test(val_loader)
    print(f'Epoch: {epoch}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
    wandb.log({"train_acc": train_acc, "val_acc": val_acc})



Epoch: 0, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 1, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 2, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 3, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 4, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 5, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 6, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 7, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 8, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 9, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 10, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 11, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 12, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 13, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 14, Train Acc: 10.8704, Val Acc: 13.5000
Epoch: 15, Train Acc: 10.8889, Val Acc: 13.5000
Epoch: 16, Train Acc: 11.0370, Val Acc: 13.5000
Epoch: 17, Train Acc: 12.1481, Val Acc: 13.8333
Epoch: 18, Train Acc: 15.0741, Val Acc: 14.8333
Epoch: 19, Train Acc: 19.0741, Val Acc: 18.0000
Epoch: 20, Train Acc: 24.4259, Val Acc: 18.3333
Ep

## Evaluation

In [21]:
test_acc = test(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
wandb.log({"test_acc": test_acc})

Test Accuracy: 16.6667


In [22]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

def get_predictions(loader):
    model.eval()
    all_preds = []
    all_labels = []
    for data in loader:
        data = data.to(device)
        out = model(data)
        preds = out.argmax(dim=1)
        all_preds.extend(preds.tolist())
        all_labels.extend(data.y.tolist())
    return all_preds, all_labels

def calculate_metrics(loader):
    preds, labels = get_predictions(loader)
    accuracy = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return accuracy, f1, auc

test_accuracy, test_f1, test_auc = calculate_metrics(test_loader)
print(f'Test Accuracy: {test_accuracy:.4f}, Test F1 Score: {test_f1:.4f}, Test AUC: {test_auc:.4f}')
wandb.log({"test_accuracy": test_accuracy, "test_f1": test_f1, "test_auc": test_auc})

Test Accuracy: 0.4910, Test F1 Score: 0.3742, Test AUC: 0.5116


In [None]:
wandb.finish()


0,1
loss,▃▄▃▂█▅▃▅▂▄▂▄▅▇▄▅▃▅▅▄▄▄▁▆▇▆▄▄▅▅▄▆▅▂▂▅▆▆▄▆
train_acc,█████████████████████▇█████████████▁▃▂▅▅
val_acc,▄▄▄▄▄▄▄▄▄▄▄▄█▄▅▄▄▄▄▄▄▇▄▄▄▄▄▄▄▄▄▄▄▄▇▁▆▆▇█

0,1
loss,0.79258
train_acc,35.48148
val_acc,19.0
