In [1]:
from models import TransformerM, CVAE
import torch
import torch.nn as nn
from torch.nn import functional as F
#from train import train_CVAE
import numpy as np
from data.utils import load_data


In [2]:
molecules_input, molecules_output, char, vocab, labels, length = load_data('data/smiles_prop.txt',120)

In [3]:
num_train_data = int(len(molecules_input)*0.75)
train_molecules_input = molecules_input[0:num_train_data]
test_molecules_input = molecules_input[num_train_data:-1]

train_molecules_output = molecules_output[0:num_train_data]
test_molecules_output = molecules_output[num_train_data:-1]

train_labels = labels[0:num_train_data]
test_labels = labels[num_train_data:-1]

train_length = length[0:num_train_data]
test_length = length[num_train_data:-1]

In [4]:
n = np.random.randint(len(train_molecules_input), size = 4)
x = nn.functional.one_hot(torch.tensor([train_molecules_input[i] for i in n], dtype=torch.int64), num_classes=len(vocab))
y = torch.tensor([train_molecules_output[i] for i in n], dtype=torch.int64)
l = torch.tensor(np.array([train_length[i] for i in n]), dtype=torch.int64)
c = torch.tensor(np.array([train_labels[i] for i in n]).astype(float),dtype=torch.float).unsqueeze(1)

  x = nn.functional.one_hot(torch.tensor([train_molecules_input[i] for i in n], dtype=torch.int64), num_classes=len(vocab))


In [68]:
packed_x_embed = torch.nn.utils.rnn.pack_padded_sequence(input= x_embed, lengths=l, batch_first=True, enforce_sorted=False)

In [69]:
packed_x_embed

PackedSequence(data=tensor([[-2.4341,  1.3255, -0.3168,  ..., -0.7388, -0.8555, -0.0082],
        [-2.4341,  1.3255, -0.3168,  ..., -0.7388, -0.8555, -0.0082],
        [-2.4341,  1.3255, -0.3168,  ..., -0.7388, -0.8555, -0.0082],
        ...,
        [ 0.7519, -0.2907, -0.1540,  ..., -0.3306, -0.1111, -0.8612],
        [ 0.7519, -0.2907, -0.1540,  ..., -0.3306, -0.1111, -0.8612],
        [-1.0513,  0.1108,  0.8604,  ...,  1.5877,  0.4810, -0.8234]],
       grad_fn=<PackPaddedSequenceBackward0>), batch_sizes=tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1,
        1, 1]), sorted_indices=tensor([0, 2, 3, 1]), unsorted_indices=tensor([0, 3, 1, 2]))

In [49]:
class Embedding(nn.Module):
    def __init__(self, emb_dim=300, num_emb=35):
        super(Embedding,self).__init__()
        self.emb =nn.Embedding(embedding_dim=emb_dim, num_embeddings=num_emb)

    def forward(self,x):
        return self.emb(x.argmax(dim=-1))

In [9]:
class Encoder(nn.Module):
    def __init__(self, input_dim=35, emb_dim=300, num_emb=35, hidden_units=1024, num_layers=3, seq_len=120, cond_dim=3):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_units
        self.num_layers = num_layers
        self.seq_len = seq_len
        self.cond_dim = cond_dim

        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_units,
            num_layers = num_layers,
            batch_first=True,
            bidirectional=False,
        )

        self.emb = Embedding(emb_dim=emb_dim,
                             num_emb=num_emb)

    def forward(self, x, c):
        # x: tensor of shape (batch_size, seq_length, hidden_size)
        x_emb = self.emb(x)
        c = torch.nn.functional.interpolate(c.unsqueeze(1), size=(self.seq_len, self.cond_dim), mode='nearest').squeeze(1)
        x_emb = torch.cat([x_emb,c], dim=-1)
        outputs, (hidden, cell) = self.lstm(x_emb)
        return outputs, (hidden, cell)

In [10]:
class Parametrizator(nn.Module):
    def __init__(self,hidden_units, latent_dim, num_layers):
        super(Parametrizator, self).__init__()

        self.hidden_size = hidden_units
        self.latent_size = latent_dim
        self.lstm_factor = num_layers

        self.mean = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)
        self.log_variance = torch.nn.Linear(in_features= self.hidden_size * self.lstm_factor, out_features= self.latent_size)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + noise * std
        return z

    def forward(self, hid_state):
        enc_h = hid_state.view(hid_state.shape[1], self.hidden_size*self.lstm_factor)
        mu = self.mean(enc_h)
        log_var = self.log_variance(enc_h)

        z = self.reparametize(mu,log_var)
        return z, mu, log_var


In [11]:
class Decoder(nn.Module):
    def __init__(self, cond_dim, seq_len, latent_dim=120, hidden_units=1024, num_layers=3):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_units
        self.num_layers = num_layers
        self.latent_size = latent_dim
        self.lstm_factor = num_layers
        self.seq_len = seq_len
        self.cond_dim = cond_dim

        self.init_hidden_decoder = torch.nn.Linear(in_features= self.latent_size, out_features= self.hidden_size)

        self.lstm = nn.LSTM(
            input_size=latent_dim+cond_dim,
            hidden_size=hidden_units,
            num_layers = num_layers,
            batch_first=True,
            bidirectional=False,
        )

    def forward(self, z, c):
        c = torch.nn.functional.interpolate(c.unsqueeze(1), size=(self.seq_len, self.cond_dim), mode='nearest').squeeze(1)

        z_inp = z.repeat(1, self.seq_len, 1)
        hidden = z.repeat(1,self.num_layers,1)

        batch_size = c.shape[0]

        z_inp = z_inp.view(batch_size, self.seq_len, self.latent_size)
        hidden = hidden.view(self.num_layers, batch_size, self.latent_size)

        hidden_decoder = self.init_hidden_decoder(hidden)
        hidden_decoder = (hidden_decoder, hidden_decoder)

        z_inp = torch.cat([z_inp,c], dim=-1)

        outputs, (hidden, cell) = self.lstm(z_inp,hidden_decoder)

        return outputs

In [12]:
class Predictor(nn.Module):
    def __init__(self, hidden_units, classes):
        super(Predictor, self).__init__()
        self.hidden_units = hidden_units
        self.classes = classes

        self.fc1 = nn.Linear(hidden_units, 256)
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,classes)

        self.predictor = nn.Sequential(self.fc1,self.fc2,self.fc3)

    def forward(self, x):
        return self.predictor(x)

In [13]:
class CVAE(nn.Module):
    def __init__(self, cond_dim = 3, hidden_units = 512, num_layers = 3, emb_dim = 300, latent_dim = 256,
                 vocab_size = len(vocab), seq_len = 120):
        super(CVAE, self).__init__()
        self.cond_dim = cond_dim
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.emb_dim = emb_dim
        self.latent_dim = latent_dim
        self.num_emb = vocab_size
        self.seq_len = seq_len

        self.enc = Encoder(input_dim=emb_dim + cond_dim,
                           emb_dim=emb_dim,
                           num_emb=vocab_size,
                           hidden_units=hidden_units,
                           num_layers=num_layers,
                           seq_len=seq_len,
                           cond_dim=cond_dim)

        self.param = Parametrizator(hidden_units=hidden_units,
                                    latent_dim=latent_dim,
                                    num_layers=num_layers)

        self.dec = Decoder(cond_dim=cond_dim,
                           seq_len=seq_len,
                           latent_dim=latent_dim,
                           hidden_units=hidden_units,
                           num_layers=num_layers)

        self.pred = Predictor(hidden_units=hidden_units,
                              classes=self.num_emb)

    def forward(self, x, c):
        #encoding
        out, state= self.enc(x,c)

        #parametrization
        z, mu, log_var = self.param(state[0])

        #decoding
        out = self.dec(z,c)

        return self.pred(out).argmax(dim=-1), mu, log_var

    def sample(self, z,c):
        c = torch.nn.functional.interpolate(c.unsqueeze(1), size=(self.seq_len, self.cond_dim), mode='nearest').squeeze(1)
        out = self.dec(z,c)
        return self.pred(out).argmax(dim=-1)


In [None]:
class LSTMVAE(nn.Module):
    """LSTM-based Variational Auto Encoder"""

    def __init__(
        self, seq_size, cond_size, hidden_size, latent_size, device=torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    ):
        """
        input_size: int, batch_size x sequence_length x input_dim
        hidden_size: int, output size of LSTM AE
        latent_size: int, latent z-layer size
        num_lstm_layer: int, number of layers in LSTM
        """
        super(LSTMVAE, self).__init__()
        self.device = device

        # dimensions
        self.input_size = seq_size + cond_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.num_layers = 1

        #embeddings
        self.emb = nn.Embedding(num_embeggings = self.vocab_size, embedding_dim=self.embed_size)
        # lstm ae
        self.lstm_enc = Encoder(
            seq_size=seq_size, cond_size = cond_size, hidden_size=hidden_size, num_layers=self.num_layers
        )
        self.lstm_dec = Decoder(
            input_size=latent_size + cond_size,
            output_size=seq_size,
            hidden_size=hidden_size,
            num_layers=self.num_layers,
        )

        self.fc21 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc22 = nn.Linear(self.hidden_size, self.latent_size)
        self.fc3 = nn.Linear(self.latent_size, self.hidden_size)

    def reparametize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        noise = torch.randn_like(std)

        z = mu + noise * std
        return z

    def forward(self, x, c):
        batch_size, seq_len, feature_dim = x.shape

        # encode input space to hidden space
        enc_hidden = self.lstm_enc(x, c)
        enc_h = enc_hidden[0].view(batch_size, self.hidden_size)

        # extract latent variable z(hidden space to latent space)
        mean = self.fc21(enc_h)
        logvar = self.fc22(enc_h)
        z = self.reparametize(mean, logvar)  # batch_size x latent_size

        # initialize hidden state as inputs
        h_ = self.fc3(z).unsqueeze(0)

        # decode latent space to input space
        z = z.repeat(1, seq_len, 1)
        z = z.view(batch_size, seq_len, self.latent_size)

        # initialize hidden state
        hidden = (h_.contiguous(), h_.contiguous())
        reconstruct_output, hidden = self.lstm_dec(z,c,hidden)

        x_hat = reconstruct_output

        # calculate vae loss
        losses = self.loss_function(x_hat, x, mean, logvar)
        m_loss, recon_loss, kld_loss = (
            losses["loss"],
            losses["Reconstruction_Loss"],
            losses["KLD"],
        )

        return x_hat, m_loss

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = 0.00025  # Account for the minibatch samples from the dataset
        recons_loss = F.mse_loss(recons, input)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )

        loss = recons_loss + kld_weight * kld_loss
        return {
            "loss": loss,
            "Reconstruction_Loss": recons_loss.detach(),
            "KLD": -kld_loss.detach(),
        }

In [5]:
cvae = CVAE()

In [6]:
out = cvae(x,c,l)

torch.Size([4, 256]) torch.Size([4, 256])


In [10]:
def sequence_mask(lengths, maxlen, dtype=torch.int32):
    if maxlen is None:
        maxlen = lengths.max()
    mask = ~(torch.ones((len(lengths), maxlen)).cumsum(dim=1).t() > lengths).t()
    mask.type(dtype)
    return mask

def get_losses(y_hat, y, l, mu, logvar, kld_weight=0.0025):
    #weight = sequence_mask(l,y.shape[1])
    #weight = torch.randint(0,1,(120,4))
    loss = nn.CrossEntropyLoss()
    #print(y_hat.shape, torch.permute(y_hat,(0,2,1)).shape, y.shape, weight.shape)
    recons_loss = loss(torch.permute(y_hat,(0,2,1)), y)
    kld_loss = torch.mean(
            -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp(), dim=1), dim=0
        )
    final_loss = recons_loss + kld_weight * kld_loss

    return recons_loss, kld_loss, final_loss

In [11]:
get_losses(out[0],y,l,out[2],out[3])

(tensor(3.5558, grad_fn=<NllLoss2DBackward0>),
 tensor(5.7151, grad_fn=<MeanBackward1>),
 tensor(3.5701, grad_fn=<AddBackward0>))