In [12]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

pd.set_option('display.float_format', lambda x: '%.5f' % x)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
df = pd.read_csv('../data/sample.csv', index_col=0)
df.reset_index(drop=True, inplace=True)
print(f'Shape of the dataframe: {df.shape}')
target = 'Stability'
X = df.drop(target, axis=1)
y = df[target]

# Normalization
scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), columns=X.columns)

X = X.astype(np.float32)
y = y.astype(np.int64)

Shape of the dataframe: (372, 1411)


In [79]:
import torch
import torch.nn as nn

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_hidden = 128
        self.encoder = nn.Sequential(
            nn.Linear(1410, 516),  # input size: 784, output size: 516
            nn.ReLU(),  # apply the ReLU activation function
            nn.Linear(516, self.num_hidden),  # input size: 516, output size: num_hidden
            nn.ReLU(),  # apply the ReLU activation function
        )
        self.decoder = nn.Sequential(
            nn.Linear(self.num_hidden, 516),  # input size: num_hidden, output size: 516
            nn.ReLU(),  # apply the ReLU activation function
            nn.Linear(516, 1410),  # input size: 516, output size: 784
            nn.Sigmoid(),  # apply the sigmoid activation function to compress the output to a range of (0, 1)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded
    
X_train = torch.from_numpy(X.values)
model = AutoEncoder()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.MSELoss()
model.to('cpu')
train_loader = torch.utils.data.DataLoader(X_train, batch_size=12, shuffle=True)

EPOCHS = 200
for epoch in range(EPOCHS):
    for batch_idx, data in enumerate(train_loader):
        data = data.to('cpu')
        encoded, decoded = model(data)
        loss = criterion(decoded, data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch [{epoch}/{EPOCHS}], Loss: {loss.item():.4f}')
            

Epoch [0/200], Loss: 1.4465
Epoch [10/200], Loss: 0.8531
Epoch [20/200], Loss: 0.8378
Epoch [30/200], Loss: 0.7261
Epoch [40/200], Loss: 0.7659
Epoch [50/200], Loss: 0.6877
Epoch [60/200], Loss: 0.6325
Epoch [70/200], Loss: 0.6627
Epoch [80/200], Loss: 0.6469
Epoch [90/200], Loss: 0.6006
Epoch [100/200], Loss: 0.6216
Epoch [110/200], Loss: 0.8733
Epoch [120/200], Loss: 0.6254
Epoch [130/200], Loss: 0.5846
Epoch [140/200], Loss: 0.6240
Epoch [150/200], Loss: 0.6137
Epoch [160/200], Loss: 0.6334
Epoch [170/200], Loss: 0.6163
Epoch [180/200], Loss: 0.5820
Epoch [190/200], Loss: 0.5956


In [80]:
gen_X = model(X_train)[1].detach().numpy()
gen_X = pd.DataFrame(gen_X, columns=X.columns)
gen_X

Unnamed: 0,Gamma_AATA,Gamma_IPPS,Gamma_IPPSm,Gamma_DHQTi,Gamma_ADCS,Gamma_ABTA,Gamma_PGL,Gamma_ACACT1r,Gamma_ACOAHim,Gamma_ACOTAim,...,sigma_km_substrate_ccm2tp,sigma_km_product_ccm2tp,sigma_km_substrate_pca2tp,sigma_km_product_pca2tp,sigma_km_substrate_r2073_1,sigma_km_product_r2073_1,sigma_km_substrate1_r_4235,sigma_km_product1_r_4235,sigma_km_substrate2_r_4235,sigma_km_product2_r_4235
0,0.00000,0.00001,0.00000,0.00001,0.83094,0.00003,0.00006,0.00111,0.00000,0.00000,...,0.00002,0.00000,0.00000,0.00000,0.35805,0.00000,0.99959,0.00000,0.00000,0.00000
1,0.00000,0.00000,0.00003,0.00000,1.00000,0.00003,0.00011,0.00004,0.00000,0.00000,...,0.93780,0.44513,0.00000,0.99927,0.00022,0.99958,0.58766,0.00218,0.00189,0.00004
2,0.00000,0.00002,0.00000,0.00002,0.93840,0.00002,0.00006,0.00067,0.00000,0.00000,...,0.00002,0.00000,0.00000,0.00000,0.35089,0.00000,0.99958,0.00000,0.00000,0.00000
3,0.00000,0.00000,0.00005,0.00000,1.00000,0.00002,0.00010,0.00002,0.00000,0.00000,...,0.93853,0.46261,0.00000,0.99897,0.00024,0.99961,0.61008,0.00188,0.00203,0.00004
4,0.00000,0.00002,0.00002,0.00001,0.87038,0.00021,0.00010,0.00092,0.00000,0.00000,...,0.00000,1.00000,0.00000,0.00000,0.00000,0.00069,0.00001,0.13128,1.00000,0.99999
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
367,0.99649,0.00000,0.00000,0.00000,0.00000,0.00011,0.00032,0.33998,0.00000,0.00000,...,0.00000,0.00000,0.00000,0.70926,0.00000,0.00041,0.00000,0.00000,0.00000,0.97631
368,0.00004,0.00000,0.00000,0.00000,0.00000,0.00004,0.00002,0.26345,0.00000,0.00000,...,0.00000,0.99904,0.00000,0.00000,0.00004,0.00000,0.00000,0.00000,0.00000,0.00000
369,0.97680,0.00000,0.00000,0.00000,0.00000,0.00043,0.00069,0.34696,0.00000,0.00000,...,0.00000,0.99932,0.00000,0.00039,0.00000,0.99998,0.00000,0.00005,1.00000,1.00000
370,0.00000,0.00000,0.00000,0.00000,0.00000,0.00001,0.00006,0.31991,0.00000,0.00000,...,0.00000,0.99990,0.00000,0.00000,0.00039,0.00000,0.28188,0.00000,0.38899,0.00000


In [81]:
X

Unnamed: 0,Gamma_AATA,Gamma_IPPS,Gamma_IPPSm,Gamma_DHQTi,Gamma_ADCS,Gamma_ABTA,Gamma_PGL,Gamma_ACACT1r,Gamma_ACOAHim,Gamma_ACOTAim,...,sigma_km_substrate_ccm2tp,sigma_km_product_ccm2tp,sigma_km_substrate_pca2tp,sigma_km_product_pca2tp,sigma_km_substrate_r2073_1,sigma_km_product_r2073_1,sigma_km_substrate1_r_4235,sigma_km_product1_r_4235,sigma_km_substrate2_r_4235,sigma_km_product2_r_4235
0,-1.15502,-1.05570,-0.07834,-0.95020,2.14162,-1.22708,-0.82135,-1.95823,-0.08789,-1.07325,...,-0.15560,-0.83392,-1.29568,-0.78388,0.35996,0.56465,1.39888,-0.85900,-0.75376,0.04081
1,-1.15502,-1.05570,-0.07834,-0.95020,2.14162,-1.22708,-0.82135,-1.95823,-0.08789,-1.07325,...,0.93404,0.45383,-0.64452,1.10484,-1.58496,1.50831,0.60724,-0.07813,-0.18908,0.00545
2,-1.16600,-1.04058,-0.07834,-1.00039,2.16442,-1.22331,-0.72483,-1.94499,-0.08789,-1.05804,...,-0.15560,-0.83392,-1.29568,-0.78388,0.35996,0.56465,1.39888,-0.85900,-0.75376,0.04081
3,-1.16600,-1.04058,-0.07834,-1.00039,2.16442,-1.22331,-0.72483,-1.94499,-0.08789,-1.05804,...,0.93404,0.45383,-0.64452,1.10484,-1.58496,1.50831,0.60724,-0.07813,-0.18908,0.00545
4,-1.16600,-1.04058,-0.07834,-1.00039,2.16442,-1.22331,-0.72483,-1.94499,-0.08789,-1.05804,...,-1.00163,1.42485,0.44313,-0.98671,0.06677,-1.06887,-1.11711,0.15654,1.08306,1.10362
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
367,2.90041,-1.23083,-0.07465,1.74067,-1.46207,-3.38558,-1.30585,0.34653,-0.08789,2.34649,...,0.53647,-0.00018,-0.89084,0.67510,-1.09354,-0.03246,0.98505,-0.99212,1.48204,0.88742
368,2.90041,-1.23083,-0.07465,1.74067,-1.46207,-3.38558,-1.30585,0.34653,-0.08789,2.34649,...,1.46449,1.29931,1.11685,1.65453,-1.66489,-0.96976,0.02322,-0.01580,-0.26979,-0.95833
369,2.90041,-1.23083,-0.07465,1.74067,-1.46207,-3.38558,-1.30585,0.34653,-0.08789,2.34649,...,-1.05140,1.15323,0.94377,-0.92009,-1.52985,1.14493,-1.85184,-1.51292,1.02730,1.19431
370,2.90041,-1.23083,-0.07465,1.74067,-1.46207,-3.38558,-1.30585,0.34653,-0.08789,2.34649,...,0.82920,1.47659,-0.87015,-0.80147,-0.53493,-1.18567,0.37712,-1.24713,0.31335,1.30485
