In [1]:
import torch
from torch import nn, optim
import numpy as np
import mlflow
import pathlib

from pytorch_lightning.core.lightning import LightningModule

In [None]:
import pytorch_lightning
pytorch_lightning.__version__

In [2]:
mlflow.pytorch.autolog()



In [3]:
class Generator(LightningModule):
    def __init__(self, voc, embed_size=128, hidden_size=512, is_lstm=True, lr: float =1e-3, dev_mode: bool = False):
        super(Generator, self).__init__()
        self.voc = voc
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.output_size = voc.size

        self.embed = nn.Embedding(voc.size, embed_size)
        self.is_lstm = is_lstm
        rnn_layer = nn.LSTM if is_lstm else nn.GRU
        self.rnn = rnn_layer(embed_size, hidden_size, num_layers=3, batch_first=True)
        self.linear = nn.Linear(hidden_size, voc.size)
        self.optim = optim.Adam(self.parameters(), lr=lr)

    def forward(self, input, h):
        output = self.embed(input.unsqueeze(-1))
        output, h_out = self.rnn(output, h)
        output = self.linear(output).squeeze(1)
        return output, h_out

    def init_h(self, batch_size, labels=None):
        h = torch.rand(3, batch_size, 512)
        if labels is not None:
            h[0, batch_size, 0] = labels
        if self.is_lstm:
            c = torch.rand(3, batch_size, self.hidden_size)
        return (h, c) if self.is_lstm else h

    def likelihood(self, target):
        batch_size, seq_len = target.size()
        x = torch.LongTensor([self.voc.tk2ix["GO"]] * batch_size)
        h = self.init_h(batch_size)
        scores = torch.zeros(batch_size, seq_len)
        for step in range(seq_len):
            logits, h = self(x, h)
            logits = logits.log_softmax(dim=-1)
            score = logits.gather(1, target[:, step : step + 1]).squeeze()
            scores[:, step] = score
            x = target[:, step]
        return scores

    def PGLoss(self, loader):
        for seq, reward in loader:
            self.zero_grad()
            score = self.likelihood(seq)
            loss = score * reward
            loss = -loss.mean()
            loss.backward()
            self.optim.step()

    def sample(self, batch_size):
        x = torch.LongTensor([self.voc.tk2ix["GO"]] * batch_size)
        h = self.init_h(batch_size)
        sequences = torch.zeros(batch_size, self.voc.max_len).long()
        isEnd = torch.zeros(batch_size).bool()

        for step in range(self.voc.max_len):
            logit, h = self(x, h)
            proba = logit.softmax(dim=-1)
            x = torch.multinomial(proba, 1).view(-1)
            x[isEnd] = self.voc.tk2ix["EOS"]
            sequences[:, step] = x

            end_token = x == self.voc.tk2ix["EOS"]
            isEnd = torch.ge(isEnd + end_token, 1)
            if (isEnd == 1).all():
                break
        return sequences

    def evolve(self, batch_size, epsilon=0.01, crover=None, mutate=None):
        # Start tokens
        x = torch.LongTensor([self.voc.tk2ix["GO"]] * batch_size)
        # Hidden states initialization for exploitation network
        h = self.init_h(batch_size)
        # Hidden states initialization for exploration network
        h1 = self.init_h(batch_size)
        h2 = self.init_h(batch_size)
        # Initialization of output matrix
        sequences = torch.zeros(batch_size, self.voc.max_len).long()
        # labels to judge and record which sample is ended
        is_end = torch.zeros(batch_size).bool()

        for step in range(self.voc.max_len):
            logit, h = self(x, h)
            proba = logit.softmax(dim=-1)
            if crover is not None:
                ratio = torch.rand(batch_size, 1)
                logit1, h1 = crover(x, h1)
                proba = proba * ratio + logit1.softmax(dim=-1) * (1 - ratio)
            if mutate is not None:
                logit2, h2 = mutate(x, h2)
                is_mutate = (torch.rand(batch_size) < epsilon)
                proba[is_mutate, :] = logit2.softmax(dim=-1)[is_mutate, :]
            # sampling based on output probability distribution
            x = torch.multinomial(proba, 1).view(-1)

            is_end |= x == self.voc.tk2ix["EOS"]
            x[is_end] = self.voc.tk2ix["EOS"]
            sequences[:, step] = x
            if is_end.all():
                break
        return sequences

    def evolve1(self, batch_size, epsilon=0.01, crover=None, mutate=None):
        # Start tokens
        x = torch.LongTensor([self.voc.tk2ix["GO"]] * batch_size)
        # Hidden states initialization for exploitation network
        h = self.init_h(batch_size)
        # Hidden states initialization for exploration network
        h2 = self.init_h(batch_size)
        # Initialization of output matrix
        sequences = torch.zeros(batch_size, self.voc.max_len).long()
        # labels to judge and record which sample is ended
        is_end = torch.zeros(batch_size).bool()

        for step in range(self.voc.max_len):
            is_change = torch.rand(1) < 0.5
            if crover is not None and is_change:
                logit, h = crover(x, h)
            else:
                logit, h = self(x, h)
            proba = logit.softmax(dim=-1)
            if mutate is not None:
                logit2, h2 = mutate(x, h2)
                ratio = torch.rand(batch_size, 1) * epsilon
                proba = (
                    logit.softmax(dim=-1) * (1 - ratio) + logit2.softmax(dim=-1) * ratio
                )
            # sampling based on output probability distribution
            x = torch.multinomial(proba, 1).view(-1)

            x[is_end] = self.voc.tk2ix["EOS"]
            sequences[:, step] = x

            # Judging whether samples are end or not.
            end_token = x == self.voc.tk2ix["EOS"]
            is_end = torch.ge(is_end + end_token, 1)
            #  If all of the samples generation being end, stop the sampling process
            if (is_end == 1).all():
                break
        return sequences

    def fit(
        self, loader_train, out: pathlib.Path, loader_valid=None, epochs=100, lr=1e-3
    ):
        optimizer = optim.Adam(self.parameters(), lr=lr)
        log = open(out.with_suffix(".log"), "w")
        best_error = np.inf
        for epoch in range(epochs):
            for i, batch in enumerate(loader_train):
                optimizer.zero_grad()
                loss_train = self.likelihood(batch)
                loss_train = -loss_train.mean()
                loss_train.backward()
                optimizer.step()
                if i % 10 == 0 or loader_valid is not None:
                    seqs = self.sample(len(batch * 2))
                    ix = tensor_ops.unique(seqs)
                    seqs = seqs[ix]
                    smiles, valids = self.voc.check_smiles(seqs)
                    error = 1 - sum(valids) / len(seqs)
                    info = "Epoch: %d step: %d error_rate: %.3f loss_train: %.3f" % (
                        epoch,
                        i,
                        error,
                        loss_train.item(),
                    )
                    if loader_valid is not None:
                        loss_valid, size = 0, 0
                        for j, batch in enumerate(loader_valid):
                            size += batch.size(0)
                            loss_valid += (
                                -self.likelihood(batch).sum().item()
                            )
                        loss_valid = loss_valid / size / self.voc.max_len
                        if loss_valid < best_error:
                            torch.save(self.state_dict(), out.with_suffix(".pkg"))
                            best_error = loss_valid
                        info += " loss_valid: %.3f" % loss_valid
                    elif error < best_error:
                        torch.save(self.state_dict(), out.with_suffix(".pkg"))
                        best_error = error
                    print(info, file=log)
                    for i, smile in enumerate(smiles):
                        print("%d\t%s" % (valids[i], smile), file=log)
        log.close()

In [6]:
from torch.utils.data import DataLoader

from src.drugexr.config.constants import MODEL_PATH, PROC_DATA_PATH, TEST_RUN
from src.drugexr.data.preprocess import logger
from src.drugexr.data_structs.vocabulary import Vocabulary
from src.drugexr.models.generator import Generator

import pandas as pd

In [12]:
voc = Vocabulary(vocabulary_path=pathlib.Path(PROC_DATA_PATH / "chembl_voc.txt"))
out_dir = MODEL_PATH / "output/rnn"
BATCH_SIZE = 512

netP_path = out_dir / "lstm_chembl_R_dev"
netE_path = out_dir / "lstm_ligand_R_dev"

prior = Generator(voc, is_lstm=True)

# Train loop
chembl = pd.read_table(PROC_DATA_PATH / "chembl_corpus_DEV_1000.txt").Token
chembl = torch.LongTensor(voc.encode([seq.split(" ") for seq in chembl]))
chembl = DataLoader(chembl, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
prior.fit(chembl, out=netP_path, epochs=50)

RDKit ERROR: [19:25:38] SMILES Parse Error: extra open parentheses for input: '[S-]NBrOB0Br[O+][SiH][S+][Si][n-]o[O]sS[S+]=[O]([nH][C-][Te][n-][se+][CH2-][BH-]46on[Te][Se]'
[19:25:38] SMILES Parse Error: extra open parentheses for input: '[S-]NBrOB0Br[O+][SiH][S+][Si][n-]o[O]sS[S+]=[O]([nH][C-][Te][n-][se+][CH2-][BH-]46on[Te][Se]'
RDKit ERROR: [19:25:38] SMILES Parse Error: unclosed ring for input: '[c-]3N[NH+][SH]-3[Si][Se+]0S[NH-]6cC41=[c-][C+]BrB[S+][se][N][NH2+][n+][Si]1N[N+]3S[I+]7[As+][N-]8cb[BH2-][se][N-]p[SH2]83[N+][CH][O-][SiH][IH]s=[BH-][I+][cH-][se+][Si][SiH][b-][BH2-]5Br[SH2]'
RDKit ERROR: [19:25:38] SMILES Parse Error: syntax error while parsing: [BH2-]9[NH2+]Cl[SiH]I9[SH+]-Cl[BH2-]P[N-]N[se]5[BH3-]4B[SeH][NH+][BH-]#[nH][O-][SH+]o[se][te+][cH-][nH+]7p[As+][B-]o[SiH2][c+]P[Te]#[SH]F[SeH]93[BH2-]([CH2][N-][BH-]O5[se]8[C-][SH+][SH+]5c[te][SeH]IO[As]%[CH2][NH-][Si][O+][b-]3[SiH2][As][nH][cH-][c+]no8
RDKit ERROR: [19:25:38] SMILES Parse Error: Failed parsing SMILES '[BH2-]9[NH2

[19:25:38] SMILES Parse Error: unclosed ring for input: '[c-]3N[NH+][SH]-3[Si][Se+]0S[NH-]6cC41=[c-][C+]BrB[S+][se][N][NH2+][n+][Si]1N[N+]3S[I+]7[As+][N-]8cb[BH2-][se][N-]p[SH2]83[N+][CH][O-][SiH][IH]s=[BH-][I+][cH-][se+][Si][SiH][b-][BH2-]5Br[SH2]'
[19:25:38] SMILES Parse Error: syntax error while parsing: [BH2-]9[NH2+]Cl[SiH]I9[SH+]-Cl[BH2-]P[N-]N[se]5[BH3-]4B[SeH][NH+][BH-]#[nH][O-][SH+]o[se][te+][cH-][nH+]7p[As+][B-]o[SiH2][c+]P[Te]#[SH]F[SeH]93[BH2-]([CH2][N-][BH-]O5[se]8[C-][SH+][SH+]5c[te][SeH]IO[As]%[CH2][NH-][Si][O+][b-]3[SiH2][As][nH][cH-][c+]no8
[19:25:38] SMILES Parse Error: Failed parsing SMILES '[BH2-]9[NH2+]Cl[SiH]I9[SH+]-Cl[BH2-]P[N-]N[se]5[BH3-]4B[SeH][NH+][BH-]#[nH][O-][SH+]o[se][te+][cH-][nH+]7p[As+][B-]o[SiH2][c+]P[Te]#[SH]F[SeH]93[BH2-]([CH2][N-][BH-]O5[se]8[C-][SH+][SH+]5c[te][SeH]IO[As]%[CH2][NH-][Si][O+][b-]3[SiH2][As][nH][cH-][c+]no8' for input: '[BH2-]9[NH2+]Cl[SiH]I9[SH+]-Cl[BH2-]P[N-]N[se]5[BH3-]4B[SeH][NH+][BH-]#[nH][O-][SH+]o[se][te+][cH-][nH+]7p[As+][B-]o

KeyboardInterrupt: 

][te+][OH+][C+][SiH][P+][O-]9o%[nH]' for input: '[N-]C[SH2][n-]P7[se+]S[o+][CH2][PH][nH](o0[cH-][n-][N][CH2][S-]PO[BH-]%-[Se+]I[b-][BH3-][CH2][B-][c-][As]F[O+][BH2-][S-][s+]Clb[O+]1[Se+][se]=[IH][As+][P+]c[Te][PH]1O=32[SH+][O][c-]N[O+][N+][o+]B[nH][o+][te+][te]-[Se+][As][Se+][o+]=)%[Si][Si][O+]c[Se+][te+][Te][nH]9[se][NH-]8[s+][te+][OH+][C+][SiH][P+][O-]9o%[nH]'
[19:25:38] SMILES Parse Error: extra close parentheses while parsing: [CH][B-]F[N-][n+][Se][C-][n+]Cl[IH])#[n+]O[Te][PH][C-][PH][BH-]O[OH+]#[PH]o[NH+][c+][BH3-]0[se][Si]P[O-]5[nH][As][S+][SH2][b-]P[SH2]3n[SH+][B-][B])[se+]9[SiH][NH2+]n3[b-][NH2+][n-]Br[se+]3[B][te+]5[IH]b[Si]s#05p0p[O+]2[IH]-[P+]s[SH2][O+]1[O][nH][S-][se+]1[SeH][nH][c-][N]S[te+][As+][CH][O]Bo[Te][O][PH]
[19:25:38] SMILES Parse Error: Failed parsing SMILES '[CH][B-]F[N-][n+][Se][C-][n+]Cl[IH])#[n+]O[Te][PH][C-][PH][BH-]O[OH+]#[PH]o[NH+][c+][BH3-]0[se][Si]P[O-]5[nH][As][S+][SH2][b-]P[SH2]3n[SH+][B-][B])[se+]9[SiH][NH2+]n3[b-][NH2+][n-]Br[se+]3[B][te+]5[IH]b[Si]s#