In [98]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd

In [99]:
# Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc_mu = nn.Linear(128, latent_dim)  # Mean of latent space
        self.fc_log_var = nn.Linear(128, latent_dim)  # Log variance of latent space

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var

# Dedcoder
class Decoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, z):
        z = torch.relu(self.fc1(z))
        # output = torch.sigmoid(self.fc2(z)) 
        output = self.fc2(z)
        return output

# VAE
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(latent_dim, input_dim)

    def reparameterize(self, mu, log_var):
        # Reparameterization
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mu + epsilon * std

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        reconstructed = self.decoder(z)
        return reconstructed, mu, log_var

# Loss function
def vae_loss(reconstructed, original, mu, log_var):
    
    # Reconstruction loss
    reconstruction_loss = nn.functional.mse_loss(reconstructed, original, reduction='mean')

    # KL Divergence loss
    kl_divergence = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())

    return reconstruction_loss + kl_divergence

In [100]:

# Data
protein_emb = pd.read_parquet('data/protein_emb.parquet')       # protein embedding
lm_3utr_emb = pd.read_parquet('data/lm_3utr_emb.parquet')       # 3utr embedding
lm_5utr_emb = pd.read_parquet('data/lm_5utr_emb.parquet')       # 5utr embedding

In [101]:
# Normalize embeddings
protein_emb_z = pd.DataFrame(StandardScaler().fit_transform(protein_emb), index=protein_emb.index, columns=protein_emb.columns)
lm_3utr_emb_z = pd.DataFrame(StandardScaler().fit_transform(lm_3utr_emb), index=lm_3utr_emb.index, columns=lm_3utr_emb.columns)
lm_5utr_emb_z = pd.DataFrame(StandardScaler().fit_transform(lm_5utr_emb), index=lm_5utr_emb.index, columns=lm_5utr_emb.columns)

In [102]:
lm_5utr_emb

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
YAL002W,0.081238,0.030906,0.014586,-0.015279,0.028113,0.121253,0.068083,0.171768,-0.014086,0.544913,...,-0.030622,-0.045234,-0.263034,0.105608,0.000841,-0.068895,0.092609,0.012248,0.040396,-0.008431
YAL003W,0.032581,0.025514,-0.031749,-0.042321,-0.052058,0.100852,-0.040193,0.227620,-0.041441,0.127843,...,-0.004374,-0.035856,-0.159351,0.127666,0.066430,-0.019931,0.045678,0.022475,-0.025326,-0.090283
YAL004W,0.072448,-0.004530,0.082222,-0.067203,-0.010486,-0.015001,0.229780,0.198991,-0.101902,0.422294,...,-0.029051,-0.134829,-0.129500,0.040292,0.037237,0.001085,0.122111,0.084727,0.389339,0.085969
YAL008W,0.018980,0.342648,-0.041645,0.006749,-0.031477,0.081134,0.195077,0.200070,-0.049627,0.392190,...,0.044176,-0.022111,-0.216389,0.075199,0.044513,-0.020134,0.051188,0.055149,-0.255109,-0.061252
YAL009W,-0.009529,0.003373,0.000113,-0.007516,-0.010976,0.063841,0.181502,0.260650,-0.035609,0.194044,...,-0.007081,0.005693,-0.196746,0.071950,0.073255,0.002211,0.107910,-0.004822,-0.406600,-0.065601
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
YPR195C,0.068890,-0.117648,-0.035925,-0.100736,-0.065239,-0.019518,0.087728,0.295325,-0.094921,-0.036480,...,-0.025029,-0.038674,-0.232437,0.077736,0.033196,-0.027186,-0.032995,0.063515,-0.396488,0.041914
YPR197C,0.028925,0.348345,-0.017416,0.018307,-0.003333,-0.108214,0.327870,0.141014,0.052964,0.057818,...,0.009432,-0.045269,-0.246033,0.101393,0.072494,0.002908,0.146454,0.024336,-0.409820,0.048543
YPR199C,0.047805,0.006884,-0.046718,-0.047041,-0.023929,0.129827,0.087715,0.212255,0.011288,0.298354,...,0.051227,-0.005029,-0.221753,0.070380,0.056157,-0.003136,0.052018,0.061356,-0.283577,-0.040348
YPR200C,0.015357,0.254043,-0.029917,0.020158,-0.023756,-0.068058,0.228225,0.169323,0.069267,0.154024,...,0.072548,-0.033315,-0.242285,0.080294,0.068319,0.018254,0.084115,0.048010,-0.287566,0.052717


In [103]:
combined_embeddings = protein_emb_z.merge(lm_3utr_emb_z, on='gene_id').merge(lm_5utr_emb_z, on='gene_id').dropna()
combined_embeddings

Unnamed: 0_level_0,0_x,1_x,2_x,3_x,4_x,5_x,6_x,7_x,8_x,9_x,...,758,759,760,761,762,763,764,765,766,767
gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
YAL037W,1.225586,-0.764648,0.621582,1.170898,0.698242,-0.569824,1.037109,0.938477,0.245117,0.244507,...,-0.866974,-0.520096,-0.361732,0.737331,-1.280365,-2.040913,-0.448777,-0.490042,0.087562,1.271699
YAL016W,0.920898,1.061523,-0.992188,0.445312,-0.085999,0.120972,0.513672,-1.166016,0.159424,-0.334473,...,0.864854,-0.555629,-1.719696,-0.050071,-0.554177,0.696141,0.166101,1.516820,-0.802275,0.330826
YAL003W,1.151367,0.094421,-1.620117,0.570801,0.791992,2.529297,0.982910,-1.930664,0.286377,0.069397,...,-0.438211,-0.320423,0.810747,1.731299,1.136185,-0.699910,0.229856,-0.503283,1.116165,-1.666283
YAL053W,-0.709473,0.158691,-0.918457,1.225586,0.238403,-0.818848,1.323242,-0.783691,-0.724121,0.206543,...,-0.055519,0.166752,0.561709,2.328072,0.730168,0.808177,0.266933,-1.027493,0.345758,-0.387323
YAL031W-A,0.558105,1.243164,2.222656,-0.280029,-0.056213,0.715820,-1.269531,1.657227,2.470703,0.871094,...,-1.591291,0.035943,2.677892,-2.289967,-0.393135,0.429102,1.042011,-1.085483,0.209537,0.031343
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
YPL199C,0.565430,-2.208984,-0.330811,0.873047,-1.205078,0.432617,0.249268,-0.777344,-0.702148,0.985840,...,0.374927,-0.074745,-0.458241,1.955727,-4.351127,-1.551018,-1.626283,-4.049374,-0.253084,0.609092
YPL259C,-1.661133,1.403320,-0.637207,-0.623047,-1.816406,-1.172852,0.229858,-1.070312,-2.251953,-1.614258,...,0.251988,-0.412844,-1.872880,-0.000190,-0.622126,-0.590178,-0.853981,0.510699,-1.118729,1.635189
YPR179C,-0.149902,-0.697266,0.238892,1.750977,-0.744629,0.052734,-0.376221,0.222290,-0.202393,-0.553223,...,1.251696,-0.928374,1.307272,-0.884555,1.054958,-0.816577,0.860868,0.527546,0.594717,0.738711
YPR096C,1.947266,0.278564,-0.707520,0.021393,-0.857422,1.104492,-1.563477,1.338867,2.140625,1.521484,...,0.180405,0.238247,0.265211,-0.263454,-1.711692,0.736894,0.811675,-1.350364,-0.732571,-0.195301


In [104]:
combined_embeddings = torch.tensor(combined_embeddings.values)

In [105]:
# Normalise again ? not sure about this
combined_embeddings_z = StandardScaler().fit_transform(combined_embeddings)

In [106]:
np.var(combined_embeddings_z)

1.0000000000000007

In [107]:
combined_embeddings_z.shape

(6579, 2816)

In [108]:
# Train-test split
train_embeddings, test_embeddings = train_test_split(
    combined_embeddings_z, test_size=0.2, random_state=42
)

# Convert to PyTorch tensors
train_embeddings = torch.tensor(train_embeddings, dtype=torch.float32)
test_embeddings = torch.tensor(test_embeddings, dtype=torch.float32)

In [109]:
# Training
input_dim = combined_embeddings_z.shape[1]
latent_dim = 64

vae = VAE(input_dim, latent_dim)
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3, weight_decay=1e-5)

epochs = 20
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_embeddings, batch_size=batch_size, shuffle=True)

In [110]:
vae.train()
for epoch in range(epochs):
    total_loss = 0
    for batch_embeddings in train_loader:
        optimizer.zero_grad()
        reconstructed, mu, log_var = vae(batch_embeddings)
        loss = vae_loss(reconstructed, batch_embeddings, mu, log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader):.4f}")

# Testing
vae.eval()
with torch.no_grad():
    reconstructed, mu, log_var = vae(test_embeddings)
    test_loss = vae_loss(reconstructed, test_embeddings, mu, log_var)
    print(f"Test Loss: {test_loss.item():.4f}")

# Extract latent embeddings
with torch.no_grad():
    train_latent = vae.encoder(train_embeddings)[0]
    test_latent = vae.encoder(test_embeddings)[0]

print("Latent embeddings extracted:", train_latent.shape, test_latent.shape)

Epoch 1/20, Loss: 1.0255
Epoch 2/20, Loss: 0.9563
Epoch 3/20, Loss: 0.9347
Epoch 4/20, Loss: 0.9163
Epoch 5/20, Loss: 0.9008
Epoch 6/20, Loss: 0.8909
Epoch 7/20, Loss: 0.8779
Epoch 8/20, Loss: 0.8693
Epoch 9/20, Loss: 0.8605
Epoch 10/20, Loss: 0.8518
Epoch 11/20, Loss: 0.8431
Epoch 12/20, Loss: 0.8373
Epoch 13/20, Loss: 0.8317
Epoch 14/20, Loss: 0.8243
Epoch 15/20, Loss: 0.8171
Epoch 16/20, Loss: 0.8092
Epoch 17/20, Loss: 0.8074
Epoch 18/20, Loss: 0.8003
Epoch 19/20, Loss: 0.8000
Epoch 20/20, Loss: 0.7942
Test Loss: 0.7968
Latent embeddings extracted: torch.Size([5263, 64]) torch.Size([1316, 64])
