In [2]:
!pip install pandas
!pip install pytorch
!pip install scikit-learn

Collecting pandas
  Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (89 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting pytz>=2020.1 (from pandas)
  Downloading pytz-2025.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting tzdata>=2022.7 (from pandas)
  Downloading tzdata-2025.2-py2.py3-none-any.whl.metadata (1.4 kB)
Downloading pandas-2.2.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m00:01[0m:00:01[0m
[?25hDownloading pytz-2025.2-py2.py3-none-any.whl (509 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m509.2/509.2 kB[0m [31m134.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tzdata-2025.2-py2.py3-none-any.whl (347 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m347.8/347.8 kB

In [19]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim_1 = 1024, hidden_dim_2 = 512, latent_dim = 10, n_classes = 3):
        super(CVAE, self).__init__()
        #Encoder
        self.fc1 = nn.Linear(input_dim + n_classes, hidden_dim_1)
        self.fc2 = nn.Linear(hidden_dim_1, hidden_dim_2) 
        self.fc_logvar = nn.Linear(hidden_dim_2, latent_dim) #variance
        self.fc_mu = nn.Linear(hidden_dim_2, latent_dim) #mean
        #Decoder
        self.fc3 = nn.Linear(latent_dim+n_classes, hidden_dim_2)
        self.fc4 = nn.Linear(hidden_dim_2, hidden_dim_1)
        self.fc5 = nn.Linear(hidden_dim_1, input_dim)

    def encoder(self, x, y):
        concat_input = torch.concat([x,y], 1) 
        h = F.relu(self.fc1(concat_input))
        h = F.relu(self.fc2(h)) 
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterization(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add(mu) 

    def decoder(self, z,y): 
        concat_input = torch.concat([z, y], 1)
        h = F.relu(self.fc3(concat_input))
        h = F.relu(self.fc4(h))
        return F.sigmoid(self.fc5(h))
    
    def forward(self, x, y):
        mu, logvar = self.encoder(x,y)
        z = self.reparameterization(mu, logvar)
        return self.decoder(z, y), mu, logvar

In [21]:
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split


df = pd.read_csv("CpGSitesWithLabel.csv")
sample_idx = df["Unnamed: 0"].values
df.index = sample_idx
df = df.drop(columns = "Unnamed: 0")

X = df.iloc[:, :-1].values.astype('float32') 
y = df['label'].values.reshape(-1, 1)    

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42)

encoder = OneHotEncoder(sparse_output=False)
y_train_onehot = encoder.fit_transform(y_train).astype('float32')
y_test_onehot = encoder.transform(y_test).astype('float32')

from torch.utils.data import TensorDataset, DataLoader

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train_onehot))
test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test_onehot))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [22]:
def cvae_loss(recon_x, x, mu, logvar, beta=0.5):
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + beta * KLD, MSE, KLD

In [23]:
import torch.optim as optim


input_dim = X.shape[1]
n_classes = y_onehot.shape[1]

model = CVAE(input_dim=input_dim, n_classes=n_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [24]:
num_epochs = 35000
for epoch in range(num_epochs):
    model.train()
    total_loss, total_mse, total_kld = 0, 0, 0
    beta = min(1.0, epoch / 1000)  # KL warmup

    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        optimizer.zero_grad()
        recon, mu, logvar = model(x_batch, y_batch)
        loss, mse, kld = cvae_loss(recon, x_batch, mu, logvar, beta=beta)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_mse += mse.item()
        total_kld += kld.item()

    if (epoch + 1) % 5000 == 0:
        print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.2f}, MSE: {total_mse:.2f}, KLD: {total_kld:.2f}, β: {beta:.2f}")

        # ==== Test Evaluation ====
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for x_batch, y_batch in test_loader:
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                recon, mu, logvar = model(x_batch, y_batch)
                loss, _, _ = cvae_loss(recon, x_batch, mu, logvar, beta=beta)
                test_loss += loss.item()
        print(f"         → Test Loss: {test_loss:.2f}")


[Epoch 5000] Train Loss: 32.00, MSE: 25.90, KLD: 6.10, β: 1.00
         → Test Loss: 10.92
[Epoch 10000] Train Loss: 32.08, MSE: 24.09, KLD: 7.99, β: 1.00
         → Test Loss: 11.84
[Epoch 15000] Train Loss: 38.58, MSE: 30.18, KLD: 8.40, β: 1.00
         → Test Loss: 10.52
[Epoch 20000] Train Loss: 30.29, MSE: 21.78, KLD: 8.50, β: 1.00
         → Test Loss: 10.85
[Epoch 25000] Train Loss: 35.78, MSE: 28.51, KLD: 7.27, β: 1.00
         → Test Loss: 10.67
[Epoch 30000] Train Loss: 35.29, MSE: 26.44, KLD: 8.86, β: 1.00
         → Test Loss: 10.97
[Epoch 35000] Train Loss: 31.57, MSE: 23.28, KLD: 8.29, β: 1.00
         → Test Loss: 10.70


In [25]:
torch.save(model.state_dict(), 'cvae_weights.pt')

In [26]:
model.eval()

# 3 classes: 0 = Healthy, 1 = CML, 2 = AML
class_labels = torch.eye(3).to(device)  # shape = (3, 3)

# Choose how many samples per class to generate
n_samples = 100


In [30]:
generated_all = []
import numpy as np

for class_index in range(3):  # for each class
    # ① Sample latent vectors z ~ N(0, I)
    z = torch.randn(n_samples, 10).to(device)
    
    # ② Repeat the one-hot class label
    y = class_labels[class_index].unsqueeze(0).repeat(n_samples, 1)  # shape = (n_samples, 3)
    
    # ③ Decode
    with torch.no_grad():
        generated = model.decoder(z, y).cpu().numpy()
    
    # ④ Add class label
    labeled_generated = np.hstack([generated, np.full((n_samples, 1), class_index)])
    generated_all.append(labeled_generated)

In [38]:
import pandas as pd
generated_all = np.vstack(generated_all)

columns = list(df.columns.values)
df_augmented = pd.DataFrame(generated_all, columns=columns)

# View the result
df_augmented

Unnamed: 0,cg22081084,cg03797768,cg25152348,cg23959187,cg03909902,cg13003239,cg14047339,cg08206623,cg03705947,cg20913106,...,cg22250642,cg18316234,cg14094063,cg26481784,cg10456628,cg17891715,cg20591167,cg03989244,cg10967101,label
0,0.066301,0.065529,0.233929,0.061816,0.149184,0.307091,0.226935,0.215911,0.067947,0.050109,...,0.017428,0.014454,0.023493,0.018369,0.712667,0.155970,0.066862,0.073908,0.275129,0.0
1,0.103899,0.106901,0.369763,0.090299,0.259406,0.425041,0.337653,0.358039,0.101214,0.091504,...,0.019757,0.019054,0.034926,0.022159,0.807066,0.189010,0.086575,0.061659,0.307003,0.0
2,0.067198,0.066690,0.237451,0.062482,0.152344,0.310289,0.230144,0.219824,0.068705,0.051274,...,0.017380,0.014496,0.023656,0.018350,0.716559,0.156896,0.067211,0.073502,0.276008,0.0
3,0.128893,0.134651,0.463158,0.107818,0.338809,0.502641,0.411509,0.455417,0.122080,0.120224,...,0.019910,0.021325,0.042165,0.023438,0.872248,0.211549,0.099567,0.049882,0.330377,0.0
4,0.124475,0.128714,0.443212,0.104791,0.321147,0.486729,0.395556,0.435724,0.119091,0.114668,...,0.020338,0.020975,0.040524,0.023539,0.853199,0.205521,0.095980,0.053008,0.323315,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,0.059151,0.060964,0.268311,0.054783,0.171481,0.333043,0.243688,0.251328,0.056624,0.053960,...,0.014772,0.012676,0.026222,0.017684,0.739669,0.144003,0.057945,0.074494,0.267058,2.0
296,0.060289,0.062135,0.270571,0.055790,0.173628,0.334927,0.245769,0.253704,0.057729,0.055047,...,0.015090,0.012983,0.026728,0.018040,0.739699,0.145382,0.058862,0.075047,0.268248,2.0
297,0.058515,0.060316,0.266721,0.054223,0.170008,0.331668,0.242261,0.249635,0.056010,0.053336,...,0.014649,0.012548,0.026007,0.017551,0.739218,0.143291,0.057510,0.074452,0.266414,2.0
298,0.057988,0.059764,0.265199,0.053968,0.169126,0.330240,0.240982,0.248070,0.055515,0.052995,...,0.014622,0.012497,0.025939,0.017531,0.737237,0.142515,0.057097,0.074771,0.265523,2.0


In [39]:
df

Unnamed: 0,cg22081084,cg03797768,cg25152348,cg23959187,cg03909902,cg13003239,cg14047339,cg08206623,cg03705947,cg20913106,...,cg22250642,cg18316234,cg14094063,cg26481784,cg10456628,cg17891715,cg20591167,cg03989244,cg10967101,label
GSM1548150_5730504006_R01C02,0.063296,0.083376,0.219855,0.060166,0.12968,0.28666,0.226909,0.196601,0.064553,0.046237,...,0.016432,0.01456,0.023253,0.019152,0.854583,0.179352,0.060282,0.098393,0.291366,0
GSM1548151_5730504006_R02C02,0.074654,0.081414,0.216102,0.066239,0.149205,0.306048,0.243436,0.207915,0.068971,0.05484,...,0.017135,0.013192,0.022788,0.016011,0.86494,0.177226,0.075187,0.080655,0.321814,0
GSM1548152_5730504006_R03C02,0.066592,0.109882,0.241437,0.059483,0.13572,0.287089,0.207555,0.182915,0.063726,0.047902,...,0.015622,0.013525,0.018745,0.01931,0.836266,0.178733,0.058835,0.078219,0.313415,0
GSM1548156_5730504020_R01C02,0.077961,0.041466,0.212549,0.049237,0.120039,0.249048,0.158421,0.173888,0.067007,0.043966,...,0.019059,0.013901,0.026665,0.019084,0.507866,0.125064,0.064233,0.055108,0.23992,0
GSM1548157_5730504020_R02C02,0.063508,0.05684,0.221196,0.056048,0.132497,0.302168,0.201435,0.227935,0.064513,0.051435,...,0.01907,0.014687,0.021116,0.018752,0.613723,0.126915,0.056562,0.061516,0.285479,0
GSM1548158_5730504020_R03C02,0.074771,0.060016,0.221104,0.067749,0.148899,0.256944,0.186309,0.194322,0.07526,0.044821,...,0.018141,0.013704,0.031082,0.019307,0.691876,0.158018,0.081373,0.064076,0.25368,0
GSM1548162_5730504028_R01C02,0.052506,0.070985,0.220845,0.057448,0.132664,0.281852,0.214527,0.183626,0.0612,0.035159,...,0.016724,0.014269,0.024371,0.018952,0.633654,0.130587,0.092352,0.075456,0.259461,0
GSM1548163_5730504028_R02C02,0.065304,0.044097,0.255395,0.045015,0.121226,0.346152,0.191549,0.183681,0.061833,0.036429,...,0.01312,0.013214,0.020083,0.016073,0.642787,0.131127,0.061192,0.046134,0.260914,0
GSM1548164_5730504028_R03C02,0.067019,0.055915,0.223882,0.069789,0.156035,0.316202,0.224603,0.224718,0.065567,0.046881,...,0.016858,0.01466,0.022778,0.018063,0.724699,0.142256,0.061967,0.08054,0.252822,0
GSM1548168_5730504029_R01C02,0.063748,0.046497,0.273614,0.058261,0.137743,0.384108,0.285818,0.250121,0.098927,0.056316,...,0.017235,0.016054,0.020591,0.020055,0.742656,0.160548,0.059541,0.08155,0.280494,0


In [40]:
df_augmented.to_csv("aug_CpGSample.csv") 