In [6]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [7]:
%load_ext autoreload
%autoreload 2
from datasets import load_dataset
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from language import DynamicLanguage
from transition import RNNLanguageModel

def train_rnn_with_dynamic_language(lang: DynamicLanguage, dataset_path: str, test_size=0.1, batch_size=64, lr=1e-3, num_epochs=10, rnn_type="GRU", embed_size=None, hidden_size=256, num_layers=2, dropout=0.3):
    device="cuda:0" if torch.cuda.is_available() else "cpu"
    print("Is CUDA available: " + str(torch.cuda.is_available()))
    
    # make dataset and build vocabs
    ds = load_dataset("text", data_files={"train": dataset_path})
    ds = ds["train"].train_test_split(test_size=test_size)
    lang.build_vocab(ds)

    ds_tokenized = ds.map(lambda x: {"ids": lang.sentence2ids(x["text"])}, remove_columns=["text"])
    train_dataset = ds_tokenized["train"]
    test_dataset  = ds_tokenized["test"]
    pad_id = lang.pad_id()
    
    def collate(batch):
        seqs = [torch.tensor(ex["ids"]) for ex in batch]
        maxlen = max(len(s) for s in seqs)
        padded = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, :len(s)] = s
        return padded[:, :-1], padded[:, 1:]  # input, target
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model = RNNLanguageModel(pad_id=lang.pad_id(), vocab_size=len(lang.vocab()), embed_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type=rnn_type, dropout=dropout).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(test_loader)

        print(f"[{epoch}] train_loss: {avg_train_loss:.4f}  val_loss: {avg_val_loss:.4f}")

    return model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
import yaml
import os
from utils import class_from_package

config_path = "config/train_rnn.yaml"
with open(os.path.join(repo_root, config_path)) as f:
    conf = yaml.safe_load(f)

output_dir = os.path.join(repo_root, conf.pop("output_dir"))
lang_class = class_from_package("language", conf.pop("lang_class"))
lang = lang_class(**conf.get("lang_conf", {}))
dataset_path = os.path.join(repo_root, conf.pop("dataset_path"))

model = train_rnn_with_dynamic_language(lang=lang, dataset_path=dataset_path, **conf)

Is CUDA available: True


Generating train split: 153253 examples [00:00, 3260822.68 examples/s]
Map: 100%|██████████| 137927/137927 [00:06<00:00, 21980.71 examples/s]
Map: 100%|██████████| 15326/15326 [00:00<00:00, 21856.93 examples/s]
Epoch 1: 100%|██████████| 270/270 [00:08<00:00, 32.26it/s]


[1] train_loss: 1.3900  val_loss: 0.9161


Epoch 2: 100%|██████████| 270/270 [00:08<00:00, 32.44it/s]


[2] train_loss: 0.8836  val_loss: 0.7913


Epoch 3: 100%|██████████| 270/270 [00:08<00:00, 33.31it/s]


[3] train_loss: 0.7968  val_loss: 0.7432


Epoch 4: 100%|██████████| 270/270 [00:08<00:00, 33.52it/s]


[4] train_loss: 0.7575  val_loss: 0.7168


Epoch 5: 100%|██████████| 270/270 [00:08<00:00, 33.03it/s]


[5] train_loss: 0.7336  val_loss: 0.7001


Epoch 6: 100%|██████████| 270/270 [00:08<00:00, 33.35it/s]


[6] train_loss: 0.7175  val_loss: 0.6891


Epoch 7: 100%|██████████| 270/270 [00:08<00:00, 33.18it/s]


[7] train_loss: 0.7052  val_loss: 0.6805


Epoch 8: 100%|██████████| 270/270 [00:08<00:00, 32.99it/s]


[8] train_loss: 0.6954  val_loss: 0.6701


Epoch 9: 100%|██████████| 270/270 [00:08<00:00, 33.27it/s]


[9] train_loss: 0.6874  val_loss: 0.6651


Epoch 10: 100%|██████████| 270/270 [00:08<00:00, 33.24it/s]


[10] train_loss: 0.6807  val_loss: 0.6592


Epoch 11: 100%|██████████| 270/270 [00:08<00:00, 32.79it/s]


[11] train_loss: 0.6749  val_loss: 0.6559


Epoch 12: 100%|██████████| 270/270 [00:08<00:00, 33.11it/s]


[12] train_loss: 0.6700  val_loss: 0.6504


Epoch 13: 100%|██████████| 270/270 [00:08<00:00, 33.36it/s]


[13] train_loss: 0.6655  val_loss: 0.6479


Epoch 14: 100%|██████████| 270/270 [00:08<00:00, 32.95it/s]


[14] train_loss: 0.6615  val_loss: 0.6448


Epoch 15: 100%|██████████| 270/270 [00:08<00:00, 33.38it/s]


[15] train_loss: 0.6578  val_loss: 0.6406


Epoch 16: 100%|██████████| 270/270 [00:08<00:00, 33.51it/s]


[16] train_loss: 0.6544  val_loss: 0.6389


Epoch 17: 100%|██████████| 270/270 [00:08<00:00, 33.28it/s]


[17] train_loss: 0.6516  val_loss: 0.6363


Epoch 18: 100%|██████████| 270/270 [00:08<00:00, 33.07it/s]


[18] train_loss: 0.6487  val_loss: 0.6339


Epoch 19: 100%|██████████| 270/270 [00:08<00:00, 33.54it/s]


[19] train_loss: 0.6463  val_loss: 0.6316


Epoch 20: 100%|██████████| 270/270 [00:08<00:00, 33.48it/s]


[20] train_loss: 0.6437  val_loss: 0.6289


Epoch 21: 100%|██████████| 270/270 [00:08<00:00, 33.20it/s]


[21] train_loss: 0.6415  val_loss: 0.6281


Epoch 22: 100%|██████████| 270/270 [00:08<00:00, 33.44it/s]


[22] train_loss: 0.6397  val_loss: 0.6262


Epoch 23: 100%|██████████| 270/270 [00:08<00:00, 33.21it/s]


[23] train_loss: 0.6376  val_loss: 0.6256


Epoch 24: 100%|██████████| 270/270 [00:08<00:00, 32.63it/s]


[24] train_loss: 0.6358  val_loss: 0.6245


Epoch 25: 100%|██████████| 270/270 [00:08<00:00, 31.86it/s]


[25] train_loss: 0.6340  val_loss: 0.6219


Epoch 26: 100%|██████████| 270/270 [00:08<00:00, 32.93it/s]


[26] train_loss: 0.6326  val_loss: 0.6208


Epoch 27: 100%|██████████| 270/270 [00:08<00:00, 32.81it/s]


[27] train_loss: 0.6310  val_loss: 0.6200


Epoch 28: 100%|██████████| 270/270 [00:08<00:00, 33.20it/s]


[28] train_loss: 0.6295  val_loss: 0.6194


Epoch 29: 100%|██████████| 270/270 [00:08<00:00, 33.15it/s]


[29] train_loss: 0.6280  val_loss: 0.6187


Epoch 30: 100%|██████████| 270/270 [00:08<00:00, 33.23it/s]


[30] train_loss: 0.6267  val_loss: 0.6168


Epoch 31: 100%|██████████| 270/270 [00:08<00:00, 33.68it/s]


[31] train_loss: 0.6254  val_loss: 0.6160


Epoch 32: 100%|██████████| 270/270 [00:08<00:00, 33.55it/s]


[32] train_loss: 0.6243  val_loss: 0.6162


Epoch 33: 100%|██████████| 270/270 [00:08<00:00, 32.75it/s]


[33] train_loss: 0.6233  val_loss: 0.6141


Epoch 34: 100%|██████████| 270/270 [00:08<00:00, 32.50it/s]


[34] train_loss: 0.6222  val_loss: 0.6141


Epoch 35: 100%|██████████| 270/270 [00:08<00:00, 33.59it/s]


[35] train_loss: 0.6213  val_loss: 0.6127


Epoch 36: 100%|██████████| 270/270 [00:08<00:00, 33.25it/s]


[36] train_loss: 0.6198  val_loss: 0.6125


Epoch 37: 100%|██████████| 270/270 [00:08<00:00, 33.21it/s]


[37] train_loss: 0.6191  val_loss: 0.6114


Epoch 38: 100%|██████████| 270/270 [00:07<00:00, 33.81it/s]


[38] train_loss: 0.6184  val_loss: 0.6113


Epoch 39: 100%|██████████| 270/270 [00:08<00:00, 33.38it/s]


[39] train_loss: 0.6173  val_loss: 0.6102


Epoch 40: 100%|██████████| 270/270 [00:08<00:00, 32.60it/s]


[40] train_loss: 0.6165  val_loss: 0.6092


Epoch 41: 100%|██████████| 270/270 [00:08<00:00, 32.21it/s]


[41] train_loss: 0.6160  val_loss: 0.6089


Epoch 42: 100%|██████████| 270/270 [00:08<00:00, 31.92it/s]


[42] train_loss: 0.6150  val_loss: 0.6085


Epoch 43: 100%|██████████| 270/270 [00:08<00:00, 32.48it/s]


[43] train_loss: 0.6146  val_loss: 0.6083


Epoch 44: 100%|██████████| 270/270 [00:08<00:00, 32.67it/s]


[44] train_loss: 0.6134  val_loss: 0.6088


Epoch 45: 100%|██████████| 270/270 [00:08<00:00, 32.96it/s]


[45] train_loss: 0.6128  val_loss: 0.6079


Epoch 46: 100%|██████████| 270/270 [00:08<00:00, 32.57it/s]


[46] train_loss: 0.6121  val_loss: 0.6073


Epoch 47: 100%|██████████| 270/270 [00:08<00:00, 32.53it/s]


[47] train_loss: 0.6113  val_loss: 0.6077


Epoch 48: 100%|██████████| 270/270 [00:08<00:00, 33.01it/s]


[48] train_loss: 0.6107  val_loss: 0.6058


Epoch 49: 100%|██████████| 270/270 [00:08<00:00, 32.25it/s]


[49] train_loss: 0.6098  val_loss: 0.6057


Epoch 50: 100%|██████████| 270/270 [00:08<00:00, 32.85it/s]


[50] train_loss: 0.6098  val_loss: 0.6051


Epoch 51: 100%|██████████| 270/270 [00:08<00:00, 32.75it/s]


[51] train_loss: 0.6090  val_loss: 0.6053


Epoch 52: 100%|██████████| 270/270 [00:08<00:00, 32.66it/s]


[52] train_loss: 0.6081  val_loss: 0.6055


Epoch 53: 100%|██████████| 270/270 [00:08<00:00, 32.75it/s]


[53] train_loss: 0.6079  val_loss: 0.6048


Epoch 54: 100%|██████████| 270/270 [00:08<00:00, 32.60it/s]


[54] train_loss: 0.6079  val_loss: 0.6048


Epoch 55: 100%|██████████| 270/270 [00:08<00:00, 32.54it/s]


[55] train_loss: 0.6072  val_loss: 0.6041


Epoch 56: 100%|██████████| 270/270 [00:08<00:00, 32.40it/s]


[56] train_loss: 0.6063  val_loss: 0.6034


Epoch 57: 100%|██████████| 270/270 [00:08<00:00, 32.94it/s]


[57] train_loss: 0.6057  val_loss: 0.6026


Epoch 58: 100%|██████████| 270/270 [00:08<00:00, 32.81it/s]


[58] train_loss: 0.6055  val_loss: 0.6032


Epoch 59: 100%|██████████| 270/270 [00:08<00:00, 31.56it/s]


[59] train_loss: 0.6048  val_loss: 0.6031


Epoch 60: 100%|██████████| 270/270 [00:08<00:00, 32.66it/s]


[60] train_loss: 0.6044  val_loss: 0.6023


Epoch 61: 100%|██████████| 270/270 [00:08<00:00, 32.63it/s]


[61] train_loss: 0.6038  val_loss: 0.6023


Epoch 62: 100%|██████████| 270/270 [00:08<00:00, 32.27it/s]


[62] train_loss: 0.6034  val_loss: 0.6013


Epoch 63: 100%|██████████| 270/270 [00:08<00:00, 32.48it/s]


[63] train_loss: 0.6030  val_loss: 0.6022


Epoch 64: 100%|██████████| 270/270 [00:08<00:00, 32.08it/s]


[64] train_loss: 0.6027  val_loss: 0.6021


Epoch 65: 100%|██████████| 270/270 [00:08<00:00, 31.84it/s]


[65] train_loss: 0.6022  val_loss: 0.6024


Epoch 66: 100%|██████████| 270/270 [00:08<00:00, 32.66it/s]


[66] train_loss: 0.6018  val_loss: 0.6023


Epoch 67: 100%|██████████| 270/270 [00:08<00:00, 33.42it/s]


[67] train_loss: 0.6016  val_loss: 0.6011


Epoch 68: 100%|██████████| 270/270 [00:08<00:00, 33.30it/s]


[68] train_loss: 0.6011  val_loss: 0.6005


Epoch 69: 100%|██████████| 270/270 [00:08<00:00, 33.26it/s]


[69] train_loss: 0.6010  val_loss: 0.6006


Epoch 70: 100%|██████████| 270/270 [00:08<00:00, 33.26it/s]


[70] train_loss: 0.6004  val_loss: 0.6004


Epoch 71: 100%|██████████| 270/270 [00:08<00:00, 33.52it/s]


[71] train_loss: 0.6004  val_loss: 0.6003


Epoch 72: 100%|██████████| 270/270 [00:08<00:00, 33.23it/s]


[72] train_loss: 0.5995  val_loss: 0.5998


Epoch 73: 100%|██████████| 270/270 [00:08<00:00, 33.51it/s]


[73] train_loss: 0.5997  val_loss: 0.6001


Epoch 74: 100%|██████████| 270/270 [00:08<00:00, 33.37it/s]


[74] train_loss: 0.5991  val_loss: 0.5997


Epoch 75: 100%|██████████| 270/270 [00:08<00:00, 32.76it/s]


[75] train_loss: 0.5987  val_loss: 0.5998


Epoch 76: 100%|██████████| 270/270 [00:08<00:00, 32.89it/s]


[76] train_loss: 0.5982  val_loss: 0.5990


Epoch 77: 100%|██████████| 270/270 [00:08<00:00, 32.74it/s]


[77] train_loss: 0.5982  val_loss: 0.5995


Epoch 78: 100%|██████████| 270/270 [00:08<00:00, 31.98it/s]


[78] train_loss: 0.5980  val_loss: 0.5994


Epoch 79: 100%|██████████| 270/270 [00:08<00:00, 32.82it/s]


[79] train_loss: 0.5975  val_loss: 0.5986


Epoch 80: 100%|██████████| 270/270 [00:08<00:00, 33.00it/s]


[80] train_loss: 0.5974  val_loss: 0.5987


Epoch 81: 100%|██████████| 270/270 [00:08<00:00, 32.56it/s]


[81] train_loss: 0.5971  val_loss: 0.5995


Epoch 82: 100%|██████████| 270/270 [00:08<00:00, 33.06it/s]


[82] train_loss: 0.5967  val_loss: 0.5984


Epoch 83: 100%|██████████| 270/270 [00:08<00:00, 33.00it/s]


[83] train_loss: 0.5965  val_loss: 0.5986


Epoch 84: 100%|██████████| 270/270 [00:08<00:00, 32.44it/s]


[84] train_loss: 0.5974  val_loss: 0.5983


Epoch 85: 100%|██████████| 270/270 [00:08<00:00, 32.83it/s]


[85] train_loss: 0.5964  val_loss: 0.5981


Epoch 86: 100%|██████████| 270/270 [00:08<00:00, 33.02it/s]


[86] train_loss: 0.5956  val_loss: 0.5999


Epoch 87: 100%|██████████| 270/270 [00:08<00:00, 33.32it/s]


[87] train_loss: 0.5954  val_loss: 0.5985


Epoch 88: 100%|██████████| 270/270 [00:08<00:00, 32.73it/s]


[88] train_loss: 0.5953  val_loss: 0.5973


Epoch 89: 100%|██████████| 270/270 [00:08<00:00, 32.79it/s]


[89] train_loss: 0.5946  val_loss: 0.5978


Epoch 90: 100%|██████████| 270/270 [00:08<00:00, 33.05it/s]


[90] train_loss: 0.5951  val_loss: 0.5976


Epoch 91: 100%|██████████| 270/270 [00:08<00:00, 32.80it/s]


[91] train_loss: 0.5945  val_loss: 0.5972


Epoch 92: 100%|██████████| 270/270 [00:08<00:00, 33.29it/s]


[92] train_loss: 0.5943  val_loss: 0.5983


Epoch 93: 100%|██████████| 270/270 [00:08<00:00, 33.19it/s]


[93] train_loss: 0.5939  val_loss: 0.5979


Epoch 94: 100%|██████████| 270/270 [00:08<00:00, 32.64it/s]


[94] train_loss: 0.5940  val_loss: 0.5968


Epoch 95: 100%|██████████| 270/270 [00:08<00:00, 32.99it/s]


[95] train_loss: 0.5936  val_loss: 0.5972


Epoch 96: 100%|██████████| 270/270 [00:08<00:00, 32.88it/s]


[96] train_loss: 0.5934  val_loss: 0.5973


Epoch 97: 100%|██████████| 270/270 [00:08<00:00, 32.55it/s]


[97] train_loss: 0.5930  val_loss: 0.5968


Epoch 98: 100%|██████████| 270/270 [00:08<00:00, 32.97it/s]


[98] train_loss: 0.5928  val_loss: 0.5965


Epoch 99: 100%|██████████| 270/270 [00:08<00:00, 32.65it/s]


[99] train_loss: 0.5924  val_loss: 0.5959


Epoch 100: 100%|██████████| 270/270 [00:08<00:00, 31.92it/s]


[100] train_loss: 0.5930  val_loss: 0.5970


In [9]:
# save model
model.save(output_dir)

In [10]:
# save yaml and lang
import shutil
src = os.path.join(repo_root, config_path)
dst = os.path.join(output_dir, "setting.yaml")
shutil.copy(src, dst)

if lang.__class__.__name__ == "HELM":
    lib_files = [
        "chembl_35_monomer_library.xml",
        "chembl_35_monomer_library_diff.xml",
        "HELMCoreLibrary.json",
        "monomerLib2.0.json"
    ]
    lang.load_monomer_library(*[repo_root + f"data/helm/library/{name}" for name in lib_files], culling=True)
lang.save(os.path.join(output_dir, "smiles.lang"))