In [None]:
import pandas as pd
import numpy as np
import sys
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

from madrigal.models.models import MLPEncoder
from madrigal.utils import DATA_DIR

In [None]:
class VAE(nn.Module):
    def __init__(self, vae_encoder_params: dict, hidden_dim: int, latent_dim: int, vae_decoder_params: dict):
        super(VAE, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder = MLPEncoder(**vae_encoder_params)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        self.decoder = MLPEncoder(**vae_decoder_params)  # while it is an MLPEncoder object, it is actually a decoder

    def encode(self, x):
        h = F.relu(self.encoder(x))
        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)
        # return torch.tanh(self.decoder(z))  # TODO: validate the range of perturbation values

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return z, recon, mu, logvar
    
    def loss(self, recons, x, mu, logvar):
        """ELBO assuming entries of x are binary variables, with closed form KLD."""
        
        recons_loss= F.mse_loss(recons, x)
        kld_loss=torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0)

        return recons_loss + kld_loss

In [21]:
class AE(nn.Module):
    def __init__(self, ae_encoder_params: dict, ae_decoder_params: dict):
        super(AE, self).__init__()
        self.encoder = MLPEncoder(**ae_encoder_params)
        self.decoder = MLPEncoder(**ae_decoder_params)  # while it is an MLPEncoder object, it is actually a decoder

    def encode(self, x):
        h = F.relu(self.encoder(x))
        return h

    def decode(self, z):
        return self.decoder(z)
        # return torch.tanh(self.decoder(z))  # TODO: validate the range of perturbation values

    def forward(self, x):
        h = F.relu(self.encode(x))
        recon = self.decode(h)
        return h, recon
    
    def loss(self, recons, x):
        recons_loss= F.mse_loss(recons, x)
        return recons_loss

In [None]:
cv = pd.read_csv(DATA_DIR+'views_features_new/cv/cv_cp_data.csv')

In [None]:
test_between = pd.read_csv(DATA_DIR+'polypharmacy/DrugBank/split_by_drugs_random/test_between_df.csv')
test_within = pd.read_csv(DATA_DIR+'polypharmacy/DrugBank/split_by_drugs_random/test_within_df.csv')

all_smiles = pd.read_csv(DATA_DIR+'views_features/combined_metadata_reindexed_ddi.csv')
all_smiles = all_smiles.canonical_smiles

In [6]:
test = np.concatenate((test_between['head'].values, test_within['head'].values, test_within['tail'].values ))
test = np.unique(test)
test_smiles = all_smiles.loc[test].values

overlap = np.intersect1d(cv.columns, test_smiles)
cv = cv[cv.columns.difference(overlap)]
cv = cv[cv.columns.drop(list(cv.filter(regex='Unnamed')))]

cv.to_csv('cv_no_test.csv')

In [7]:
from torch.utils.data import Dataset
import torch

class CVDataset(Dataset):

    def __init__(self, data_file):
        self.df = pd.read_csv(data_file, index_col=0)
        
    def __len__(self) -> int:
        return self.df.shape[1]

    def __getitem__(self, index):
        item = self.df.iloc[:,index]
        return torch.tensor(item)

In [8]:
dataset = CVDataset('cv_no_test.csv')

train_length=int(0.8* len(dataset))
test_length=len(dataset)-train_length

train_dataset, test_dataset = torch.utils.data.random_split(dataset,(train_length,test_length))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

encoder_params ={'in_dim': 559, 'hidden_dims': [512, 256], 'output_dim': 128, 'p':0.2, 
                 'norm': 'bn', 'actn': 'relu', 'order': 'nd'}
decoder_params = {'in_dim': 128, 'hidden_dims': [256, 512], 'output_dim': 559, 'p':0.2, 
                 'norm': 'bn', 'actn': 'relu', 'order': 'nd'}

model =  AE(encoder_params, ae_decoder_params = decoder_params)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

loss_fn = nn.MSELoss(reduction='sum')

def train(x):
    model.train()
    optimizer.zero_grad()
    z, recon = model(x)
    loss = model.loss(recon, x)
    loss.backward()
    optimizer.step()
    return loss
    
@torch.no_grad()
def test(test_loader):
    model.eval()
    losses = []
    for x in tqdm(test_loader):
        z, recon = model(x.float())
        loss = model.loss(recon, x)
        losses.append(loss.detach().cpu().numpy())
    print(f'Loss: {np.array(losses).ravel().mean()}')

for epoch in range(1, 10):
    print(f'Training at epoch {epoch}')
    losses = []
    for batch in tqdm(train_loader):
        loss = train(batch.float())
        losses.append(loss.detach().cpu().numpy())
    print(f'Loss: {np.array(losses).ravel().mean()}')

    print('Testing')
    test(test_loader)
    
torch.save(model.encoder.state_dict(), 'cv_model_ae.pt')

Training at epoch 1


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.75it/s]


Loss: 0.5102753043174744
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.37s/it]


Loss: 0.34106629117754694
Training at epoch 2


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.83it/s]


Loss: 0.30785658955574036
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]


Loss: 0.2573168741283297
Training at epoch 3


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.85it/s]


Loss: 0.29913127422332764
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]


Loss: 0.2531611630551279
Training at epoch 4


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.84it/s]


Loss: 0.2922775447368622
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]


Loss: 0.26249992536556266
Training at epoch 5


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.81it/s]


Loss: 0.29167234897613525
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.34s/it]


Loss: 0.2506233064018257
Training at epoch 6


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.87it/s]


Loss: 0.28825557231903076
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]


Loss: 0.24849607025472054
Training at epoch 7


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.85it/s]


Loss: 0.28548455238342285
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.32s/it]


Loss: 0.24797853691229146
Training at epoch 8


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.87it/s]


Loss: 0.2831953763961792
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.31s/it]


Loss: 0.2497517443389612
Training at epoch 9


100%|███████████████████████████████████████████████████████████████| 27/27 [00:05<00:00,  4.81it/s]


Loss: 0.2808248996734619
Testing


100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.33s/it]

Loss: 0.24528804807880883





In [24]:
@torch.no_grad()
def plot_test(test_loader):
    xs = []
    preds = []
    model.eval()
    for x in tqdm(test_loader):
        z, recon = model(x.float())
        xs.append(x.detach().cpu().numpy())
        preds.append(recon.detach().cpu().numpy())
    return np.array(xs), np.array(preds)

In [25]:
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=True)
x,pred = plot_test(test_loader)

100%|█████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.45s/it]


In [26]:
arr = np.concatenate([x,pred])

In [27]:
np.save('cv_arr.npy', arr)