In [16]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.utils import to_dense_adj
import random

In [17]:
class RNAGraphAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(RNAGraphAutoencoder, self).__init__()
        self.encoder = nn.ModuleList([
            GCNConv(input_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, latent_dim)
        ])
        
        self.decoder = nn.ModuleList([
            GCNConv(latent_dim, hidden_dim),
            GCNConv(hidden_dim, hidden_dim),
            GCNConv(hidden_dim, input_dim)
        ])
        
        self.relu = nn.ReLU()
    
    def encode(self, x, edge_index):
        for layer in self.encoder:
            x = self.relu(layer(x, edge_index))
        return x
    
    def decode(self, x, edge_index):
        for layer in self.decoder:
            x = self.relu(layer(x, edge_index))
        return x
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Encoding
        latent = self.encode(x, edge_index)
        
        # Decoding
        reconstructed = self.decode(latent, edge_index)
        
        return reconstructed, latent


In [18]:
df = pd.read_csv("../p04_test.csv")

In [19]:
RNA_sequences = []
RNA_structures = []

df_data = df[~df.vienna_mfe.str.contains('x')]
RNA_sequences = RNA_sequences + df_data.sec_struc.to_list()
RNA_structures = RNA_structures + df_data.vienna_mfe.to_list()

In [20]:
from sklearn.model_selection import train_test_split

RNA_seq_train_val, RNA_seq_test, RNA_struct_train_val, RNA_struct_test = train_test_split(
    RNA_sequences, RNA_structures, test_size=0.23, random_state=42, shuffle=True)

RNA_seq_train, RNA_seq_val, RNA_struct_train, RNA_struct_val = train_test_split(
    RNA_seq_train_val, RNA_struct_train_val, test_size=0.2, random_state=42)


In [21]:
def build_graph(sequence, structure):
    base_to_idx = {'A': 0, 'T': 1, 'G': 2, 'C': 3}
    struct_to_idx = {'.': 0, '(': 1, ')': 2}
    
    node_features = []
    for base, struct in zip(sequence, structure):
        base_feature = [0, 0, 0, 0]
        base_feature[base_to_idx[base]] = 1
        struct_feature = [0, 0, 0]
        struct_feature[struct_to_idx[struct]] = 1
        node_features.append(base_feature + struct_feature)
    
    node_features = torch.tensor(node_features, dtype=torch.float)
    
    edges = []
    stack = []
    for i, (base, struct) in enumerate(zip(sequence, structure)):
        if i > 0:
            edges.append((i-1, i))
            edges.append((i, i-1))
        
        if struct == '(':
            stack.append(i)
        elif struct == ')' and stack:
            j = stack.pop()
            edges.append((i, j))
            edges.append((j, i))
    for i in range(len(sequence) - 1):
        j = i + 1
        edges.append((i, j))
        edges.append((j, i))
    
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    return Data(x=node_features, edge_index=edge_index)

# Convert data to graphs
graphs_train = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_train, RNA_struct_train)]
graphs_test = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_test, RNA_struct_test)]
graphs_val = [build_graph(seq, struct) for seq, struct in zip(RNA_seq_val, RNA_struct_val)]

In [22]:
batch_size = 64
data_loader = DataLoader(graphs_train, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(graphs_val, batch_size=batch_size, shuffle=False)
test_data_loader = DataLoader(graphs_test, batch_size=batch_size, shuffle=False)



In [23]:
input_dim = 7  
hidden_dim = 64
latent_dim = 32
model = RNAGraphAutoencoder(input_dim, hidden_dim, latent_dim)

model.load_state_dict(torch.load('/home/ec2-user/internship/modeling/GNN_autoencoder/autoencoder_model_v2.pth'))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  
criterion = nn.MSELoss()  
train_losses = []
val_losses = []
num_epochs = 5  
def train(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        reconstructed, _ = model(batch)
        loss = criterion(reconstructed, batch.x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            reconstructed, _ = model(batch)
            loss = criterion(reconstructed, batch.x)
            total_loss += loss.item()
    return total_loss / len(loader)


num_epochs = 100
for epoch in range(num_epochs):
    train_loss = train(model, data_loader, criterion, optimizer)
    val_loss = validate(model, val_data_loader, criterion)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


torch.save(model.state_dict(), './fine_tuned_autoencoder2.pth')

KeyboardInterrupt: 

In [24]:
input_dim = 7  
hidden_dim = 64
latent_dim = 32
model = RNAGraphAutoencoder(input_dim, hidden_dim, latent_dim)

model.load_state_dict(torch.load('./fine_tuned_autoencoder.pth'))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  
criterion = nn.MSELoss()  
def test(model, loader, criterion):
    model.eval()  
    total_loss = 0
    with torch.no_grad():
        for batch in loader:
            reconstructed, _ = model(batch)
            loss = criterion(reconstructed, batch.x)
            total_loss += loss.item()
    return total_loss / len(loader)
test_loss = test(model, test_data_loader, criterion)
print(f"Test Loss: {test_loss:.4f}")

KeyboardInterrupt: 