In [None]:
pip install timm

In [None]:
from io import BytesIO
from IPython.display import HTML

import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import os
import gc
import re
import math
import time
import random
import shutil
import pickle
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import Levenshtein
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

import warnings 
warnings.filterwarnings('ignore')

import re
import numpy as np

from matplotlib import pyplot as plt

from subprocess import Popen, PIPE

from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# DIRECTIORES

BASE_DIR = "../input/bms-molecular-translation"
TRAIN_DIR = f"{BASE_DIR}/train"
TEST_DIR = f"{BASE_DIR}/test"
OUTPUT_DIR = './'

DIR_TYPES = {
    "train": TRAIN_DIR,
    "test": TEST_DIR
}

# PATH UTILS

get_path = lambda items_ids: "/".join(list(map(lambda id: f"{id}" ,items_ids)))
get_file_path = lambda file_type, file_id, suffix = "": f"{DIR_TYPES[file_type]}/{get_path(file_id)}{suffix}"

"""Get a image path

params: 
    image_id: string
        image id to be retrieved
    image_type: 'train' | 'test'
"""
get_image_path = lambda image_id, file_type = 'train': f"{get_file_path(file_type, [image_id[0],image_id[1],image_id[2], image_id], '.png')}"
get_image_path_test = lambda image_id, file_type = 'test': f"{get_file_path(file_type, [image_id[0],image_id[1],image_id[2], image_id], '.png')}"

def has_file(filename):
    """Verify if a file exists
    
    params:
    ------
        filename: string
            the filename to verify existence
    returns:
    -------
        Returns a boolean wheres concludes the file existence
    """
    process = Popen('ls', stdout=PIPE)
    output, error = process.communicate()
    process.kill()
    if error: return False;
    return filename in f"{output}".split("\\n")

In [None]:
# IMAGE PREVIEW UTILS

pd.set_option('display.max_colwidth', None)

def get_thumbnail(path):
    image = Image.open(path)
    image.thumbnail((200, 200), Image.LANCZOS)
    return image

def image_base64(im):
    if isinstance(im, str):
        im = get_thumbnail(im)
    with BytesIO() as buffer:
        im.save(buffer, 'jpeg')
        return base64.b64encode(buffer.getvalue()).decode()

def image_formatter(im):
    return f'<img src="data:image/jpeg;base64,{image_base64(im)}">'

In [None]:
# TOKENIZER CLASS

class Tokenizer(object):
    def __init__(self):
        self.stoi = {}
        self.itos = {}
    
    def __len__(self):
        return len(self.stoi)

    def fit_on_texts(self, texts = []):
        """ Fill stoi and itos with vocabulary

        Parameters
        ----------
        texts: list of string
            list of texts used to be tokenized, where fill the stoi(string to index) and 
            itos(index to string) sets. 
        """

        vocab = map(lambda text: text.split(' '), texts)
        sorted_vocab = sorted([item for sublist in vocab for item in sublist if item != ''])
        final_vocab = {*sorted_vocab, '<sos>', '<eos>', '<pad>'}


        for idx, string in enumerate(final_vocab): 
            self.stoi[string] = idx
            self.itos[idx] = string
        
    def text_to_sequence(self, text = ""):
        """ Parse the received text to a vector of indexes related to each token.

        Parameters
        ----------
        text: string
            Text to be parsed to string.

        Returns
        -------
            Returns a array of integers.
        """
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for string in text.split(' '): sequence.append(self.stoi[string])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts = []):
        """ Parse a batch of texts into a array of sequences

        Parameters
        ----------
        texts: [string]
            A array of texts.

        Returns
        -------
            Returns a array of array of integers.
        
        """
        return list(map(self.text_to_sequence, texts))
    
    def sequence_to_text(self, sequence = []):
        """ Parse a sequence into a text

        Parameters
        ----------
        sequence: [integer]
            A sequence where will be parsed by accessing the giving numbers to retrieve
            the related texts from itos(itos).

        Returns
        -------
            Returns a string
        """

        return " ".join(list(map(
            lambda idx: self.itos[idx],
            sequence
        )))

    def sequences_to_texts(self, sequences = []):
        """ Parse a batch of sequence into a array related texts

        Parameters
        ----------
        sequences: [[integer]]
            A array of sequences where will be parsed by accessing the giving numbers to
            retrieve the related texts from itos(itos).

        Returns
        -------
            Array of strings
        """
        
        return list(map(self.sequence_to_text, sequences))
    
    def predict_caption(self, sequence = []):
        """ Parse a sequence into a text based and limited by eos(end of sequence) or 
            pad tokens
        
        Parameters
        ----------
        sequence: [integer]
            A sequence where will be parsed by accessing the giving numbers to retrieve
            the related texts from itos(itos).

        Returns
        -------
            Returns a text
        
        """
        caption = []
        for idx in sequence:
            curr_string = self.itos[idx]
            if curr_string in ["<pad>", "<eos>"]: break
            caption.append(curr_string)
        return " ".join(caption)
    
    def predict_captions(self, sequences = []):
        """ Parse a batch of sequences into a text based and limited by eos(end of sequence) or 
            pad tokens
        
        Parameters
        ----------
        sequence: [[integer]]
            A array of sequences where will be parsed by accessing the giving numbers to 
            retrieve the related texts from itos(itos).

        Returns
        -------
            Returns a array of texts
        
        """
        return list(map(self.predict_caption, sequences))

def run_tokenizer_test(): # TODO Remove when ended
    error = None

    test_tokenizer = Tokenizer()
    test_texts = [
        'C 13 H 20 O S /c 1 - 9 ( 2 ) 8 - 15 - 13 - 6 - 5 - 10 ( 3 ) 7 - 12 ( 13 ) 11 ( 4 ) 14 /h 5 - 7 , 9 , 11 , 14 H , 8 H 2 , 1 - 4 H 3',
        'C 21 H 30 O 4 /c 1 - 12 ( 22 ) 25 - 14 - 6 - 8 - 20 ( 2 ) 13 ( 10 - 14 ) 11 - 17 ( 23 ) 19 - 15 - 4 - 5 - 18 ( 24 ) 21 ( 15 , 3 ) 9 - 7 - 16 ( 19 ) 20 /h 13 - 16 , 19 H , 4 - 11 H 2 , 1 - 3 H 3 /t 13 - , 14 + , 15 + , 16 - , 19 - , 20 + , 21 + /m 1 /s 1',
        'C 24 H 23 N 5 O 4 /c 1 - 14 - 13 - 15 ( 7 - 8 - 17 ( 14 ) 28 - 12 - 10 - 20 ( 28 ) 30 ) 27 - 11 - 9 - 16 - 21 ( 23 ( 25 ) 31 ) 26 - 29 ( 22 ( 16 ) 24 ( 27 ) 32 ) 18 - 5 - 3 - 4 - 6 - 19 ( 18 ) 33 - 2 /h 3 - 8 , 13 H , 9 - 12 H 2 , 1 - 2 H 3 , ( H 2 , 25 , 31 )'
    ]
    test_tokenizer.fit_on_texts(test_texts)

    expected = ['(', ')', '+', ',', '-', '/c', '/h', '/m', '/s', '/t', '1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '4', '5', '6', '7', '8', '9', 'C', 'H', 'N', 'O', 'S', '<sos>', '<eos>', '<pad>']
    rest = [item for item in expected if item not in list(test_tokenizer.stoi)]

    if len(rest): error = 'Mismatching stoi'

    if error != None: raise Exception(error)


    


run_tokenizer_test()


In [None]:
# DATASET READING

tokenizer = '../input/inchipreprocess2/tokenizer2.pth'
TRAIN_PICKLE_FILENAME = '../input/inchipreprocess2/train2.pkl'

class PreProcess:
    def __init__(self):
        print("INITIALIZED\n")
        print("LOADING TOKENIZER...\n")
        self.tokenizer = Tokenizer()
        self.tokenizer = torch.load('../input/inchipreprocess2/tokenizer2.pth')

        print("LOADING TRAIN PICKLE...\n")
        self.train = pd.read_pickle('../input/inchipreprocess2/train2.pkl')
        self.train['file_path'] = self.train['image_id'].apply(get_image_path)
        self.test = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv')
        self.test['file_path'] = self.test['image_id'].apply(get_image_path_test)



In [None]:
pre_process = PreProcess()

In [None]:
print(pre_process.train.columns)
print(pre_process.test.columns)

In [None]:
print(pre_process.train.describe())
print("\n")
print(pre_process.test.describe())

In [None]:
pre_process.train.head(10)

In [None]:
pre_process.test.head(10)

In [None]:
print(pre_process.train.dtypes)
print("\n")
print(pre_process.test.dtypes)

In [None]:
# ITEM EXAMPLE

print(pre_process.train.loc[0])
print("\n")
print(pre_process.test.loc[0])

In [None]:
# TRAIN IMAGE SAMPLE

get_thumbnail(pre_process.train['file_path'].values[0])

In [None]:
# TEST IMAGE SAMPLE

get_thumbnail(get_image_path(pre_process.test.loc[0].image_id, 'test'))

In [None]:
def split_form(form):
    string = ''
    for i in re.findall(r"[A-Z][^A-Z]*", form):
        elem = re.match(r"\D+", i).group()
        num = i.replace(elem, "")
        if num == "":
            string += f"{elem} "
        else:
            string += f"{elem} {str(num)} "
    return string.rstrip(' ')

def split_form2(form):
    string = ''
    for i in re.findall(r"[a-z][^a-z]*", form):
        elem = i[0]
        num = i.replace(elem, "").replace('/', "")
        num_string = ''
        for j in re.findall(r"[0-9]+[^0-9]*", num):
            num_list = list(re.findall(r'\d+', j))
            assert len(num_list) == 1, f"len(num_list) != 1"
            _num = num_list[0]
            if j == _num:
                num_string += f"{_num} "
            else:
                extra = j.replace(_num, "")
                num_string += f"{_num} {' '.join(list(extra))} "
        string += f"/{elem} {num_string}"
    return string.rstrip(' ')

split_form_compound = lambda inChI: split_form2(split_form(inChI.split('=')[1]))

In [None]:
splitByNumber = r'(\d+)'
splitText = r'/[a-zA-Z][a-zA-Z]'

def parse_formula(inChI = ''):
    result = map(lambda string: re.split(splitByNumber,string) ,re.split(splitText, inChI))
    return np.unique([item for sublist in result for item in sublist if item != ''][1:])

def parse_formulas(formulas = []):
    result = map(parse_formula, formulas)
    return np.unique([item for sublist in result for item in sublist if item != ''][1:])

print(" ".join(parse_formula('InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(10-14)11-17(23)19-15-4-5-18(24)21(15,3)9-7-16(19)20/h13-16,19H,4-11H2,1-3H3/t13-,14+,15+,16-,19-,20+,21+/m1/s1')))
    

In [None]:
# Setting up CFG

class CFG:
    debug=True
    max_len=275
    print_freq=1000
    num_workers=4
    model_name='resnet34'
    size=224
    scheduler='CosineAnnealingLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    epochs=1
    #factor=0.2 # ReduceLROnPlateau
    #patience=4 # ReduceLROnPlateau
    #eps=1e-6 # ReduceLROnPlateau
    T_max=4 # CosineAnnealingLR
    #T_0=4 # CosineAnnealingWarmRestarts
    encoder_lr=1e-4
    decoder_lr=4e-4
    min_lr=1e-6
    batch_size=8
    weight_decay=1e-6
    gradient_accumulation_steps=1
    max_grad_norm=5
    attention_dim=256
    embed_dim=256
    decoder_dim=512
    dropout=0.5
    seed=42
    n_fold=2
    trn_fold=[0] 
    train=True

In [None]:
if CFG.debug:
    CFG.epochs = 1
    pre_process.train = pre_process.train.sample(n=30000, random_state=CFG.seed).reset_index(drop=True)
    pre_process.test = pre_process.test.sample(n=30000, random_state=CFG.seed).reset_index(drop=True)

In [None]:
pre_process.test.shape

In [None]:
# SEED TORCH

def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

In [None]:
# SCORING UTILS

def get_score(y_true, y_pred):
    scores = []
    for true, pred in zip(y_true, y_pred):
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score


def init_logger(log_file=OUTPUT_DIR+'train.log'):
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()


def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

In [None]:
# CV SPLIT
if(CFG.train):
    folds = pre_process.train.copy()
    Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['InChI_length'])):
        folds.loc[val_index, 'fold'] = int(n)
    folds['fold'] = folds['fold'].astype(int)
    print(folds.groupby(['fold']).size())

In [None]:
# DATASET
class TrainDataset(Dataset):
    def __init__(self, df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.file_paths = df['file_path'].values
        self.labels = df['InChI_text'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label)
        print(label,label_length)
        label_length = torch.LongTensor([label_length])
        return image, torch.LongTensor(label), label_length
    
class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df
        self.file_paths = df['file_path'].values
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        return image

In [None]:
def bms_collate(batch):
    imgs, labels, label_lengths = [], [], []
    for data_point in batch:
        imgs.append(data_point[0])
        labels.append(data_point[1])
        label_lengths.append(data_point[2])
    labels = pad_sequence(labels, batch_first=True, padding_value=pre_process.tokenizer.stoi["<pad>"])
    return torch.stack(imgs), labels, torch.stack(label_lengths).reshape(-1, 1)

In [None]:
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])

In [None]:
pre_process.tokenizer.sequence_to_text([1,2,3])

In [None]:
#Showing train dataset
train_dataset = TrainDataset(pre_process.train, pre_process.tokenizer, transform=get_transforms(data='train'))

for i in range(1):
    train_dataset[i]
    image, label, label_length = train_dataset[i]
    text = pre_process.tokenizer.sequence_to_text(label.numpy())
    plt.imshow(image.transpose(0, 1).transpose(1, 2))
    plt.title(f'label: {label}  text: {text}  label_length: {label_length}')
    plt.show() 

# **MODEL**

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_name=CFG.model_name, pretrained=True):
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        self.n_features = self.cnn.fc.in_features
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()

    def forward(self, x):
        bs = x.size(0)
        features = self.cnn(x)
        features = features.permute(0, 2, 3, 1)
        return features

In [None]:
class Attention(nn.Module):
 
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
  
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim) 
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  
        att2 = self.decoder_att(decoder_hidden)
        att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2)
        alpha = self.softmax(att)
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return attention_weighted_encoding, alpha


class DecoderWithAttention(nn.Module):

    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=512, dropout=0.5):
        super(DecoderWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device
        self.attention = Attention(encoder_dim, decoder_dim, attention_dim) 
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) 
        self.init_h = nn.Linear(encoder_dim, decoder_dim) 
        self.init_c = nn.Linear(encoder_dim, decoder_dim) 
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size) 
        self.init_weights()  

    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = nn.Parameter(embeddings)

    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = fine_tune

    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out) 
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):

        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim) 
        num_pixels = encoder_out.size(1)
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]
        embeddings = self.embedding(encoded_captions) 
        h, c = self.init_hidden_state(encoder_out) 
        decode_lengths = (caption_lengths - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(self.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t])) 
            preds = self.fc(self.dropout(h)) 
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        return predictions, encoded_captions, decode_lengths, alphas, sort_ind
    
    def predict(self, encoder_out, decode_lengths, tokenizer):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  
        num_pixels = encoder_out.size(1)
        start_tockens = torch.ones(batch_size, dtype=torch.long).to(self.device) * tokenizer.stoi["<sos>"]
        embeddings = self.embedding(start_tockens)
        h, c = self.init_hidden_state(encoder_out)
        predictions = torch.zeros(batch_size, decode_lengths, vocab_size).to(self.device)
        for t in range(decode_lengths):
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            gate = self.sigmoid(self.f_beta(h))  
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings, attention_weighted_encoding], dim=1),
                (h, c)) 
            preds = self.fc(self.dropout(h))  
            predictions[:, t, :] = preds
            if np.argmax(preds.detach().cpu().numpy()) == tokenizer.stoi["<eos>"]:
                break
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(train_loader, encoder, decoder, criterion, 
             encoder_optimizer, decoder_optimizer, epoch,
             encoder_scheduler, decoder_scheduler, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    encoder.train()
    decoder.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels, label_lengths) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        batch_size = images.size(0)
        features = encoder(images)
        predictions, caps_sorted, decode_lengths, alphas, sort_ind = decoder(features, labels, label_lengths)
        targets = caps_sorted[:, 1:]
        predictions = pack_padded_sequence(predictions, decode_lengths, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_lengths, batch_first=True).data
        loss = criterion(predictions, targets)
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        loss.backward()
        encoder_grad_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), CFG.max_grad_norm)
        decoder_grad_norm = torch.nn.utils.clip_grad_norm_(decoder.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            encoder_optimizer.step()
            decoder_optimizer.step()
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            global_step += 1
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Encoder Grad: {encoder_grad_norm:.4f}  '
                  'Decoder Grad: {decoder_grad_norm:.4f}  '
                  #'Encoder LR: {encoder_lr:.6f}  '
                  #'Decoder LR: {decoder_lr:.6f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   encoder_grad_norm=encoder_grad_norm,
                   decoder_grad_norm=decoder_grad_norm,
                   #encoder_lr=encoder_scheduler.get_lr()[0],
                   #decoder_lr=decoder_scheduler.get_lr()[0],
                   ))
    return losses.avg


def valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    encoder.eval()
    decoder.eval()
    text_preds = []
    start = end = time.time()
    for step, (images) in enumerate(valid_loader):
        data_time.update(time.time() - end)
        images = images.to(device)
        batch_size = images.size(0)
        with torch.no_grad():
            features = encoder(images)
            predictions = decoder.predict(features, CFG.max_len, tokenizer)
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        _text_preds = tokenizer.predict_captions(predicted_sequence)
        text_preds.append(_text_preds)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
    text_preds = np.concatenate(text_preds)
    return text_preds

In [None]:
def train_loop(folds, fold):

    LOGGER.info(f"========== fold: {fold} training ==========")

    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds['InChI'].values

    train_dataset = TrainDataset(train_folds, pre_process.tokenizer, transform=get_transforms(data='train'))
    valid_dataset = TestDataset(valid_folds, transform=get_transforms(data='valid'))

    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, 
                              pin_memory=True,
                              drop_last=True, 
                              collate_fn=bms_collate)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers,
                              pin_memory=True, 
                              drop_last=False)
    
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler

    encoder = Encoder(CFG.model_name, pretrained=True)
    encoder.to(device)
    encoder_optimizer = Adam(encoder.parameters(), lr=CFG.encoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    encoder_scheduler = get_scheduler(encoder_optimizer)
    
    decoder = DecoderWithAttention(attention_dim=CFG.attention_dim,
                                   embed_dim=CFG.embed_dim,
                                   decoder_dim=CFG.decoder_dim,
                                   vocab_size=len(pre_process.tokenizer),
                                   dropout=CFG.dropout,
                                   device=device)
    decoder.to(device)
    decoder_optimizer = Adam(decoder.parameters(), lr=CFG.decoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    decoder_scheduler = get_scheduler(decoder_optimizer)

    criterion = nn.CrossEntropyLoss(ignore_index=pre_process.tokenizer.stoi["<pad>"])

    best_score = np.inf
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_loss = train_fn(train_loader, encoder, decoder, criterion, 
                            encoder_optimizer, decoder_optimizer, epoch, 
                            encoder_scheduler, decoder_scheduler, device)

        # eval
        text_preds = valid_fn(valid_loader, encoder, decoder, pre_process.tokenizer, criterion, device)
        text_preds = [f"InChI=1S/{text}" for text in text_preds]
        LOGGER.info(f"labels: {valid_labels[:5]}")
        LOGGER.info(f"preds: {text_preds[:5]}")
        
        # scoring
        score = get_score(valid_labels, text_preds)
        
        if isinstance(encoder_scheduler, ReduceLROnPlateau):
            encoder_scheduler.step(score)
        elif isinstance(encoder_scheduler, CosineAnnealingLR):
            encoder_scheduler.step()
        elif isinstance(encoder_scheduler, CosineAnnealingWarmRestarts):
            encoder_scheduler.step()
            
        if isinstance(decoder_scheduler, ReduceLROnPlateau):
            decoder_scheduler.step(score)
        elif isinstance(decoder_scheduler, CosineAnnealingLR):
            decoder_scheduler.step()
        elif isinstance(decoder_scheduler, CosineAnnealingWarmRestarts):
            decoder_scheduler.step()

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        
        if score < best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'encoder': encoder.state_dict(), 
                        'encoder_optimizer': encoder_optimizer.state_dict(), 
                        'encoder_scheduler': encoder_scheduler.state_dict(), 
                        'decoder': decoder.state_dict(), 
                        'decoder_optimizer': decoder_optimizer.state_dict(), 
                        'decoder_scheduler': decoder_scheduler.state_dict(), 
                        'text_preds': text_preds,
                       },
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')

In [None]:
#When CFG.train is true to training the train dataset

def main():

    """
    Prepare: 1.train  2.folds
    """

    if CFG.train:
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                print(fold)
                train_loop(folds, fold)

In [None]:
if __name__ == '__main__':
    main()

In [None]:
#def inference(test_loader, encoder, decoder, tokenizer, device):
#    encoder.eval()
#    decoder.eval()
#    text_preds = []
#    tk0 = tqdm(test_loader, total= 1616107)
#    for images in tk0:
#        images = images.to(device)
#        with torch.no_grad():
#            features = encoder(images)
#            predictions = decoder.predict(features, CFG.max_len, tokenizer)
#        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
#        _text_preds = tokenizer.predict_captions(predicted_sequence)
#        text_preds.append(_text_preds)
#    text_preds = np.concatenate(text_preds)
#    return text_preds

In [None]:
#encoder = Encoder(CFG.model_name, pretrained= True)
#encoder.to(device)

#decoder = DecoderWithAttention(attention_dim=CFG.attention_dim,
#                               embed_dim=CFG.embed_dim,
#                               decoder_dim=CFG.decoder_dim,
#                               vocab_size=len(pre_process.tokenizer),
#                               dropout=CFG.dropout,
#                               device=device)
#decoder.to(device)

#test_dataset = TestDataset(pre_process.test, transform=get_transforms(data='valid'))
#test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=CFG.num_workers)
#predictions = inference(test_loader, encoder, decoder, pre_process.tokenizer, device)

In [None]:
# submission
#pre_process.test['InChI'] = [f"InChI=1S/{text}" for text in predictions]
#pre_process.test[['image_id', 'InChI']].to_csv('submission.csv', index=False)


In [None]:
#pre_process.test[['image_id', 'InChI']].shape