In [52]:
import torch
import torch.nn as nn
from tfrecord.torch.dataset import TFRecordDataset
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

import numpy as np

PAD_LEN = 100
MASKING_RATE = 0.15
def pad_sequence_feats(data):
    if len(data['token']) >= PAD_LEN:
        data['token'][:PAD_LEN]
        
    data['masked_token'] = data['token'].copy()
    token_length = len(data['masked_token']) - 2
    index = np.random.choice(np.arange(1, token_length), int(token_length * MASKING_RATE), replace=False)
    data['masked_token'][index] = 3
    
    for k, v in data.items():       
        data[k] = np.pad(v, (0, PAD_LEN - len(v)), 'constant')
    
    return data

def collate_fn(batch):
    from torch.utils.data._utils import collate
    from torch.nn.utils import rnn
    
    batch_ = {k: [torch.Tensor(d[k]) for d in batch] for k in batch[0]}
    return {k: rnn.pad_sequence(f, True) for (k, f) in batch_.items()}


train_tfrecord_path = "../data/molecule_net/molecule_train.tfrecord"
valid_tfrecord_path = "../data/molecule_net/molecule_valid.tfrecord"
test_tfrecord_path = "../data/molecule_net/molecule_test.tfrecord"

index_path = None
description = {"smiles": "byte",
               "token": "float"}

train_dataset = TFRecordDataset(train_tfrecord_path, index_path, description, transform=pad_sequence_feats)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=512)

valid_dataset = TFRecordDataset(valid_tfrecord_path, index_path, description, transform=pad_sequence_feats)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=512)

test_dataset = TFRecordDataset(test_tfrecord_path, index_path, description, transform=pad_sequence_feats)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=512)

In [51]:
class Transformer(pl.LightningModule):
    def __init__(self, learning_rate):
        super(Transformer, self).__init__()
        self.model = nn.Transformer(nhead=8,
                                    num_encoder_layers=8, 
                                    num_decoder_layers=8,
                                    d_model=128,
                                    dim_feedforward=512,
                                    dropout=0.1,
                                    activation='gelu',
                                    batch_first=True)
        
        self.learning_rate = learning_rate
        
        self.train_accuracy = torchmetrics.Accuracy()
        self.valid_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()
    
    
    def training_step(self, batch, batch_idx):
        x = batch['masked_token']
        y = batch['token']
        
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        
        self.log("train_loss", loss)
        self.log("train_accuracy", self.train_accuracy(y_hat, y), on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
        
    def validation_step(self, batch, batch_idx):
        x = batch['masked_token']
        y = batch['token']
        
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        
        self.log("valid_loss", loss)
        self.log("valid_accuracy", self.valid_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)

    
    def test_step(self, batch, batch_idx):
        x = batch['masked_token']
        y = batch['token']
        
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        
        self.log("test_loss", loss)
        self.log("test_accuracy", self.test_accuracy(y_hat, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)
          
        
    def configure_optimizers(self):
        optimizer = BertAdam(params, lr=self.learning_rate, warmup=-1, t_total=-1, schedule='WarmupLinearSchedule',
                             b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0)
        
        return {"optimizer": optimizer}

    
def define_callbacks(patience, ckpt_path):
    early_stopping = EarlyStopping('valid_loss', patience=patience)
    check_points = ModelCheckpoint(monitor="valid_loss", mode="min", dirpath=ckpt_path, save_top_k=1)
    
    return [early_stopping, check_points]


LEARNING_RATE = 1e-4

transformer = Transformer(LEARNING_RATE)
callbacks = define_callbacks(10, "./weights")


Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inpla

In [None]:
N_EPOCHS = 50

trainer = pl.Trainer(gpus=1, max_epochs=N_EPOCHS, enable_progress_bar=True, callbacks=callbacks)
trainer.fit(classifier, train_dataloader, valid_dataloader)