In [None]:
!conda install -y -c rdkit rdkit
!conda install -y -c conda-forge editdistance
!pip install deepsmiles
!pip install selfies

In [None]:
from typing import List, Optional
import logging
from itertools import islice

import yaml
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.io
from torchvision import transforms
import pytorch_lightning as pl

import albumentations as A
import cv2

from rdkit import Chem
from rdkit.Chem import inchi
import deepsmiles
import selfies

try:
    from editdistance import eval as lsdistance
except ImportError:
    from Levenshtein import distance as lsdistance


# ============ CONFIG ============

DATASET = "/kaggle/input/bms-molecular-translation"
TARGETS = "/kaggle/input/bms-custom/train_labels.csv"
TEST_PATH = "/kaggle/input/bms-molecular-translation/sample_submission.csv"

CONFIG = """
epochs: 1
target: selfies
imsize: 224
batch_size: 64
encoded_image_size: 14
input_grayscale: False
encoder_model: resnet18
encoder_fine_tune: False
encoder_pretrained: False
attention_dim: 256
embedding_dim: 256
decoder_dim: 512
encoder_dim: 512
decoder_dropout: 0.5
learning_rate: 0.004
max_grad_norm: 5
max_pred_len: 120
gpus: 1
"""


# ============ UTILS ============


def inchi2smiles(inchi_str: str):
    mol = inchi.MolFromInchi(inchi_str)
    return Chem.MolToSmiles(mol)


def smiles2inchi(smiles_str: str):
    mol = Chem.MolFromSmiles(smiles_str)
    if mol is not None:
        inchi_str = inchi.MolToInchi(mol)
    else:
        inchi_str = "InChI=1S/"
    return inchi_str


def smiles2deepsmiles(smiles: List[str], rings=True, branches=True):
    converter = deepsmiles.Converter(rings=rings, branches=branches)
    return [converter.encode(smile) for smile in smiles]


def deepsmiles2smiles(dsm: List[str], rings=True, branches=True):
    converter = deepsmiles.Converter(rings=rings, branches=branches)
    return [converter.decode(ds) for ds in dsm]


def smiles2selfies(smiles: List[str]):
    return [selfies.encoder(smile) for smile in smiles]


def selfies2smiles(selfies_list: List[str]):
    return [selfies.decoder(sf) for sf in selfies_list]


def selfies2inchi(selfies_str: str) -> str:
    smiles = selfies.decoder(selfies_str)
    inchi_str = smiles2inchi(smiles)
    return inchi_str


# ============ EVALUATION ============


def eval_ld_batch(y_true: List[str], y_pred: List[str]):
    lds = [lsdistance(y, y_) for y, y_ in zip(y_true, y_pred)]
    mean_ld = np.mean(lds)
    return mean_ld


class Evaluator:

    def __init__(self, dataset):
        self.dataset = dataset

    def eval_batch(self, idxs, y_pred):
        original_targets = self.dataset.get_original_targets(idxs)
        ld = eval_ld_batch(original_targets, y_pred)
        return ld


# ============ DATA ============


PRETRAINED_MEAN = [0.485, 0.456, 0.406]
PRETRAINED_STD = [0.229, 0.224, 0.225]
PRETRAINED_TEST_MEAN = [0.485, 0.456, 0.406]
PRETRAINED_TEST_STD = [0.229, 0.224, 0.225]


def get_img_file_path(
        image_id: str,
        split: str,
        dataset_path: str = "../data/raw"
):
    return "{}/{}/{}/{}/{}/{}.png".format(
        dataset_path, split, image_id[0], image_id[1], image_id[2], image_id
    )


def load_targets(targets_path: str, target: str) -> pd.DataFrame:
    usecols = list({"image_id", "InChI", target})
    df_targets = pd.read_csv(
        targets_path,
        usecols=usecols
    )
    df_targets.rename(columns={target: "target"}, inplace=True)
    return df_targets


class Tokenizer:

    def __init__(self):
        self.stoi = dict()
        self.itos = dict()
        self.aux_tokens = ['<SOS>', '<EOS>', '<UNK>', '<PAD>']

    def __len__(self):
        return len(self.stoi)

    def fit(self, labels: List[str]):
        raise NotImplementedError()

    def tokenize(self, label: str) -> List[int]:
        raise NotImplementedError()

    def reverse_tokenize(self, idxs: List[int]) -> str:
        raise NotImplementedError()


class SelfiesTokenizer(Tokenizer):

    def fit(self, labels: List[str]):
        logging.info("Constructing SELFIES vocabulary")
        vocabulary = selfies.get_alphabet_from_selfies(tqdm(labels))
        vocabulary = sorted(list(vocabulary))
        vocabulary.extend(self.aux_tokens)
        self.stoi = {s: i for i, s in enumerate(vocabulary)}
        self.itos = dict(enumerate(vocabulary))

    def tokenize(self, label: str) -> List[int]:
        unk_idx = self.stoi['<UNK>']
        tokens = selfies.split_selfies(label)
        idxs = [self.stoi.get(token, unk_idx) for token in tokens]
        return idxs

    def reverse_tokenize(self, idxs: List[int], filter_aux: bool = True) -> str:
        tokens = map(self.itos.get, idxs)
        if filter_aux:
            tokens = filter(lambda t: t not in self.aux_tokens, tokens)
        return ''.join(tokens)


class MolecularCaptioningDataset(Dataset):

    def __init__(
            self,
            dataset_path: str,
            df_targets: pd.DataFrame,
            tokenizer: Tokenizer,
            transforms
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.df_targets = df_targets
        self.tokenizer = tokenizer
        self.transforms = transforms

    def __len__(self):
        return self.df_targets.shape[0]

    def __getitem__(self, idx):
        row = self.df_targets.iloc[idx]
        image_id, target = row["image_id"], row["target"]
        img_path = get_img_file_path(
            image_id,
            "train",
            dataset_path=self.dataset_path
        )
        # img = torchvision.io.read_image(img_path)
        # img = img / 255.0
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = self.transforms(img)
        tgt_sequence = self.tokenizer.tokenize(target)
        tgt_len = torch.LongTensor([len(tgt_sequence)])
        tgt_sequence = torch.LongTensor(tgt_sequence)
        return img, tgt_sequence, tgt_len

    def collate_fn(self, batch):
        imgs, labels, label_lengths = [], [], []
        for data_point in batch:
            imgs.append(data_point[0])
            labels.append(data_point[1])
            label_lengths.append(data_point[2])
        labels = pad_sequence(labels, batch_first=True, padding_value=self.tokenizer.stoi["<PAD>"])
        return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)


class MolecularCaptioningValDataset(MolecularCaptioningDataset):

    def __getitem__(self, idx):
        img, tgt_sequence, tgt_len = super(MolecularCaptioningValDataset, self).__getitem__(idx)
        return img, tgt_sequence, tgt_len, torch.LongTensor([idx])

    def collate_fn(self, batch):
        imgs, labels, label_lengths = super(MolecularCaptioningValDataset, self).collate_fn(batch)
        idxs = []
        for data_point in batch:
            idxs.append(data_point[-1])
        return imgs, labels, label_lengths, torch.stack(idxs).reshape(-1, 1)

    def get_original_targets(self, idxs):
        return self.df_targets.iloc[idxs]["InChI"]


class MolecularCaptioningTestDataset(Dataset):

    def __init__(
            self,
            dataset_path: str,
            df_image_ids: pd.DataFrame,
            transforms
    ):
        self.dataset_path = dataset_path
        self.df_image_ids = df_image_ids
        self.transforms = transforms
        self.fix_transform = A.Compose([A.Transpose(p=1), A.VerticalFlip(p=1)])

    def __len_(self):
        return self.df_image_ids.shape[0]

    def __getitem__(self, idx):
        image_id = self.df_image_ids.iloc[idx]['image_id']
        img_path = get_img_file_path(
            image_id,
            "test",
            dataset_path=self.dataset_path
        )
        # img = torchvision.io.read_image(img_path)
        # img = img / 255.0
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        _, h, w = img.size()
        if h > w:
            img = self.fix_transform(image=img)['image']
        img = self.transforms(img)
        return img


class MolecularCaptioningDataModule(pl.LightningDataModule):

    def __init__(
            self,
            dataset_path: str,
            target: str,
            targets_path: str,
            test_ids_path: str,
            imsize: int,
            batch_size: int,
            num_workers: int = 4,
            train_size: float = 0.8
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.targets_path = targets_path
        self.imsize = imsize
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_size = train_size
        self.df_targets = load_targets(targets_path, target)
        self.df_test_ids = pd.read_csv(test_ids_path)
        self.tokenizer = SelfiesTokenizer()
        self.train, self.val, self.test = None, None, None
        self.df_train, self.df_val = None, None

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        if stage == 'fit' or stage is None:
            df_train, df_val = train_test_split(
                self.df_targets,
                train_size=self.train_size,
                shuffle=True,
                random_state=42
            )
            self.df_train, self.df_val = df_train, df_val
            self.tokenizer.fit(df_train['target'])
            self.train = MolecularCaptioningDataset(
                self.dataset_path,
                df_targets=df_train,
                tokenizer=self.tokenizer,
                transforms=self._init_tforms("train")
            )
            self.val = MolecularCaptioningValDataset(
                self.dataset_path,
                df_targets=df_val,
                tokenizer=self.tokenizer,
                transforms=self._init_tforms("test")
            )
        if stage == 'test' or stage is None:
            self.test = MolecularCaptioningTestDataset(
                self.dataset_path,
                df_image_ids=self.df_test_ids,
                transforms=self._init_tforms("test")
            )

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            collate_fn=self.train.collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.val,
            batch_size=4,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            collate_fn=self.val.collate_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.test,
            batch_size=1024,
            num_workers=self.num_workers,
            shuffle=False,
            drop_last=False,
        )

    def _init_tforms(self, stage: str):
        tforms = None
        if stage == 'train':
            tforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(2 * (self.imsize,)),
                transforms.Normalize(
                    mean=PRETRAINED_MEAN,
                    std=PRETRAINED_STD
                )
            ])
        elif stage == 'test':
            tforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(2 * (self.imsize,)),
                transforms.Normalize(
                    mean=PRETRAINED_TEST_MEAN,
                    std=PRETRAINED_TEST_STD
                )
            ])
        return tforms

    def _aug_tfroms(self):
        return transforms.Compose([
            transforms.RandomErasing(p=0.2),
            transforms.RandomHorizontalFlip(p=0.2),
            transforms.RandomApply([transforms.RandomRotation(90)])
        ])


# ============= MODULES =============

class Encoder(nn.Module):

    def __init__(
            self,
            encoded_image_size: int,
            is_grayscale: bool,
            fine_tune: bool = True,
            model: str = 'resnet18',
            pretrained: bool = False
    ):
        """
        :param encoded_image_size: output tensor WxH dimensions
        :param fine_tune: enable fine-tuning of residual blocks 2-4
        :param model: model name corresponding to the model from torchvision library
        :param pretrained: load pretrained weights
        """
        super(Encoder, self).__init__()
        self.encoded_image_size = encoded_image_size
        self.is_grayscale = is_grayscale
        self.fine_tune = fine_tune
        self.model = model
        self.pretrained = pretrained
        self.encoder_net = self._make_encoder_net()
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))
        if self.pretrained:
            self.encoder_net = self._fine_tune_params(self.encoder_net, fine_tune)

    def _make_encoder_net(self):
        model_builder = getattr(torchvision.models, self.model)
        resnet = model_builder(pretrained=self.pretrained)
        modules = list(resnet.children())[:-2]
        resnet = nn.Sequential(*modules)
        if self.is_grayscale:
            resnet.conv1 = nn.Conv2d(
                1, 64,
                kernel_size=(7, 7),
                stride=(2, 2),
                padding=(3, 3),
                bias=False
            )
        return resnet

    def _fine_tune_params(self, resnet: nn.Module, fine_tune: bool):
        for param in resnet.parameters():
            param.requires_grad = False
        for child in islice(resnet.children(), 5, None):
            for param in child.parameters():
                param.requires_grad = fine_tune
        return resnet

    def forward(self, x):
        y = self.encoder_net(x)
        y = y.permute(0, 2, 3, 1)
        return y


class Attention(nn.Module):

    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        """
        :param encoder_dim: feature size of encoded images
        :param decoder_dim: size of decoder's RNN
        :param attention_dim: size of the attention network
        """
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # linear layer to transform encoded image
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # linear layer to transform decoder's output
        self.full_att = nn.Linear(attention_dim, 1)  # linear layer to calculate values to be softmax-ed
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)  # softmax layer to calculate weights

    def forward(self, encoder_out, decoder_hidden):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
        :return: attention weighted encoding, weights
        """
        att1 = self.encoder_att(encoder_out)  # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_hidden)  # (batch_size, attention_dim)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)  # (batch_size, num_pixels)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)

        return attention_weighted_encoding, alpha


class DecoderWithAttention(nn.Module):
    """
    Decoder.
    """

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim, device, dropout=0.5):
        """
        :param attention_dim: size of attention network
        :param embed_dim: embedding size
        :param decoder_dim: size of decoder's RNN
        :param vocab_size: size of vocabulary
        :param encoder_dim: feature size of encoded images
        :param dropout: dropout
        """
        super(DecoderWithAttention, self).__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device

        self.attention = Attention(encoder_dim, decoder_dim, attention_dim)  # attention network

        self.embedding = nn.Embedding(vocab_size, embed_dim)  # embedding layer
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  # decoding LSTMCell
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial hidden state of LSTMCell
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  # linear layer to find initial cell state of LSTMCell
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  # linear layer to create a sigmoid-activated gate
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  # linear layer to find scores over vocabulary
        self.init_weights()  # initialize some layers with the uniform distribution

    def init_weights(self):
        """
        Initializes some parameters with values from the uniform distribution, for easier convergence.
        """
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        """
        Loads embedding layer with pre-trained embeddings.
        :param embeddings: pre-trained embeddings
        """
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        """
        Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
        :param fine_tune: Allow?
        """
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        """
        Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
        :return: hidden state, cell state
        """
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Forward propagation.
        :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
        :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
        :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
        :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
        """

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths; why? apparent below
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        # We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
        # So, decoding lengths are actual lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).type_as(h)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).type_as(h)

        # At each time-step, decode by
        # attention-weighing the encoder's output based on the decoder's previous hidden state output
        # then generate a new word in the decoder with the previous word and the attention weighted encoding
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

    def decode(self, encoder_out, decode_lengths, tokenizer):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  # (batch_size, num_pixels, encoder_dim)
        num_pixels = encoder_out.size(1)
        # embed start tocken for LSTM input
        start_tockens = torch.ones(batch_size, dtype=torch.long, device=self.embedding.weight.device) * tokenizer.stoi["<SOS>"]
        embeddings = self.embedding(start_tockens)
        # initialize hidden state and cell state of LSTM cell
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).type_as(h)
        # predict sequence
        for t in range(decode_lengths):
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings, attention_weighted_encoding], dim=1),
                (h, c))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:, t, :] = preds
            if np.argmax(preds.detach().cpu().numpy()) == tokenizer.stoi["<EOS>"]:
                break
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions


# ============= MODELS =============


class MolecularCaptioningModel(pl.LightningModule):

    def __init__(
            self,
            learning_rate,
            encoded_image_size,
            input_grayscale,
            encoder_model,
            encoder_fine_tune,
            encoder_pretrained,
            attention_dim,
            embedding_dim,
            decoder_dim,
            encoder_dim,
            decoder_dropout,
            tokenizer,
            target,
            translate_fn,
            evaluator,
            max_pred_len=120
    ):
        super(MolecularCaptioningModel, self).__init__()
        self.save_hyperparameters(
            "learning_rate",
            "encoded_image_size",
            "input_grayscale",
            "encoder_model",
            "encoder_fine_tune",
            "encoder_pretrained",
            "attention_dim",
            "embedding_dim",
            "decoder_dim",
            "encoder_dim",
            "decoder_dropout",
            "target",
            "max_pred_len"
        )
        self.learning_rate = learning_rate
        self.encoder = Encoder(
            encoded_image_size=encoded_image_size,
            is_grayscale=input_grayscale,
            model=encoder_model,
            fine_tune=encoder_fine_tune,
            pretrained=encoder_pretrained
        )
        self.decoder = DecoderWithAttention(
            attention_dim=attention_dim,
            embed_dim=embedding_dim,
            decoder_dim=decoder_dim,
            vocab_size=len(tokenizer),
            encoder_dim=encoder_dim,
            dropout=decoder_dropout,
            device=self.device
        )
        self.max_pred_len = max_pred_len
        self.tokenizer = tokenizer
        self.evaluator = evaluator
        self.target = target
        self.translate_fn = translate_fn
        self.criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<PAD>"])

    def forward(self, images):
        pass

    def _compute_loss(self, images, labels, label_lengths):
        features = self.encoder(images)
        predictions_, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(features, labels, label_lengths)
        targets = caps_sorted[:, 1:]
        predictions = pack_padded_sequence(predictions_, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        loss = self.criterion(predictions, targets)
        return loss, features

    def training_step(self, batch, batch_idx):
        images, labels, label_lengths = batch
        loss, _ = self._compute_loss(images, labels, label_lengths)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images, labels, label_lengths, idxs = batch
        loss, encoded = self._compute_loss(images, labels, label_lengths)
        predictions = self.decoder.decode(encoded, self.max_pred_len, self.tokenizer)
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        predicted_sequences = [self.tokenizer.reverse_tokenize(seq) for seq in predicted_sequence]
        translated_sequences = [self.translate_fn(seq) for seq in predicted_sequences]
        cv_ld = self.evaluator.eval_batch(idxs.detach().cpu().numpy().ravel(), translated_sequences)
        self.log('val_loss', loss)
        self.log('cv_ld', cv_ld)

    def configure_optimizers(self):
        optimizer = Adam(
            self.parameters(),
            lr=self.learning_rate,
            amsgrad=False
        )
        scheduler = ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.2,
            patience=4,
            eps=1e-6,
            verbose=True
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }


# ============= TRAINING =============

pl.seed_everything(42)
conf = yaml.safe_load(CONFIG)
datamodule = MolecularCaptioningDataModule(
    dataset_path=DATASET,
    target=conf['target'],
    targets_path=TARGETS,
    test_ids_path=TEST_PATH,
    imsize=conf['imsize'],
    batch_size=conf['batch_size']
)
datamodule.setup('fit')
evaluator = Evaluator(datamodule.val)
model = MolecularCaptioningModel(
    learning_rate=conf['learning_rate'],
    encoded_image_size=conf['encoded_image_size'],
    input_grayscale=conf['input_grayscale'],
    encoder_model=conf['encoder_model'],
    encoder_fine_tune=conf['encoder_fine_tune'],
    encoder_pretrained=conf['encoder_pretrained'],
    attention_dim=conf['attention_dim'],
    embedding_dim=conf['embedding_dim'],
    decoder_dim=conf['decoder_dim'],
    encoder_dim=conf['encoder_dim'],
    decoder_dropout=conf['decoder_dropout'],
    tokenizer=datamodule.tokenizer,
    target=conf['target'],
    translate_fn=selfies2inchi,
    max_pred_len=conf['max_pred_len'],
    evaluator=evaluator
)
trainer = pl.Trainer(
    default_root_dir='/kaggle/working',
    deterministic=True,
    gpus=conf['gpus'],
    max_epochs=conf['epochs'],
    gradient_clip_val=conf['max_grad_norm'],
    precision=16
)
trainer.fit(model, datamodule)
