In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        os.path.join(dirname, filename)

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import torch
from torch import nn
from torch.utils import data
from torchvision import transforms
from torchvision import models

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image

In [None]:
%%time

from tqdm.auto import tqdm #重要
tqdm.pandas()

train = pd.read_csv('../input/bms-molecular-translation/train_labels.csv')
test = pd.read_csv('../input/bms-molecular-translation/sample_submission.csv')

def get_train_file_path(image_id):
    return "../input/bms-molecular-translation/train/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

def get_test_file_path(image_id):
    return "../input/bms-molecular-translation/test/{}/{}/{}/{}.png".format(
        image_id[0], image_id[1], image_id[2], image_id 
    )

train['file_path'] = train['image_id'].progress_apply(get_train_file_path)
test['file_path'] = test['image_id'].progress_apply(get_test_file_path)

print(f'train.shape: {train.shape}  test.shape: {test.shape}')
display(train.head())
display(test.head())

In [None]:
import cv2

for i in range(5):
    img_path = train.loc[i, 'file_path']
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = img[:,:,np.newaxis]
    label = train.loc[i, 'InChI']
    print(img.shape)
    plt.imshow(img)
    plt.title(label)
    plt.show()

In [None]:
#scorring_function
import Levenshtein
def get_score(y_true, y_pred):
    scores = list()
    for true, pred in y_true, y_pred:
        score = Levenshtein.distance(true, pred)
        scores.append(score)
    avg_score = np.mean(scores)
    

In [None]:
import re
def split_form(form):
    string = ''
    for i in re.findall(r"[A-Z][^A-Z]*", form):
        elem = re.match(r"\D+", i).group()
        num = i.replace(elem, "")
        if num == "":
            string += f"{elem} "
        else:
            string += f"{elem} {str(num)} "
    return string.rstrip(' ')

def split_form2(form):
    string = ''
    for i in re.findall(r"[a-z][^a-z]*", form):
        elem = i[0]
        num = i.replace(elem, "").replace('/', "")
        num_string = ''
        for j in re.findall(r"[0-9]+[^0-9]*", num):
            num_list = list(re.findall(r'\d+', j))
            assert len(num_list) == 1, f"len(num_list) != 1"
            _num = num_list[0]
            if j == _num:
                num_string += f"{_num} "
            else:
                extra = j.replace(_num, "")
                num_string += f"{_num} {' '.join(list(extra))} "
        string += f"/{elem} {num_string}"
    return string.rstrip(' ')

train['InChI_1'] = train['InChI'].progress_apply(lambda x: x.split('/')[1])
train['InChI_text'] = train['InChI_1'].progress_apply(split_form) + ' ' + \
                            train['InChI'].apply(lambda x: '/'.join(x.split('/')[2:])).progress_apply(split_form2).values

In [None]:
train.head()

In [None]:
class Tokenizer(object):
    def __init__(self):
        self.stoi = {}
        self.itos = {}
    
    def __len__(self):
        return len(self.stoi)
    def fit_on_texts(self, texts):
        vocab = set()
        for text in texts:
            vocab.update(text.split())
        #copyしつつsort
        vocab = sorted(vocab)
        vocab.append('<sos>')
        vocab.append('<eos>')
        vocab.append('<pad>')
        #i: index番号, s:value
        for i, s in enumerate(vocab):
            self.stoi[s] = i
            #番号から語彙検索と語彙から番号検索を両方できるようにする
        self.itos = {item[1]:item[0] for item in self.stoi.items()}
    
    def text_to_sequence(self, text):
        sequence = []
        sequence.append(self.stoi['<sos>'])
        for s in text.split(' '):
            sequence.append(self.stoi[s])
        sequence.append(self.stoi['<eos>'])
        return sequence
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            sequence = self.text_to_sequence(text)
            sequences.append(sequence)
        return sequences
    
    def sequence_to_text(self,sequence):
        return ''.join(list(map(lambda x:self.itos[x], sequence)))
    
    def sequences_to_texts(self, sequences):
        texts = []
        for sequence in sequences:
            text = self.sequence_to_text(sequence)
            texts.append(text)
        return texts
    
    def predict_caption(self, sequence):
        caption = ''
        for i in sequence:
            if i == self.stoi['<sos>'] or i == self.stoi['<pad>']:
                break
            captions += self.itos[i]
        return captions
    
    def prdict_captions(self, sequences):
        captions = []
        for sequence in sequences:
            caption = self.predict_caption(sequence)
            captions.append(caption)
        return captions

In [None]:
train

In [None]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(train['InChI_text'].values)

In [None]:
lengths = []
tk0 = tqdm(train['InChI_text'].values, total=len(train))
for text in tk0:
        seq = tokenizer.text_to_sequence(text)
        length = len(seq) - 2
        lengths.append(length)
train['InChI_length'] = lengths

In [None]:
train

In [None]:
train['InChI_length'].max()

In [None]:
class CFG:
    debug = False
    max_len = 275
    print_freq = 1000
    num_workers = 4
    model_name = 'resnet'
    size = 224
    scheduler='CosineAnnealingLR' #学習率をepochによって変化
    epochs = 35
    T_max = 4
    encoder_lr = 1e-4
    decorder_lr = 4e-4
    min_lr = 1e-6
    batch_size = 64
    weight_decay = 1e-6 #重み減衰 => over_fit対策　重みの値が大きくなることにペナルティをつける
    gradient_acuumulation_steps = 1
    max_grad_norm=5
    attention_dim=256
    embed_dim=256
    decoder_dim=512
    dropout=0.5
    seed=42
    n_fold=5
    trn_fold=[0] # [0, 1, 2, 3, 4]
    train=True

In [None]:
if CFG.debug:
    CFG.epochs = 1
    train = train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)

In [None]:
!pip install timm

In [None]:
import sys
sys.path.append('../input/pytorch-image-models/pytorch-image-models-master')

import os
import gc
import re
import math
import time
import random
import shutil
import pickle
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import Levenshtein
from sklearn import preprocessing
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, Blur
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

import warnings 
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.determinstic = True

seed_torch(CFG.seed)

In [None]:
folds = train.copy()
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for i, (tr_idx, va_idx) in enumerate(Fold.split(folds, folds['InChI_length'])):
    folds.loc[va_idx, 'fold'] = int(i) ##fold番号をインデックスのように指定
folds['fold'] = folds['fold'].astype(int)

In [None]:
#*でリスト内を別途に取り出して処理することが可能
def image_transform(*,data):
    if data =='train':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                     ),
            ToTensorV2(),
        ])
    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                     ),
            ToTensorV2,
        ])

In [None]:
class TrainDataset(Dataset):
    def __init__(self,df, tokenizer, transform=None):
        super().__init__()
        self.df = df
        self.tokenizer = tokenizer
        self.transform = transform
        self.file_paths = df['file_path'].values
        self.labels = df['InChI_text'].values
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        file_path = self.file_paths[idx]
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = self.labels[idx]
        label = self.tokenizer.text_to_sequence(label)
        label_length = len(label) 
        #整数を扱う場合はlong型を使います（int型も別途ありますが、ニューラルネットのラベルとして受け付けてくれませんので、よほど使う機会はないです）
        label_length = torch.LongTensor
        return image, torch.LongTensor(label), label_length

class TestDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        self.file_paths = df['file_path']
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        image_path = self.file_paths
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        if self.transform:
            augmented = self.transform(image=img)
            image = augmented['image']
        return image
        

In [None]:
def bms_collate(batch):
    imgs, label, label_length = [],[],[]
    for data_points in batch:
        imgs.append(data_points[0])
        labels.append(data_points[1])
        label_length.append(data_points[2])
    #Tensor of size T x B x * if batch_first is False. Tensor of size B x T x * otherwise
    labels = pad_sequence(labels, batch_first=True, padding_value=tokenizer.stoi['<pad>'])
    #torch.stack(imgs) => [], [], []... のような配置を[[],[],[],[],...]
    return torch.stack(imgs), labels, torch.stack(label_length).reshape(-1, 1)

In [None]:
size=CFG.size
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
train_dataset = TrainDataset(train, tokenizer, transform=image_transform(data='train'))

In [None]:
for i in range(1):
    image,label,label_length = train_dataset[i]
    print(image.shape)
    text = tokenizer.sequence_to_text(label.numpy())
    plt.imshow(image.permute(1,2,0))
    plt.show

In [None]:
class Encoder(nn.Module):
    def __init__(self, model_name='resnet18', pretrained=False):
        super().__init__()
        self.cnn = timm.create_model(model_name, pretrained=pretrained)
        #最終層の入力.つまり最終層に入ってくるエッジの数
        self.n_feaatures = self.cnn.fc.in_features
        self.cnn.global_pool = nn.Identity()
        self.cnn.fc = nn.Identity()
    
    def forward(self,x):
        #画像データのサイズ (画像数, 色,　幅,　高さ)
        bs = x.size(0)
        #画像から特徴量抽出
        features = self.cnn(x)
        #(画像数, 高さ, 幅, 色) に変換している
        features = features.permute(0,2,3,1)
        return_features

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention,self).__init__()
        #encoder,decoderの容器作り
        self.encoder_att = nn.Linear(encoder_dim,decoder_dim) #(入力,出力) 
        self.decoder_att = nn.Linear(decoder_dim, attention_dim) # (batch_size, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1) #softmax
        #使う活性化関数の定義
        self.relu = nn.Relu()
        self.softmax = nn.Softmax(dim=1) #dim = 出力数
    
    def forward(self, encoder_out, decoder_out):
        #事故注意層はencoderのattention
        att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(decoder_out) # (batch_size, attention_dim)
        #次元があっていないと足し算できないのでunsqueeze
        #torch.squeezeは要素数が1のみの軸を削除
        #num_pixel = (高さ,　幅)
        att = self.full_att(self.relu(att1+att2.unsqueeze(1))).squeeze(2) #(batch_size, (height * width))
        
        alpha = self.softmax(att) # (batch_size, (height,width))
        attention_weighted_encoding =(encoder_out * alpha.squeeze(2))
        
        return attention_weight_encoding, alpha
    

class DecoderWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, device, encoder_dim=512, dropout=0.5):
        super(DecoderWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.attention_dim = decoder_dim
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.device = device
        self.attention = Attention(encoder_dim,decoder_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout = dropout
        self.devce = device
        #decoding_cell
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True) #encoderの出力特徴量と入力特徴量をconcatするので
        #LSTMのセルと隠れ層の定義
        self.init_h = nn.Linear(encoder_dim, decoder_dim)
        self.init_c = nn.Linear(encoder_dim, decoder_dim)
        #sigmoid_acctivate_gate
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        self.sigmoid = nn.sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)
    
    def load_pretrained_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            #requires_grad=trueで更新される
            p.requires_grad = fine_tune
    
    def init_weights(self):
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
    
    def load_pretrained_embeddings(self, embeddings):
        self.embedding.weight = n.Parameter(embeddings)
        
    def fine_tune_embeddings(self, fine_tune=True):
        for p in self.embedding.parameters():
            p.requires_grad = True
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1) #num_pixels
        #hidden層の計算encoderの出力を入力として計算している
        h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
    
    def forward(self, encoder_out, encoded_captions, caption_length):
        batch_size = encoder_out(0) #0次元が画像枚数
        encoder_dim = encoder_out(-1)
        vocab_size = self.vocab_size
        #テンソルの形状を再形成する.　形状の総数が一致していないとerror
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixel, encoder_dim)
        num_pixels = encoder_out(1)
        caption_length, sort_idx = caption_length.squeeze(1).sort(dim=0, descending=True)
        #labelをデータの出力されたものに合わせる
        encoder_out = encoder_out[sort_idx]
        encoded_captions = encoded_captions[sort_idx]
        embeddings  = self.embedding(encoded_captions)
        #encoded_features => LSTMCell
        h,c = self.init_hidden_state(encoder_out)
        decode_lengths = (caption_length - 1).tolist()
        predictions = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(self.device)
        #predict sequence
        for t in range(max(decode_length)):
            #padding
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size])) ## gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weight_encoding
            #cat => 第三引数は出力時のテンソルのサイズを変更している
            h, c = self.decode_step(torch.cat([embeddings, attention_weight], dim=1), (h[:batch_size_t],c[batch_size_t]))
            preds = self.fc(self.fc(self.dropout(h)))
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha
        return predictions, encoded_captions, decode_length, alphas, sort_ind
    
    def predict(self, encoder_out, decoder_length, tokenizer):
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)
        num_pixels = encoder_out.size(1)
        start_tokens = np.ones(batch_size, dtype=torch.long).to(self.device)
        #最初の画像はないのでstart tokensで代用
        embeddings = self.embedding(start_tokens)
        #最初のLSTMの出力層,最初はencoderの出力を平均するだけ
        h, c = self.init_hidden_state(encoder_out) #(batch_size, encoder_dim)
        for t in range(decoder_length):
            #attention
            attention_weighted_encoding, alpha = self.attention(encoder_out, h)
            #横軸の直接渡す法のhの処理
            gate = self.sigmoid(self.f_beta(h))
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(torch.cat([embeddings, attention_weighted_encoded], dim=1), (h,c))
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:, t, :] = preds
            if np.argmax(preds.detach().cpu().numpy()) == tokenizer.stoi["<eos>"]:
                break
            embeddings = self.embedding(torch.argmax(preds, -1))
        return predictions

In [None]:
#Helper functions
class AverageMator(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum /self.count
    
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%m %ds' % (m, s)
    
def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es -s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
    
def train_fn(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_optimizer, epochs,
            encode_scheduler, decoder_scheduler, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    encoder.train()
    decoder.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels, label_length) in enumerate(train_loader):
        #measure_loading_time
        data_time.upgrade(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        label_lengths = label_lengths.to(device)
        batch_size = image.size(0)
        features = encoder(images)
        predictions, caps_sorted, decode_lengths, alpha, sort_idx = decoder(features, labels, label_length)
        targets = caps_sorted[:,1:]
        #batchでパックしながらpadding
        predictions = pack_padded_sequence(predictions, decode_length, batch_first=True).data
        targets = pack_padded_sequence(targets, decode_length, batch_first=True)
        loss = criterion(predictions, targets)
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        loss.backward()
        encoder_grad_norm = torch.nn.utils.clip_grad_norm_(encoder.parameters(), CFG.max_grad_norm)
        decoder_grad_norm = torch.nn.utils.clip_grad_norm_(decoder.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            encoder_optimizer.step()
            decoder_optimizer.step()
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Encoder Grad: {encoder_grad_norm:.4f}  '
                  'Decoder Grad: {decoder_grad_norm:.4f}  '
                  #'Encoder LR: {encoder_lr:.6f}  '
                  #'Decoder LR: {decoder_lr:.6f}  '
                  .format(
                   epoch+1, step, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses,
                   remain=timeSince(start, float(step+1)/len(train_loader)),
                   encoder_grad_norm=encoder_grad_norm,
                   decoder_grad_norm=decoder_grad_norm,
                   #encoder_lr=encoder_scheduler.get_lr()[0],
                   #decoder_lr=decoder_scheduler.get_lr()[0],
                   ))
    return losses.avg


def valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    # switch to evaluation mode
    encoder.eval()
    decoder.eval()
    text_preds = []
    start = end = time.time()
    for step, (images) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        batch_size = images.size(0)
        with torch.no_grad():
            features = encoder(images)
            predictions = decoder.predict(features, CFG.max_len, tokenizer)
        predicted_sequence = torch.argmax(predictions.detach().cpu(), -1).numpy()
        _text_preds = tokenizer.predict_captions(predicted_sequence)
        text_preds.append(_text_preds)
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                  'Elapsed {remain:s} '
                  .format(
                   step, len(valid_loader), batch_time=batch_time,
                   data_time=data_time,
                   remain=timeSince(start, float(step+1)/len(valid_loader)),
                   ))
    text_preds = np.concatenate(text_preds)
    return text_preds



In [None]:
def init_logger():
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    return logger

LOGGER = init_logger()

In [None]:
def train_loop(folds, fold):
    
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds['InChI'].values

    train_dataset = TrainDataset(train_folds, tokenizer, transform=image_transform(data='train'))
    valid_dataset = TestDataset(valid_folds, transform=image_transform(data='valid'))

    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=True, 
                              num_workers=CFG.num_workers, 
                              pin_memory=True,
                              drop_last=True, 
                              collate_fn=bms_collate)
    valid_loader = DataLoader(valid_dataset, 
                              batch_size=CFG.batch_size, 
                              shuffle=False, 
                              num_workers=CFG.num_workers,
                              pin_memory=True, 
                              drop_last=False)
def get_scheduler(optimizer):
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
    return scheduler

    # ====================================================
    # model & optimizer
    # ====================================================
    encoder = Encoder(CFG.model_name, pretrained=True)
    encoder.to(device)
    encoder_optimizer = Adam(encoder.parameters(), lr=CFG.encoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    encoder_scheduler = get_scheduler(encoder_optimizer)
    
    decoder = DecoderWithAttention(attention_dim=CFG.attention_dim,
                                   embed_dim=CFG.embed_dim,
                                   decoder_dim=CFG.decoder_dim,
                                   vocab_size=len(tokenizer),
                                   dropout=CFG.dropout,
                                   device=device)
    decoder.to(device)
    decoder_optimizer = Adam(decoder.parameters(), lr=CFG.decoder_lr, weight_decay=CFG.weight_decay, amsgrad=False)
    decoder_scheduler = get_scheduler(decoder_optimizer)

    # ====================================================
    # loop
    # ====================================================
    criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.stoi["<pad>"])

    best_score = np.inf
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_loss = train_fn(train_loader, encoder, decoder, criterion, 
                            encoder_optimizer, decoder_optimizer, epoch, 
                            encoder_scheduler, decoder_scheduler, device)

        # eval
        text_preds = valid_fn(valid_loader, encoder, decoder, tokenizer, criterion, device)
        text_preds = [f"InChI=1S/{text}" for text in text_preds]
        LOGGER.info(f"labels: {valid_labels[:5]}")
        LOGGER.info(f"preds: {text_preds[:5]}")
        
        # scoring
        score = get_score(valid_labels, text_preds)
        
        if isinstance(encoder_scheduler, ReduceLROnPlateau):
            encoder_scheduler.step(score)
        elif isinstance(encoder_scheduler, CosineAnnealingLR):
            encoder_scheduler.step()
        elif isinstance(encoder_scheduler, CosineAnnealingWarmRestarts):
            encoder_scheduler.step()
            
        if isinstance(decoder_scheduler, ReduceLROnPlateau):
            decoder_scheduler.step(score)
        elif isinstance(decoder_scheduler, CosineAnnealingLR):
            decoder_scheduler.step()
        elif isinstance(decoder_scheduler, CosineAnnealingWarmRestarts):
            decoder_scheduler.step()

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}')
        
        if score < best_score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'encoder': encoder.state_dict(), 
                        'encoder_optimizer': encoder_optimizer.state_dict(), 
                        'encoder_scheduler': encoder_scheduler.state_dict(), 
                        'decoder': decoder.state_dict(), 
                        'decoder_optimizer': decoder_optimizer.state_dict(), 
                        'decoder_scheduler': decoder_scheduler.state_dict(), 
                        'text_preds': text_preds,
                       },
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best.pth')

In [None]:
def main():

    """
    Prepare: 1.train  2.folds
    """

    if CFG.train:
        # train
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                train_loop(folds, fold)

In [None]:
if __name__ == '__main__':
    main()