In [3]:
from encoder_model import EncoderModel
import torch
from dataset import SMILESDataset
from torch.utils.data import DataLoader

In [4]:
device = torch.device("cpu")

In [5]:
model_args = {"d_model": 512,
              "d_out": None,
              "dim_feedforward": 512,
              "is_causal": True,
              "nhead": 32,
              "num_layers": 2,
              "output_head": "LogitOut",
              "output_head_opts": {"d_model": 512, "d_out": 41},
              "permute_output": True,
              "pooler": "IdentityPool",
              "pooler_opts": {},
              "source_size": 41,
              "src_embed": "nn.embed",
              "src_forward_function": "src_fwd_fxn_basic",
              "src_pad_token": 38}

In [6]:
model = EncoderModel(**model_args)
model.eval()

EncoderModel(
  (network): EncoderNetwork(
    (src_embed): Embedding(41, 512, padding_idx=38)
    (pooler): IdentityPool()
    (output_head): LogitOut(
      (network): Linear(in_features=512, out_features=41, bias=True)
    )
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2

In [7]:
ckpt = torch.load("RESTART_checkpoint.pt",
                  map_location=device)["model_state_dict"]
curr_state_dict = model.state_dict()
pretrained_dict = {k: v
                   for k, v in ckpt.items()
                   if curr_state_dict[k].shape == v.shape}

model.load_state_dict(pretrained_dict, strict=False)

<All keys matched successfully>

# Create Dataset

In [8]:
dataset = SMILESDataset(data_file = "generative_smiles_dset.h5", device = device)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False)

# Compute Loss (Sanity Check)

In [9]:
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=38)

In [10]:
for (x, y) in dataloader:
    loss, _, _ = model.get_loss(x, y, loss_fn)
    print(loss)
    break

tensor(2.7282, grad_fn=<NllLoss2DBackward0>)




# Save Embedding

In [11]:
for (x, y) in dataloader:
    model.save_embeddings(x)
    break

In [13]:
print(dataset)
print(dataset[0])

<dataset.SMILESDataset object at 0x146997bb0>
((tensor([39., 12., 34.,  4., 34., 34., 34., 35., 34.,  4., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38.]), b'Brc1cccnc1'), (tensor([12., 34.,  4., 34., 34., 34., 35., 34.,  4., 40., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38., 38.,
        38., 38., 38., 38., 38., 38., 38.]),))
