In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
%load_ext autoreload
%autoreload 2
from datasets import load_dataset
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from language import DynamicLanguage
from transition import RNNLanguageModel

def train_rnn_with_dynamic_language(lang: DynamicLanguage, dataset_path: str, test_size=0.1, batch_size=64, lr=1e-3, num_epochs=10, rnn_type="GRU", embed_size=None, hidden_size=256, num_layers=2, dropout=0.3):
    device="cuda:0" if torch.cuda.is_available() else "cpu"
    print("Is CUDA available: " + str(torch.cuda.is_available()))
    
    # make dataset and build vocabs
    ds = load_dataset("text", data_files={"train": dataset_path})
    ds = ds["train"].train_test_split(test_size=test_size)
    lang.build_vocab(ds)

    ds_tokenized = ds.map(lambda x: {"ids": lang.sentence2ids(x["text"])}, remove_columns=["text"])
    train_dataset = ds_tokenized["train"]
    test_dataset  = ds_tokenized["test"]
    pad_id = lang.pad_id()
    
    def collate(batch):
        seqs = [torch.tensor(ex["ids"]) for ex in batch]
        maxlen = max(len(s) for s in seqs)
        padded = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, :len(s)] = s
        return padded[:, :-1], padded[:, 1:]  # input, target
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model = RNNLanguageModel(vocab_size=len(lang.vocab()), embed_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type=rnn_type, dropout=dropout).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(test_loader)

        print(f"[{epoch}] train_loss: {avg_train_loss:.4f}  val_loss: {avg_val_loss:.4f}")

    return model

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import yaml
import os
from utils import class_from_package

config_path = "config/train_rnn_smiles.yaml"
with open(os.path.join(repo_root, config_path)) as f:
    conf = yaml.safe_load(f)

output_dir = os.path.join(repo_root, conf.pop("output_dir"))
lang_class = class_from_package("language", conf.pop("lang_class"))
lang = lang_class(**conf.get("lang_conf", {}))
dataset_path = os.path.join(repo_root, conf.pop("dataset_path"))

model = train_rnn_with_dynamic_language(lang=lang, dataset_path=dataset_path, **conf)

Is CUDA available: True


Map: 100%|██████████| 224510/224510 [00:10<00:00, 20739.66 examples/s]
Map: 100%|██████████| 24946/24946 [00:01<00:00, 20941.84 examples/s]
Epoch 1: 100%|██████████| 439/439 [00:10<00:00, 40.51it/s]


[1] train_loss: 1.2152  val_loss: 0.8508


Epoch 2: 100%|██████████| 439/439 [00:10<00:00, 39.95it/s]


[2] train_loss: 0.8176  val_loss: 0.7637


Epoch 3: 100%|██████████| 439/439 [00:11<00:00, 39.24it/s]


[3] train_loss: 0.7594  val_loss: 0.7258


Epoch 4: 100%|██████████| 439/439 [00:11<00:00, 39.47it/s]


[4] train_loss: 0.7298  val_loss: 0.7045


Epoch 5: 100%|██████████| 439/439 [00:11<00:00, 39.58it/s]


[5] train_loss: 0.7112  val_loss: 0.6912


Epoch 6: 100%|██████████| 439/439 [00:11<00:00, 39.37it/s]


[6] train_loss: 0.6974  val_loss: 0.6791


Epoch 7: 100%|██████████| 439/439 [00:11<00:00, 39.28it/s]


[7] train_loss: 0.6870  val_loss: 0.6707


Epoch 8: 100%|██████████| 439/439 [00:11<00:00, 39.14it/s]


[8] train_loss: 0.6787  val_loss: 0.6630


Epoch 9: 100%|██████████| 439/439 [00:11<00:00, 39.31it/s]


[9] train_loss: 0.6716  val_loss: 0.6568


Epoch 10: 100%|██████████| 439/439 [00:11<00:00, 39.44it/s]


[10] train_loss: 0.6658  val_loss: 0.6527


Epoch 11: 100%|██████████| 439/439 [00:11<00:00, 38.98it/s]


[11] train_loss: 0.6604  val_loss: 0.6481


Epoch 12: 100%|██████████| 439/439 [00:11<00:00, 39.44it/s]


[12] train_loss: 0.6560  val_loss: 0.6450


Epoch 13: 100%|██████████| 439/439 [00:11<00:00, 38.91it/s]


[13] train_loss: 0.6521  val_loss: 0.6422


Epoch 14: 100%|██████████| 439/439 [00:11<00:00, 38.81it/s]


[14] train_loss: 0.6485  val_loss: 0.6390


Epoch 15: 100%|██████████| 439/439 [00:11<00:00, 39.28it/s]


[15] train_loss: 0.6451  val_loss: 0.6357


Epoch 16: 100%|██████████| 439/439 [00:11<00:00, 39.51it/s]


[16] train_loss: 0.6422  val_loss: 0.6349


Epoch 17: 100%|██████████| 439/439 [00:11<00:00, 39.10it/s]


[17] train_loss: 0.6394  val_loss: 0.6328


Epoch 18: 100%|██████████| 439/439 [00:10<00:00, 39.95it/s]


[18] train_loss: 0.6370  val_loss: 0.6306


Epoch 19: 100%|██████████| 439/439 [00:11<00:00, 39.66it/s]


[19] train_loss: 0.6346  val_loss: 0.6284


Epoch 20: 100%|██████████| 439/439 [00:11<00:00, 39.31it/s]


[20] train_loss: 0.6327  val_loss: 0.6261


Epoch 21: 100%|██████████| 439/439 [00:11<00:00, 39.60it/s]


[21] train_loss: 0.6306  val_loss: 0.6254


Epoch 22: 100%|██████████| 439/439 [00:11<00:00, 39.36it/s]


[22] train_loss: 0.6290  val_loss: 0.6238


Epoch 23: 100%|██████████| 439/439 [00:11<00:00, 39.11it/s]


[23] train_loss: 0.6273  val_loss: 0.6233


Epoch 24: 100%|██████████| 439/439 [00:11<00:00, 39.26it/s]


[24] train_loss: 0.6257  val_loss: 0.6217


Epoch 25: 100%|██████████| 439/439 [00:11<00:00, 39.28it/s]


[25] train_loss: 0.6243  val_loss: 0.6229


Epoch 26: 100%|██████████| 439/439 [00:11<00:00, 38.72it/s]


[26] train_loss: 0.6230  val_loss: 0.6203


Epoch 27: 100%|██████████| 439/439 [00:11<00:00, 38.86it/s]


[27] train_loss: 0.6216  val_loss: 0.6194


Epoch 28: 100%|██████████| 439/439 [00:11<00:00, 39.12it/s]


[28] train_loss: 0.6202  val_loss: 0.6184


Epoch 29: 100%|██████████| 439/439 [00:11<00:00, 39.41it/s]


[29] train_loss: 0.6188  val_loss: 0.6184


Epoch 30: 100%|██████████| 439/439 [00:11<00:00, 39.54it/s]


[30] train_loss: 0.6181  val_loss: 0.6165


Epoch 31: 100%|██████████| 439/439 [00:11<00:00, 38.92it/s]


[31] train_loss: 0.6170  val_loss: 0.6162


Epoch 32: 100%|██████████| 439/439 [00:11<00:00, 39.32it/s]


[32] train_loss: 0.6157  val_loss: 0.6164


Epoch 33: 100%|██████████| 439/439 [00:13<00:00, 32.94it/s]


[33] train_loss: 0.6149  val_loss: 0.6152


Epoch 34: 100%|██████████| 439/439 [00:11<00:00, 39.00it/s]


[34] train_loss: 0.6140  val_loss: 0.6151


Epoch 35: 100%|██████████| 439/439 [00:12<00:00, 36.53it/s]


[35] train_loss: 0.6133  val_loss: 0.6138


Epoch 36: 100%|██████████| 439/439 [00:11<00:00, 38.00it/s]


[36] train_loss: 0.6124  val_loss: 0.6142


Epoch 37: 100%|██████████| 439/439 [00:11<00:00, 37.60it/s]


[37] train_loss: 0.6116  val_loss: 0.6136


Epoch 38: 100%|██████████| 439/439 [00:11<00:00, 38.76it/s]


[38] train_loss: 0.6108  val_loss: 0.6123


Epoch 39: 100%|██████████| 439/439 [00:11<00:00, 38.79it/s]


[39] train_loss: 0.6099  val_loss: 0.6127


Epoch 40: 100%|██████████| 439/439 [00:11<00:00, 39.10it/s]


[40] train_loss: 0.6095  val_loss: 0.6125


Epoch 41: 100%|██████████| 439/439 [00:11<00:00, 38.33it/s]


[41] train_loss: 0.6086  val_loss: 0.6116


Epoch 42: 100%|██████████| 439/439 [00:11<00:00, 38.30it/s]


[42] train_loss: 0.6082  val_loss: 0.6111


Epoch 43: 100%|██████████| 439/439 [00:11<00:00, 38.55it/s]


[43] train_loss: 0.6073  val_loss: 0.6114


Epoch 44: 100%|██████████| 439/439 [00:11<00:00, 39.13it/s]


[44] train_loss: 0.6069  val_loss: 0.6110


Epoch 45: 100%|██████████| 439/439 [00:11<00:00, 39.34it/s]


[45] train_loss: 0.6063  val_loss: 0.6109


Epoch 46: 100%|██████████| 439/439 [00:11<00:00, 38.15it/s]


[46] train_loss: 0.6062  val_loss: 0.6097


Epoch 47: 100%|██████████| 439/439 [00:11<00:00, 39.48it/s]


[47] train_loss: 0.6051  val_loss: 0.6101


Epoch 48: 100%|██████████| 439/439 [00:11<00:00, 39.49it/s]


[48] train_loss: 0.6047  val_loss: 0.6100


Epoch 49: 100%|██████████| 439/439 [00:11<00:00, 38.93it/s]


[49] train_loss: 0.6042  val_loss: 0.6094


Epoch 50: 100%|██████████| 439/439 [00:11<00:00, 39.10it/s]


[50] train_loss: 0.6037  val_loss: 0.6096


Epoch 51: 100%|██████████| 439/439 [00:11<00:00, 39.45it/s]


[51] train_loss: 0.6034  val_loss: 0.6091


Epoch 52: 100%|██████████| 439/439 [00:11<00:00, 38.95it/s]


[52] train_loss: 0.6028  val_loss: 0.6084


Epoch 53: 100%|██████████| 439/439 [00:11<00:00, 38.57it/s]


[53] train_loss: 0.6024  val_loss: 0.6099


Epoch 54: 100%|██████████| 439/439 [00:11<00:00, 38.94it/s]


[54] train_loss: 0.6019  val_loss: 0.6089


Epoch 55: 100%|██████████| 439/439 [00:11<00:00, 38.69it/s]


[55] train_loss: 0.6014  val_loss: 0.6078


Epoch 56: 100%|██████████| 439/439 [00:11<00:00, 39.25it/s]


[56] train_loss: 0.6012  val_loss: 0.6102


Epoch 57: 100%|██████████| 439/439 [00:11<00:00, 39.65it/s]


[57] train_loss: 0.6010  val_loss: 0.6079


Epoch 58: 100%|██████████| 439/439 [00:11<00:00, 39.00it/s]


[58] train_loss: 0.6006  val_loss: 0.6076


Epoch 59: 100%|██████████| 439/439 [00:11<00:00, 39.06it/s]


[59] train_loss: 0.6001  val_loss: 0.6093


Epoch 60: 100%|██████████| 439/439 [00:11<00:00, 39.33it/s]


[60] train_loss: 0.5997  val_loss: 0.6082


Epoch 61: 100%|██████████| 439/439 [00:11<00:00, 38.91it/s]


[61] train_loss: 0.5993  val_loss: 0.6070


Epoch 62: 100%|██████████| 439/439 [00:11<00:00, 39.00it/s]


[62] train_loss: 0.5992  val_loss: 0.6071


Epoch 63: 100%|██████████| 439/439 [00:11<00:00, 39.14it/s]


[63] train_loss: 0.5988  val_loss: 0.6067


Epoch 64: 100%|██████████| 439/439 [00:11<00:00, 39.54it/s]


[64] train_loss: 0.5984  val_loss: 0.6066


Epoch 65: 100%|██████████| 439/439 [00:11<00:00, 39.73it/s]


[65] train_loss: 0.5980  val_loss: 0.6064


Epoch 66: 100%|██████████| 439/439 [00:11<00:00, 39.18it/s]


[66] train_loss: 0.5977  val_loss: 0.6058


Epoch 67: 100%|██████████| 439/439 [00:11<00:00, 39.27it/s]


[67] train_loss: 0.5975  val_loss: 0.6059


Epoch 68: 100%|██████████| 439/439 [00:11<00:00, 39.26it/s]


[68] train_loss: 0.5973  val_loss: 0.6062


Epoch 69: 100%|██████████| 439/439 [00:11<00:00, 39.20it/s]


[69] train_loss: 0.5969  val_loss: 0.6059


Epoch 70: 100%|██████████| 439/439 [00:11<00:00, 39.33it/s]


[70] train_loss: 0.5967  val_loss: 0.6067


Epoch 71: 100%|██████████| 439/439 [00:11<00:00, 38.68it/s]


[71] train_loss: 0.5963  val_loss: 0.6051


Epoch 72: 100%|██████████| 439/439 [00:11<00:00, 39.23it/s]


[72] train_loss: 0.5962  val_loss: 0.6068


Epoch 73: 100%|██████████| 439/439 [00:11<00:00, 39.35it/s]


[73] train_loss: 0.5959  val_loss: 0.6063


Epoch 74: 100%|██████████| 439/439 [00:12<00:00, 35.75it/s]


[74] train_loss: 0.5954  val_loss: 0.6053


Epoch 75: 100%|██████████| 439/439 [00:11<00:00, 39.03it/s]


[75] train_loss: 0.5953  val_loss: 0.6052


Epoch 76: 100%|██████████| 439/439 [00:11<00:00, 39.67it/s]


[76] train_loss: 0.5948  val_loss: 0.6065


Epoch 77: 100%|██████████| 439/439 [00:11<00:00, 38.97it/s]


[77] train_loss: 0.5950  val_loss: 0.6051


Epoch 78: 100%|██████████| 439/439 [00:11<00:00, 38.81it/s]


[78] train_loss: 0.5950  val_loss: 0.6052


Epoch 79: 100%|██████████| 439/439 [00:11<00:00, 39.05it/s]


[79] train_loss: 0.5944  val_loss: 0.6053


Epoch 80: 100%|██████████| 439/439 [00:11<00:00, 39.10it/s]


[80] train_loss: 0.5941  val_loss: 0.6052


Epoch 81: 100%|██████████| 439/439 [00:11<00:00, 38.76it/s]


[81] train_loss: 0.5937  val_loss: 0.6052


Epoch 82: 100%|██████████| 439/439 [00:11<00:00, 38.92it/s]


[82] train_loss: 0.5938  val_loss: 0.6045


Epoch 83: 100%|██████████| 439/439 [00:11<00:00, 37.86it/s]


[83] train_loss: 0.5938  val_loss: 0.6041


Epoch 84: 100%|██████████| 439/439 [00:11<00:00, 38.63it/s]


[84] train_loss: 0.5931  val_loss: 0.6045


Epoch 85: 100%|██████████| 439/439 [00:11<00:00, 39.32it/s]


[85] train_loss: 0.5932  val_loss: 0.6044


Epoch 86: 100%|██████████| 439/439 [00:11<00:00, 39.49it/s]


[86] train_loss: 0.5931  val_loss: 0.6046


Epoch 87: 100%|██████████| 439/439 [00:11<00:00, 39.53it/s]


[87] train_loss: 0.5927  val_loss: 0.6039


Epoch 88: 100%|██████████| 439/439 [00:11<00:00, 39.28it/s]


[88] train_loss: 0.5925  val_loss: 0.6045


Epoch 89: 100%|██████████| 439/439 [00:11<00:00, 39.06it/s]


[89] train_loss: 0.5924  val_loss: 0.6043


Epoch 90: 100%|██████████| 439/439 [00:11<00:00, 39.16it/s]


[90] train_loss: 0.5923  val_loss: 0.6048


Epoch 91: 100%|██████████| 439/439 [00:11<00:00, 39.15it/s]


[91] train_loss: 0.5918  val_loss: 0.6046


Epoch 92: 100%|██████████| 439/439 [00:11<00:00, 39.12it/s]


[92] train_loss: 0.5921  val_loss: 0.6047


Epoch 93: 100%|██████████| 439/439 [00:11<00:00, 39.17it/s]


[93] train_loss: 0.5920  val_loss: 0.6041


Epoch 94: 100%|██████████| 439/439 [00:11<00:00, 38.92it/s]


[94] train_loss: 0.5915  val_loss: 0.6039


Epoch 95: 100%|██████████| 439/439 [00:11<00:00, 39.26it/s]


[95] train_loss: 0.5912  val_loss: 0.6039


Epoch 96: 100%|██████████| 439/439 [00:11<00:00, 38.70it/s]


[96] train_loss: 0.5911  val_loss: 0.6037


Epoch 97: 100%|██████████| 439/439 [00:11<00:00, 38.70it/s]


[97] train_loss: 0.5907  val_loss: 0.6036


Epoch 98: 100%|██████████| 439/439 [00:11<00:00, 39.10it/s]


[98] train_loss: 0.5907  val_loss: 0.6032


Epoch 99: 100%|██████████| 439/439 [00:11<00:00, 39.14it/s]


[99] train_loss: 0.5907  val_loss: 0.6033


Epoch 100: 100%|██████████| 439/439 [00:11<00:00, 39.12it/s]


[100] train_loss: 0.5906  val_loss: 0.6037


In [None]:
# save model
model.lang = lang
model.save(output_dir)

In [None]:
# save yaml and lang
import shutil
src = os.path.join(repo_root, config_path)
dst = os.path.join(output_dir, "setting.yaml")
shutil.copy(src, dst)

if lang.__class__.__name__ == "HELM":
    lib_files = [
        "chembl_35_monomer_library.xml",
        "chembl_35_monomer_library_diff.xml",
        "HELMCoreLibrary.json",
        "monomerLib2.0.json"
    ]
    lang.load_monomer_library(*[repo_root + f"data/helm/library/{name}" for name in lib_files], culling=True)
lang.save(os.path.join(output_dir, "smiles.lang"))