In [37]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
from torch.utils.data import DataLoader, TensorDataset, Dataset
import pandas as pd
import numpy as np
from scipy.stats import t, shapiro, kstest
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import itertools

In [38]:
#data preparation
def prepare_data():
    #loading data
    abundance = pd.read_csv('~/icr/simko/data/simko2_data/passport_prots.csv', index_col=0)
    abundance.index = abundance.index.astype(str)
    #removing cell lines with over 4000 nans
    nans_per_cl = abundance.isna().sum(axis=0)
    abundance_cl_filtered = abundance.loc[:, nans_per_cl<4000]
    #getting rid of protein with over 80% NaN (from the dataset filtered by CLs)
    prot_nan_count = abundance_cl_filtered.isna().sum(axis=1)
    prot_nan_percent = (prot_nan_count/abundance_cl_filtered.shape[1])*100
    abundance_filtered = abundance_cl_filtered[prot_nan_percent<80]

    #imputing witht the lower quartile average for each protein
    #set the protein names as the index - ignores it while we find the lower quartile
    def average_lower_quartile(x):
        sorted_abundances = x.dropna().sort_values()
        lower_qt_values = sorted_abundances.iloc[:int(len(sorted_abundances) * 0.25)]
        return lower_qt_values.mean()


    lower_qt_averages = abundance_filtered.apply(average_lower_quartile, axis=1)
    abundance_filtered_no_nan = abundance_filtered.apply(lambda x: x.fillna(lower_qt_averages[x.name]), axis=1)

    #transposing
    abundance_imputed = abundance_filtered_no_nan.T
    #scaling the imputed data
    scaler = StandardScaler()
    scaled_data = pd.DataFrame(scaler.fit_transform(abundance_imputed), index=abundance_imputed.index, columns=abundance_imputed.columns)
    #scaled_data = scaled_data.T
    return scaled_data

scaled_data = prepare_data()

In [39]:
#getting pbrm1 values so we can make the continuous - put the values between 0 and 1
raw_pbrm1 = scaled_data["PBRM1"].values
raw_min, raw_max = raw_pbrm1.min(), raw_pbrm1.max()
#condition = 
c = (raw_pbrm1 - raw_min) / (raw_max - raw_min) 

In [40]:
# getting rid of pbrm1 from data set to use in training
X = scaled_data.drop(columns=["PBRM1"]).values.astype(np.float32)

In [41]:
n_proteins = X.shape[1]     # after dropping PBRM1

In [42]:
#creating custom data set --> makes it easier to proccess it later on (data augmentation)
class ProteomeDataset(Dataset):
    def __init__(self, X, c):
        self.X = torch.from_numpy(X).float()
        self.c = torch.from_numpy(c).float().unsqueeze(-1)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, i):
        return self.X[i], self.c[i]

full_ds = ProteomeDataset(X, c)
loader  = DataLoader(full_ds, batch_size=32, shuffle=True, drop_last=True)

In [43]:
# 1) Factory for a 3-layer CVAE
def make_cvae_3layer(n_proteins, latent_dim, activation, dropout):
    act_fn = {
        'relu': nn.ReLU(),
        'lrelu': nn.LeakyReLU(0.1),
        'elu': nn.ELU(),
        'gelu': nn.GELU()
    }[activation]
    
    # Encoder: [in → 3 hidden layers → μ/logvar]
    enc = nn.Sequential(
        nn.Linear(n_proteins + 1, 800),
        nn.BatchNorm1d(800), act_fn, nn.Dropout(dropout),
        nn.Linear(800, 400),
        nn.BatchNorm1d(400), act_fn, nn.Dropout(dropout),
        nn.Linear(400, 200),
        nn.BatchNorm1d(200), act_fn, nn.Dropout(dropout),
    )
    fc_mu     = nn.Linear(200, latent_dim)
    fc_logvar = nn.Linear(200, latent_dim)

    # Decoder: [z+c → 3 hidden layers → out]
    dec = nn.Sequential(
        nn.Linear(latent_dim + 1, 200),
        nn.BatchNorm1d(200), act_fn, nn.Dropout(dropout),
        nn.Linear(200, 400),
        nn.BatchNorm1d(400), act_fn, nn.Dropout(dropout),
        nn.Linear(400, 800),
        nn.BatchNorm1d(800), act_fn, nn.Dropout(dropout),
        nn.Linear(800, n_proteins),
        nn.ReLU(),
    )

    class CVAE(nn.Module):
        def __init__(self):
            super().__init__()
            self.enc, self.fc_mu, self.fc_logvar, self.dec = enc, fc_mu, fc_logvar, dec

        def reparameterize(self, mu, logvar):
            std = (0.5 * logvar).exp()
            eps = torch.randn_like(std)
            return mu + eps * std

        def forward(self, x, c):
            h = self.enc(torch.cat([x, c], 1))
            mu, logvar = self.fc_mu(h), self.fc_logvar(h)
            z = self.reparameterize(mu, logvar)
            x_rec = self.dec(torch.cat([z, c], 1))
            return x_rec, mu, logvar

    return CVAE()

In [44]:
# 3) Training function with per-element MSE & KL logging
# ------------------------------------------------------
def train_one(model, loader, epochs=30, warmup=20, eps=0.02):
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    n_proteins = next(iter(loader))[0].shape[1]

    for epoch in range(1, epochs+1):
        model.train()
        β = min(1.0, epoch / warmup)
        sum_recon, sum_kl = 0.0, 0.0
        for x, c in loader:
            x, c = x, c
            # jitter c
            c_noisy = torch.clamp(c + torch.randn_like(c)*eps, 0, 1)
            x_rec, mu, logvar = model(x, c_noisy)
            recon = F.mse_loss(x_rec, x, reduction='sum')
            kl    = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss  = recon + β * kl

            opt.zero_grad(); loss.backward(); opt.step()
            sum_recon += recon.item()
            sum_kl    += kl.item()

        # per-element metrics
        total_elems = len(loader.dataset) * n_proteins
        mse_elem = sum_recon / total_elems
        rmse_elem = np.sqrt(mse_elem)
        avg_kl = sum_kl / len(loader.dataset)

        # print diagnostic
        print(f"[{epoch:02d}/{epochs}] β={β:.2f}  "
              f"MSE_elem={mse_elem:.4f}  RMSE={rmse_elem:.4f}  KL_per_samp={avg_kl:.2f}")

In [45]:
# 4) Evaluation function returning per-element MSE
# -------------------------------------------------
def eval_mse(model, loader):
    model.eval()
    sum_se = 0.0
    with torch.no_grad():
        for x, c in loader:
            x, c = x, c
            x_rec, _, _ = model(x, c)
            sum_se += F.mse_loss(x_rec, x, reduction='sum').item()
    total_elems = len(loader.dataset) * n_proteins
    return sum_se / total_elems

In [46]:
# 3) Hyperparameter grid
activations = ['relu','lrelu','elu','gelu']
dropouts    = [0.0, 0.2, 0.5]
latents     = [10, 20, 50, 100]

results = []

for act in activations:
    for dr in dropouts:
        for zdim in latents:
            print(f"\n>>> Testing act={act}, dropout={dr}, latent={zdim}")
            model = make_cvae_3layer(n_proteins, zdim, act, dr)
            train_one(model, loader, epochs=40, warmup=10, eps=0.02)
            mse = eval_mse(model, loader)
            print(f"→ Final per-element MSE: {mse:.4f}")
            results.append((act, dr, zdim, mse))


>>> Testing act=relu, dropout=0.0, latent=10
[01/40] β=0.10  MSE_elem=0.9966  RMSE=0.9983  KL_per_samp=7.26
[02/40] β=0.20  MSE_elem=0.9776  RMSE=0.9887  KL_per_samp=16.84
[03/40] β=0.30  MSE_elem=0.9712  RMSE=0.9855  KL_per_samp=19.56
[04/40] β=0.40  MSE_elem=0.9684  RMSE=0.9841  KL_per_samp=19.34
[05/40] β=0.50  MSE_elem=0.9663  RMSE=0.9830  KL_per_samp=19.31
[06/40] β=0.60  MSE_elem=0.9623  RMSE=0.9810  KL_per_samp=18.92
[07/40] β=0.70  MSE_elem=0.9602  RMSE=0.9799  KL_per_samp=19.08
[08/40] β=0.80  MSE_elem=0.9585  RMSE=0.9790  KL_per_samp=18.43
[09/40] β=0.90  MSE_elem=0.9568  RMSE=0.9782  KL_per_samp=17.93
[10/40] β=1.00  MSE_elem=0.9556  RMSE=0.9775  KL_per_samp=17.53
[11/40] β=1.00  MSE_elem=0.9535  RMSE=0.9765  KL_per_samp=17.45
[12/40] β=1.00  MSE_elem=0.9504  RMSE=0.9749  KL_per_samp=17.79
[13/40] β=1.00  MSE_elem=0.9491  RMSE=0.9742  KL_per_samp=18.10
[14/40] β=1.00  MSE_elem=0.9473  RMSE=0.9733  KL_per_samp=18.40
[15/40] β=1.00  MSE_elem=0.9460  RMSE=0.9726  KL_per_samp=1

KeyboardInterrupt: 

In [12]:
best = min(results, key=lambda x: x[-1])
print("\n=== Best config ===")
print(f"Activation:   {best[0]}")
print(f"Dropout:      {best[1]}")
print(f"Latent dim:   {best[2]}")
print(f"MSE per elt.: {best[3]:.4f}")


>>> Best hyperparams: (1, 400, 50, 'lrelu', 0.0) with per-element MSE = 0.45979664098698225
