In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_dense_adj
import random

### Load RNA dataset

In [2]:
dict_dataset = {
    'GPN15k': pd.read_csv('data/GPN15k_silico_predictions.csv'),
    'PK50': pd.read_csv('data/PK50_silico_predictions.csv'),
    'PK90': pd.read_csv('data/PK90_silico_predictions.csv'),
    'R1': pd.read_csv('data/R1_silico_predictions.csv'),
}

RNA_sequences = []
RNA_structures = []

for k, df_data in dict_dataset.items():
    df_data = df_data[~df_data.vienna2_mfe.str.contains('x')]
    RNA_sequences = RNA_sequences + df_data.sequence.to_list()
    RNA_structures = RNA_structures + df_data.vienna2_mfe.to_list()
    print(k, len(df_data))

GPN15k 15000
PK50 2729
PK90 2173
R1 119999


In [3]:
from sklearn.model_selection import train_test_split

RNA_seq_train_val, RNA_seq_test, RNA_struct_train_val, RNA_struct_test = train_test_split(
    RNA_sequences, RNA_structures, test_size=0.23, random_state=42, shuffle=True)

RNA_seq_train, RNA_seq_val, RNA_struct_train, RNA_struct_val = train_test_split(
    RNA_seq_train_val, RNA_struct_train_val, test_size=0.2, random_state=42)


107723 32178
107723 32178


### Build Graph

In [4]:
def build_graph(sequence, structure):
    base_to_idx = {'A': 0, 'U': 1, 'G': 2, 'C': 3}
    struct_to_idx = {'.': 0, '(': 1, ')': 2}
    
    node_features = []
    for base, struct in zip(sequence, structure):
        base_feature = [0, 0, 0, 0]
        base_feature[base_to_idx[base]] = 1
        struct_feature = [0, 0, 0]
        struct_feature[struct_to_idx[struct]] = 1
        node_features.append(base_feature + struct_feature)
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    edges = []
    stack = []
    for i, (base, struct) in enumerate(zip(sequence, structure)):
        if i > 0:
            edges.append((i-1, i))
            edges.append((i, i-1))
        
        if struct == '(':
            stack.append(i)
        elif struct == ')' and stack:
            j = stack.pop()
            edges.append((i, j))
            edges.append((j, i))
    for i in range(len(sequence) - 1):
        j = i + 1
        edges.append((i, j))
        edges.append((j, i))
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    return Data(x=node_features, edge_index=edge_index)

# Convert data to graphs
graphs_train = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_train, RNA_struct_train)]
graphs_test = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_test, RNA_struct_test)]
graphs_val = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_val, RNA_struct_val)]

### GNN-AutoEncoder Model Design

In [6]:
class RNAGraphAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(RNAGraphAutoencoder, self).__init__()
        self.encoder = nn.ModuleList([
            GCNConv(input_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, latent_dim)
        ])
        
        self.decoder = nn.ModuleList([
            GCNConv(latent_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, input_dim)
        ])
        
        self.relu = nn.ReLU()
    
    def encode(self, x, edge_index):
        for layer in self.encoder:
            x = self.relu(layer(x, edge_index))
        return x
    
    def decode(self, x, edge_index):
        for layer in self.decoder:
            x = self.relu(layer(x, edge_index))
        return x
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Encoding
        latent = self.encode(x, edge_index)
        
        # Decoding
        reconstructed = self.decode(latent, edge_index)
        
        return reconstructed, latent


In [7]:
input_dim = 7  # 4 for bases + 3 for structure
hidden_dim = 64
latent_dim = 32
model = RNAGraphAutoencoder(input_dim, hidden_dim, latent_dim)


### Train the Model

In [8]:
# Create data loader
batch_size = 64
data_loader = DataLoader(graphs_train, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(graphs_val, batch_size=batch_size, shuffle=True)



In [9]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        reconstructed, _ = model(batch)
        loss = criterion(reconstructed, batch.x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            reconstructed, _ = model(batch)
            loss = criterion(reconstructed, batch.x)
            total_loss += loss.item()
    return total_loss / len(loader)

train_losses = []
val_losses = []

num_epochs = 150
for epoch in range(num_epochs):
    train_loss = train(model, data_loader, criterion, optimizer)
    val_loss = validate(model, val_data_loader, criterion)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

print("GNN-AutoEncoder training completed!")

torch.save(train_losses, 'train_losses.pth')
torch.save(val_losses, 'val_losses.pth')


plt.rc('font', size=14)  

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Training and Validation Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.legend()
plt.show()

Epoch 1/100, Loss: 0.1301
Epoch 2/100, Loss: 0.1143
Epoch 3/100, Loss: 0.1078
Epoch 4/100, Loss: 0.1044


KeyboardInterrupt: 

In [None]:
# torch.save(model.state_dict(), "autoencoder_model_v2.pth")

In [None]:
# def extract_features(model, loader):
#     model.eval()
#     features = []
    
#     with torch.no_grad():
#         for batch in loader:
#             _, latent = model(batch)
#             features.append(latent)
    
#     return torch.cat(features, dim=0)

# # Extract features for all graphs
# feature_loader = DataLoader(graphs_test, batch_size=batch_size, shuffle=False)
# extracted_features = extract_features(model, feature_loader)

# print(f"Extracted features shape: {extracted_features.shape}")