In [1]:
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

PAD_WIDTH = 100
def pad_sequence_feats(data):
    for k, v in data.items():       
        data[k] = np.pad(v, (0, PAD_WIDTH - len(v)), 'constant')
    return data

def collate_fn(batch):
    from torch.utils.data._utils import collate
    from torch.nn.utils import rnn
    
    batch_ = {k: [torch.Tensor(d[k]) for d in batch] for k in batch[0]}
    return {k: rnn.pad_sequence(f, True) for (k, f) in batch_.items()}


tfrecord_path = "data/molecule_net/molecule_test.tfrecord"
index_path = None
description = {"smiles": "byte",
               "token": "float"}
# dataset = TFRecordDataset(tfrecord_path, index_path, description, transform=pad_sequence_feats)
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)


In [3]:
data = next(iter(loader))
print(data['smiles'].shape)
print(data['token'].shape)

torch.Size([32, 119])
torch.Size([32, 121])


In [51]:
import torch.nn as nn

model = nn.Transformer(nhead=8,
                       num_encoder_layers=8, 
                       num_decoder_layers=8,
                       d_model=128,
                       dim_feedforward=512,
                       dropout=0.1,
                       activation='gelu',
                       batch_first=True)


Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inpla

In [52]:
src = torch.rand((32, 10, 128))
tgt = torch.rand((32, 10, 128))
out = model(src, tgt)
out.shape

torch.Size([32, 10, 128])