In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
%load_ext autoreload
%autoreload 2
from datasets import load_dataset
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from language import DynamicLanguage
from transition import RNNLanguageModel

def train_rnn_with_dynamic_language(lang: DynamicLanguage, dataset_path: str, test_size=0.1, batch_size=64, lr=1e-3, num_epochs=10, rnn_type="GRU", embed_size=None, hidden_size=256, num_layers=2, dropout=0.3):
    device="cuda:0" if torch.cuda.is_available() else "cpu"
    print("Is CUDA available: " + str(torch.cuda.is_available()))
    
    # make dataset and build vocabs
    ds = load_dataset("text", data_files={"train": dataset_path})
    ds = ds["train"].train_test_split(test_size=test_size)
    lang.build_vocab(ds)

    ds_tokenized = ds.map(lambda x: {"ids": lang.sentence2ids(x["text"])}, remove_columns=["text"])
    train_dataset = ds_tokenized["train"]
    test_dataset  = ds_tokenized["test"]
    pad_id = lang.pad_id()
    
    def collate(batch):
        seqs = [torch.tensor(ex["ids"]) for ex in batch]
        maxlen = max(len(s) for s in seqs)
        padded = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, :len(s)] = s
        return padded[:, :-1], padded[:, 1:]  # input, target
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model = RNNLanguageModel(pad_id=lang.pad_id(), vocab_size=len(lang.vocab()), embed_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type=rnn_type, dropout=dropout).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(test_loader)

        print(f"[{epoch}] train_loss: {avg_train_loss:.4f}  val_loss: {avg_val_loss:.4f}")

    return model

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [4]:
import yaml
import os
from utils import class_from_package

config_path = "config/train_rnn.yaml"
with open(os.path.join(repo_root, config_path)) as f:
    conf = yaml.safe_load(f)

output_dir = os.path.join(repo_root, conf.pop("output_dir"))
lang_class = class_from_package("language", conf.pop("lang_class"))
lang = lang_class(**conf.get("lang_conf", {}))
dataset_path = os.path.join(repo_root, conf.pop("dataset_path"))

model = train_rnn_with_dynamic_language(lang=lang, dataset_path=dataset_path, **conf)

Is CUDA available: True


Map: 100%|██████████| 224510/224510 [00:10<00:00, 21927.64 examples/s]
Map: 100%|██████████| 24946/24946 [00:01<00:00, 22163.52 examples/s]
Epoch 1: 100%|██████████| 439/439 [00:15<00:00, 29.07it/s]


[1] train_loss: 1.4204  val_loss: 0.9583


Epoch 2: 100%|██████████| 439/439 [00:14<00:00, 29.61it/s]


[2] train_loss: 0.9172  val_loss: 0.8206


Epoch 3: 100%|██████████| 439/439 [00:14<00:00, 29.75it/s]


[3] train_loss: 0.8242  val_loss: 0.7662


Epoch 4: 100%|██████████| 439/439 [00:14<00:00, 29.84it/s]


[4] train_loss: 0.7794  val_loss: 0.7361


Epoch 5: 100%|██████████| 439/439 [00:14<00:00, 29.48it/s]


[5] train_loss: 0.7516  val_loss: 0.7165


Epoch 6: 100%|██████████| 439/439 [00:14<00:00, 29.83it/s]


[6] train_loss: 0.7320  val_loss: 0.7004


Epoch 7: 100%|██████████| 439/439 [00:14<00:00, 29.63it/s]


[7] train_loss: 0.7176  val_loss: 0.6892


Epoch 8: 100%|██████████| 439/439 [00:14<00:00, 29.48it/s]


[8] train_loss: 0.7057  val_loss: 0.6789


Epoch 9: 100%|██████████| 439/439 [00:14<00:00, 29.51it/s]


[9] train_loss: 0.6958  val_loss: 0.6718


Epoch 10: 100%|██████████| 439/439 [00:14<00:00, 29.66it/s]


[10] train_loss: 0.6876  val_loss: 0.6637


Epoch 11: 100%|██████████| 439/439 [00:14<00:00, 29.59it/s]


[11] train_loss: 0.6803  val_loss: 0.6578


Epoch 12: 100%|██████████| 439/439 [00:15<00:00, 29.15it/s]


[12] train_loss: 0.6743  val_loss: 0.6541


Epoch 13: 100%|██████████| 439/439 [00:14<00:00, 29.85it/s]


[13] train_loss: 0.6688  val_loss: 0.6481


Epoch 14: 100%|██████████| 439/439 [00:14<00:00, 29.59it/s]


[14] train_loss: 0.6639  val_loss: 0.6454


Epoch 15: 100%|██████████| 439/439 [00:14<00:00, 29.83it/s]


[15] train_loss: 0.6595  val_loss: 0.6420


Epoch 16: 100%|██████████| 439/439 [00:14<00:00, 29.63it/s]


[16] train_loss: 0.6559  val_loss: 0.6380


Epoch 17: 100%|██████████| 439/439 [00:14<00:00, 29.28it/s]


[17] train_loss: 0.6522  val_loss: 0.6349


Epoch 18: 100%|██████████| 439/439 [00:14<00:00, 29.58it/s]


[18] train_loss: 0.6489  val_loss: 0.6319


Epoch 19: 100%|██████████| 439/439 [00:14<00:00, 29.74it/s]


[19] train_loss: 0.6459  val_loss: 0.6302


Epoch 20: 100%|██████████| 439/439 [00:15<00:00, 29.08it/s]


[20] train_loss: 0.6432  val_loss: 0.6275


Epoch 21: 100%|██████████| 439/439 [00:14<00:00, 29.55it/s]


[21] train_loss: 0.6426  val_loss: 0.6264


Epoch 22: 100%|██████████| 439/439 [00:14<00:00, 29.64it/s]


[22] train_loss: 0.6386  val_loss: 0.6245


Epoch 23: 100%|██████████| 439/439 [00:14<00:00, 29.48it/s]


[23] train_loss: 0.6362  val_loss: 0.6216


Epoch 24: 100%|██████████| 439/439 [00:14<00:00, 29.83it/s]


[24] train_loss: 0.6342  val_loss: 0.6205


Epoch 25: 100%|██████████| 439/439 [00:14<00:00, 29.62it/s]


[25] train_loss: 0.6325  val_loss: 0.6188


Epoch 26: 100%|██████████| 439/439 [00:14<00:00, 29.66it/s]


[26] train_loss: 0.6306  val_loss: 0.6174


Epoch 27: 100%|██████████| 439/439 [00:14<00:00, 29.83it/s]


[27] train_loss: 0.6288  val_loss: 0.6154


Epoch 28: 100%|██████████| 439/439 [00:14<00:00, 29.31it/s]


[28] train_loss: 0.6272  val_loss: 0.6148


Epoch 29: 100%|██████████| 439/439 [00:14<00:00, 29.80it/s]


[29] train_loss: 0.6255  val_loss: 0.6131


Epoch 30: 100%|██████████| 439/439 [00:14<00:00, 29.29it/s]


[30] train_loss: 0.6240  val_loss: 0.6124


Epoch 31: 100%|██████████| 439/439 [00:14<00:00, 29.66it/s]


[31] train_loss: 0.6227  val_loss: 0.6110


Epoch 32: 100%|██████████| 439/439 [00:14<00:00, 29.72it/s]


[32] train_loss: 0.6218  val_loss: 0.6105


Epoch 33: 100%|██████████| 439/439 [00:14<00:00, 29.94it/s]


[33] train_loss: 0.6202  val_loss: 0.6092


Epoch 34: 100%|██████████| 439/439 [00:14<00:00, 29.81it/s]


[34] train_loss: 0.6191  val_loss: 0.6081


Epoch 35: 100%|██████████| 439/439 [00:14<00:00, 29.68it/s]


[35] train_loss: 0.6179  val_loss: 0.6075


Epoch 36: 100%|██████████| 439/439 [00:15<00:00, 28.99it/s]


[36] train_loss: 0.6167  val_loss: 0.6061


Epoch 37: 100%|██████████| 439/439 [00:14<00:00, 29.82it/s]


[37] train_loss: 0.6158  val_loss: 0.6062


Epoch 38: 100%|██████████| 439/439 [00:14<00:00, 30.07it/s]


[38] train_loss: 0.6147  val_loss: 0.6044


Epoch 39: 100%|██████████| 439/439 [00:14<00:00, 29.86it/s]


[39] train_loss: 0.6139  val_loss: 0.6042


Epoch 40: 100%|██████████| 439/439 [00:14<00:00, 29.84it/s]


[40] train_loss: 0.6133  val_loss: 0.6031


Epoch 41: 100%|██████████| 439/439 [00:14<00:00, 29.50it/s]


[41] train_loss: 0.6118  val_loss: 0.6036


Epoch 42: 100%|██████████| 439/439 [00:15<00:00, 28.86it/s]


[42] train_loss: 0.6110  val_loss: 0.6030


Epoch 43: 100%|██████████| 439/439 [00:15<00:00, 28.91it/s]


[43] train_loss: 0.6102  val_loss: 0.6016


Epoch 44: 100%|██████████| 439/439 [00:14<00:00, 29.52it/s]


[44] train_loss: 0.6092  val_loss: 0.6002


Epoch 45: 100%|██████████| 439/439 [00:14<00:00, 29.58it/s]


[45] train_loss: 0.6088  val_loss: 0.6007


Epoch 46: 100%|██████████| 439/439 [00:14<00:00, 29.56it/s]


[46] train_loss: 0.6078  val_loss: 0.6000


Epoch 47: 100%|██████████| 439/439 [00:14<00:00, 29.68it/s]


[47] train_loss: 0.6071  val_loss: 0.6000


Epoch 48: 100%|██████████| 439/439 [00:14<00:00, 29.64it/s]


[48] train_loss: 0.6063  val_loss: 0.5992


Epoch 49: 100%|██████████| 439/439 [00:14<00:00, 29.79it/s]


[49] train_loss: 0.6056  val_loss: 0.5984


Epoch 50: 100%|██████████| 439/439 [00:14<00:00, 29.43it/s]


[50] train_loss: 0.6050  val_loss: 0.5977


Epoch 51: 100%|██████████| 439/439 [00:14<00:00, 29.61it/s]


[51] train_loss: 0.6054  val_loss: 0.5982


Epoch 52: 100%|██████████| 439/439 [00:14<00:00, 29.67it/s]


[52] train_loss: 0.6040  val_loss: 0.5971


Epoch 53: 100%|██████████| 439/439 [00:14<00:00, 29.79it/s]


[53] train_loss: 0.6031  val_loss: 0.5963


Epoch 54: 100%|██████████| 439/439 [00:14<00:00, 29.65it/s]


[54] train_loss: 0.6024  val_loss: 0.5968


Epoch 55: 100%|██████████| 439/439 [00:14<00:00, 29.68it/s]


[55] train_loss: 0.6021  val_loss: 0.5957


Epoch 56: 100%|██████████| 439/439 [00:14<00:00, 30.00it/s]


[56] train_loss: 0.6013  val_loss: 0.5957


Epoch 57: 100%|██████████| 439/439 [00:14<00:00, 29.45it/s]


[57] train_loss: 0.6008  val_loss: 0.5955


Epoch 58: 100%|██████████| 439/439 [00:14<00:00, 29.90it/s]


[58] train_loss: 0.6003  val_loss: 0.5957


Epoch 59: 100%|██████████| 439/439 [00:14<00:00, 29.72it/s]


[59] train_loss: 0.5998  val_loss: 0.5958


Epoch 60: 100%|██████████| 439/439 [00:14<00:00, 29.80it/s]


[60] train_loss: 0.5994  val_loss: 0.5946


Epoch 61: 100%|██████████| 439/439 [00:14<00:00, 29.86it/s]


[61] train_loss: 0.5991  val_loss: 0.5941


Epoch 62: 100%|██████████| 439/439 [00:14<00:00, 29.82it/s]


[62] train_loss: 0.5984  val_loss: 0.5941


Epoch 63: 100%|██████████| 439/439 [00:14<00:00, 29.81it/s]


[63] train_loss: 0.5978  val_loss: 0.5937


Epoch 64: 100%|██████████| 439/439 [00:14<00:00, 29.48it/s]


[64] train_loss: 0.5976  val_loss: 0.5935


Epoch 65: 100%|██████████| 439/439 [00:14<00:00, 29.71it/s]


[65] train_loss: 0.5970  val_loss: 0.5930


Epoch 66: 100%|██████████| 439/439 [00:14<00:00, 29.54it/s]


[66] train_loss: 0.5965  val_loss: 0.5926


Epoch 67: 100%|██████████| 439/439 [00:14<00:00, 29.56it/s]


[67] train_loss: 0.5963  val_loss: 0.5925


Epoch 68: 100%|██████████| 439/439 [00:14<00:00, 29.49it/s]


[68] train_loss: 0.5959  val_loss: 0.5920


Epoch 69: 100%|██████████| 439/439 [00:14<00:00, 29.94it/s]


[69] train_loss: 0.5953  val_loss: 0.5921


Epoch 70: 100%|██████████| 439/439 [00:14<00:00, 29.66it/s]


[70] train_loss: 0.5949  val_loss: 0.5915


Epoch 71: 100%|██████████| 439/439 [00:14<00:00, 29.87it/s]


[71] train_loss: 0.5948  val_loss: 0.5912


Epoch 72: 100%|██████████| 439/439 [00:14<00:00, 29.73it/s]


[72] train_loss: 0.5939  val_loss: 0.5921


Epoch 73: 100%|██████████| 439/439 [00:14<00:00, 29.86it/s]


[73] train_loss: 0.5941  val_loss: 0.5913


Epoch 74: 100%|██████████| 439/439 [00:14<00:00, 29.78it/s]


[74] train_loss: 0.5934  val_loss: 0.5904


Epoch 75: 100%|██████████| 439/439 [00:14<00:00, 29.81it/s]


[75] train_loss: 0.5929  val_loss: 0.5899


Epoch 76: 100%|██████████| 439/439 [00:14<00:00, 30.01it/s]


[76] train_loss: 0.5924  val_loss: 0.5902


Epoch 77: 100%|██████████| 439/439 [00:14<00:00, 29.87it/s]


[77] train_loss: 0.5923  val_loss: 0.5911


Epoch 78: 100%|██████████| 439/439 [00:14<00:00, 29.90it/s]


[78] train_loss: 0.5921  val_loss: 0.5912


Epoch 79: 100%|██████████| 439/439 [00:14<00:00, 29.80it/s]


[79] train_loss: 0.5918  val_loss: 0.5894


Epoch 80: 100%|██████████| 439/439 [00:14<00:00, 29.78it/s]


[80] train_loss: 0.5912  val_loss: 0.5901


Epoch 81: 100%|██████████| 439/439 [00:14<00:00, 29.88it/s]


[81] train_loss: 0.5911  val_loss: 0.5898


Epoch 82: 100%|██████████| 439/439 [00:14<00:00, 30.04it/s]


[82] train_loss: 0.5908  val_loss: 0.5894


Epoch 83: 100%|██████████| 439/439 [00:14<00:00, 30.01it/s]


[83] train_loss: 0.5904  val_loss: 0.5889


Epoch 84: 100%|██████████| 439/439 [00:14<00:00, 29.79it/s]


[84] train_loss: 0.5898  val_loss: 0.5887


Epoch 85: 100%|██████████| 439/439 [00:14<00:00, 30.01it/s]


[85] train_loss: 0.5956  val_loss: 0.5901


Epoch 86: 100%|██████████| 439/439 [00:14<00:00, 29.73it/s]


[86] train_loss: 0.5910  val_loss: 0.5894


Epoch 87: 100%|██████████| 439/439 [00:14<00:00, 29.92it/s]


[87] train_loss: 0.5896  val_loss: 0.5884


Epoch 88: 100%|██████████| 439/439 [00:14<00:00, 29.79it/s]


[88] train_loss: 0.5901  val_loss: 0.5885


Epoch 89: 100%|██████████| 439/439 [00:14<00:00, 30.04it/s]


[89] train_loss: 0.5890  val_loss: 0.5885


Epoch 90: 100%|██████████| 439/439 [00:14<00:00, 30.04it/s]


[90] train_loss: 0.5885  val_loss: 0.5874


Epoch 91: 100%|██████████| 439/439 [00:14<00:00, 30.00it/s]


[91] train_loss: 0.5878  val_loss: 0.5877


Epoch 92: 100%|██████████| 439/439 [00:14<00:00, 30.09it/s]


[92] train_loss: 0.5875  val_loss: 0.5878


Epoch 93: 100%|██████████| 439/439 [00:14<00:00, 29.92it/s]


[93] train_loss: 0.5873  val_loss: 0.5867


Epoch 94: 100%|██████████| 439/439 [00:14<00:00, 30.13it/s]


[94] train_loss: 0.5873  val_loss: 0.5878


Epoch 95: 100%|██████████| 439/439 [00:14<00:00, 29.89it/s]


[95] train_loss: 0.5869  val_loss: 0.5873


Epoch 96: 100%|██████████| 439/439 [00:14<00:00, 30.24it/s]


[96] train_loss: 0.5869  val_loss: 0.5870


Epoch 97: 100%|██████████| 439/439 [00:14<00:00, 29.92it/s]


[97] train_loss: 0.5865  val_loss: 0.5872


Epoch 98: 100%|██████████| 439/439 [00:14<00:00, 30.16it/s]


[98] train_loss: 0.5863  val_loss: 0.5875


Epoch 99: 100%|██████████| 439/439 [00:14<00:00, 29.86it/s]


[99] train_loss: 0.5862  val_loss: 0.5866


Epoch 100: 100%|██████████| 439/439 [00:14<00:00, 30.13it/s]


[100] train_loss: 0.5858  val_loss: 0.5867


In [5]:
# save model
model.save(output_dir)

In [6]:
# save yaml and lang
import shutil
src = os.path.join(repo_root, config_path)
dst = os.path.join(output_dir, "setting.yaml")
shutil.copy(src, dst)

if lang.__class__.__name__ == "HELM":
    lib_files = [
        "chembl_35_monomer_library.xml",
        "chembl_35_monomer_library_diff.xml",
        "HELMCoreLibrary.json",
        "monomerLib2.0.json"
    ]
    lang.load_monomer_library(*[repo_root + f"data/helm/library/{name}" for name in lib_files], culling=True)
lang.save(os.path.join(output_dir, "smiles.lang"))