In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
%load_ext autoreload
%autoreload 2
from datasets import load_dataset
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from language import DynamicLanguage
from transition import RNNLanguageModel

def train_rnn_with_dynamic_language(lang: DynamicLanguage, dataset_path: str, test_size=0.1, batch_size=64, lr=1e-3, num_epochs=10, rnn_type="GRU", embed_size=None, hidden_size=256, num_layers=2, dropout=0.3):
    device="cuda:0" if torch.cuda.is_available() else "cpu"
    print("Is CUDA available: " + str(torch.cuda.is_available()))
    
    # make dataset and build vocabs
    ds = load_dataset("text", data_files={"train": dataset_path})
    ds = ds["train"].train_test_split(test_size=test_size)
    lang.build_vocab(ds)

    ds_tokenized = ds.map(lambda x: {"ids": lang.sentence2ids(x["text"])}, remove_columns=["text"])
    train_dataset = ds_tokenized["train"]
    test_dataset  = ds_tokenized["test"]
    pad_id = lang.pad_id()
    
    def collate(batch):
        seqs = [torch.tensor(ex["ids"]) for ex in batch]
        maxlen = max(len(s) for s in seqs)
        padded = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, :len(s)] = s
        return padded[:, :-1], padded[:, 1:]  # input, target
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model = RNNLanguageModel(pad_id=lang.pad_id(), vocab_size=len(lang.vocab()), embed_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type=rnn_type, dropout=dropout).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)


    best_val_loss = float('inf')
    best_state_dict = None
    
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(test_loader)

        print(f"[{epoch}] train_loss: {avg_train_loss:.4f}  val_loss: {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state_dict = model.state_dict()
        
    return model, best_state_dict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import yaml
import os
from utils import class_from_package

config_path = "config/train_rnn.yaml"
with open(os.path.join(repo_root, config_path)) as f:
    conf = yaml.safe_load(f)

output_dir = os.path.join(repo_root, conf.pop("output_dir"))
lang_class = class_from_package("language", conf.pop("lang_class"))
lang = lang_class(**conf.pop("lang_args", {}))
dataset_path = os.path.join(repo_root, conf.pop("dataset_path"))

model, best_state_dict = train_rnn_with_dynamic_language(lang=lang, dataset_path=dataset_path, **conf)

Is CUDA available: True


Map: 100%|██████████| 9721/9721 [00:00<00:00, 18145.54 examples/s]
Map: 100%|██████████| 1081/1081 [00:00<00:00, 16234.00 examples/s]
Epoch 1: 100%|██████████| 19/19 [00:00<00:00, 21.98it/s]


[1] train_loss: 3.6272  val_loss: 2.8061


Epoch 2: 100%|██████████| 19/19 [00:00<00:00, 29.23it/s]


[2] train_loss: 2.5610  val_loss: 2.4276


Epoch 3: 100%|██████████| 19/19 [00:00<00:00, 29.61it/s]


[3] train_loss: 2.2963  val_loss: 2.2200


Epoch 4: 100%|██████████| 19/19 [00:00<00:00, 29.68it/s]


[4] train_loss: 2.1734  val_loss: 2.1401


Epoch 5: 100%|██████████| 19/19 [00:00<00:00, 29.81it/s]


[5] train_loss: 2.0803  val_loss: 2.0139


Epoch 6: 100%|██████████| 19/19 [00:00<00:00, 29.44it/s]


[6] train_loss: 1.9942  val_loss: 1.9579


Epoch 7: 100%|██████████| 19/19 [00:00<00:00, 30.01it/s]


[7] train_loss: 1.9222  val_loss: 1.8793


Epoch 8: 100%|██████████| 19/19 [00:00<00:00, 29.89it/s]


[8] train_loss: 1.8613  val_loss: 1.8512


Epoch 9: 100%|██████████| 19/19 [00:00<00:00, 29.69it/s]


[9] train_loss: 1.8036  val_loss: 1.7786


Epoch 10: 100%|██████████| 19/19 [00:00<00:00, 29.76it/s]


[10] train_loss: 1.7510  val_loss: 1.7431


Epoch 11: 100%|██████████| 19/19 [00:00<00:00, 29.56it/s]


[11] train_loss: 1.7008  val_loss: 1.6533


Epoch 12: 100%|██████████| 19/19 [00:00<00:00, 29.50it/s]


[12] train_loss: 1.6527  val_loss: 1.6081


Epoch 13: 100%|██████████| 19/19 [00:00<00:00, 29.59it/s]


[13] train_loss: 1.6050  val_loss: 1.5444


Epoch 14: 100%|██████████| 19/19 [00:00<00:00, 29.48it/s]


[14] train_loss: 1.5593  val_loss: 1.5821


Epoch 15: 100%|██████████| 19/19 [00:00<00:00, 29.46it/s]


[15] train_loss: 1.5193  val_loss: 1.4994


Epoch 16: 100%|██████████| 19/19 [00:00<00:00, 29.23it/s]


[16] train_loss: 1.4794  val_loss: 1.3985


Epoch 17: 100%|██████████| 19/19 [00:00<00:00, 29.39it/s]


[17] train_loss: 1.4421  val_loss: 1.3397


Epoch 18: 100%|██████████| 19/19 [00:00<00:00, 29.15it/s]


[18] train_loss: 1.4076  val_loss: 1.3799


Epoch 19: 100%|██████████| 19/19 [00:00<00:00, 25.47it/s]


[19] train_loss: 1.3766  val_loss: 1.2702


Epoch 20: 100%|██████████| 19/19 [00:00<00:00, 29.19it/s]


[20] train_loss: 1.3483  val_loss: 1.3179


Epoch 21: 100%|██████████| 19/19 [00:00<00:00, 29.34it/s]


[21] train_loss: 1.3227  val_loss: 1.2830


Epoch 22: 100%|██████████| 19/19 [00:00<00:00, 29.24it/s]


[22] train_loss: 1.3002  val_loss: 1.2875


Epoch 23: 100%|██████████| 19/19 [00:00<00:00, 29.69it/s]


[23] train_loss: 1.2772  val_loss: 1.2185


Epoch 24: 100%|██████████| 19/19 [00:00<00:00, 29.84it/s]


[24] train_loss: 1.2558  val_loss: 1.1800


Epoch 25: 100%|██████████| 19/19 [00:00<00:00, 29.99it/s]


[25] train_loss: 1.2398  val_loss: 1.2751


Epoch 26: 100%|██████████| 19/19 [00:00<00:00, 29.84it/s]


[26] train_loss: 1.2197  val_loss: 1.1641


Epoch 27: 100%|██████████| 19/19 [00:00<00:00, 29.57it/s]


[27] train_loss: 1.2035  val_loss: 1.1584


Epoch 28: 100%|██████████| 19/19 [00:00<00:00, 29.70it/s]


[28] train_loss: 1.1882  val_loss: 1.1891


Epoch 29: 100%|██████████| 19/19 [00:00<00:00, 29.69it/s]


[29] train_loss: 1.1735  val_loss: 1.1023


Epoch 30: 100%|██████████| 19/19 [00:00<00:00, 29.74it/s]


[30] train_loss: 1.1584  val_loss: 1.1652


Epoch 31: 100%|██████████| 19/19 [00:00<00:00, 29.87it/s]


[31] train_loss: 1.1442  val_loss: 1.1669


Epoch 32: 100%|██████████| 19/19 [00:00<00:00, 29.55it/s]


[32] train_loss: 1.1340  val_loss: 1.0982


Epoch 33: 100%|██████████| 19/19 [00:00<00:00, 29.81it/s]


[33] train_loss: 1.1203  val_loss: 1.0992


Epoch 34: 100%|██████████| 19/19 [00:00<00:00, 29.69it/s]


[34] train_loss: 1.1102  val_loss: 1.1785


Epoch 35: 100%|██████████| 19/19 [00:00<00:00, 29.98it/s]


[35] train_loss: 1.0992  val_loss: 1.1835


Epoch 36: 100%|██████████| 19/19 [00:00<00:00, 29.60it/s]


[36] train_loss: 1.0889  val_loss: 1.1444


Epoch 37: 100%|██████████| 19/19 [00:00<00:00, 29.37it/s]


[37] train_loss: 1.0789  val_loss: 1.1933


Epoch 38: 100%|██████████| 19/19 [00:00<00:00, 29.34it/s]


[38] train_loss: 1.0686  val_loss: 1.1083


Epoch 39: 100%|██████████| 19/19 [00:00<00:00, 29.64it/s]


[39] train_loss: 1.0601  val_loss: 1.0730


Epoch 40: 100%|██████████| 19/19 [00:00<00:00, 29.14it/s]


[40] train_loss: 1.0512  val_loss: 1.0862


Epoch 41: 100%|██████████| 19/19 [00:00<00:00, 29.49it/s]


[41] train_loss: 1.0437  val_loss: 1.0783


Epoch 42: 100%|██████████| 19/19 [00:00<00:00, 29.50it/s]


[42] train_loss: 1.0350  val_loss: 1.0822


Epoch 43: 100%|██████████| 19/19 [00:00<00:00, 29.38it/s]


[43] train_loss: 1.0275  val_loss: 1.0599


Epoch 44: 100%|██████████| 19/19 [00:00<00:00, 29.32it/s]


[44] train_loss: 1.0196  val_loss: 1.1797


Epoch 45: 100%|██████████| 19/19 [00:00<00:00, 29.73it/s]


[45] train_loss: 1.0142  val_loss: 1.0474


Epoch 46: 100%|██████████| 19/19 [00:00<00:00, 29.32it/s]


[46] train_loss: 1.0065  val_loss: 1.0539


Epoch 47: 100%|██████████| 19/19 [00:00<00:00, 29.50it/s]


[47] train_loss: 0.9997  val_loss: 1.1567


Epoch 48: 100%|██████████| 19/19 [00:00<00:00, 29.18it/s]


[48] train_loss: 0.9905  val_loss: 1.0193


Epoch 49: 100%|██████████| 19/19 [00:00<00:00, 29.45it/s]


[49] train_loss: 0.9861  val_loss: 1.0538


Epoch 50: 100%|██████████| 19/19 [00:00<00:00, 29.35it/s]


[50] train_loss: 0.9807  val_loss: 1.1063


Epoch 51: 100%|██████████| 19/19 [00:00<00:00, 29.39it/s]


[51] train_loss: 0.9748  val_loss: 1.1390


Epoch 52: 100%|██████████| 19/19 [00:00<00:00, 29.18it/s]


[52] train_loss: 0.9683  val_loss: 1.1650


Epoch 53: 100%|██████████| 19/19 [00:00<00:00, 29.41it/s]


[53] train_loss: 0.9625  val_loss: 1.0757


Epoch 54: 100%|██████████| 19/19 [00:00<00:00, 29.68it/s]


[54] train_loss: 0.9567  val_loss: 1.0749


Epoch 55: 100%|██████████| 19/19 [00:00<00:00, 29.49it/s]


[55] train_loss: 0.9518  val_loss: 1.0068


Epoch 56: 100%|██████████| 19/19 [00:00<00:00, 29.86it/s]


[56] train_loss: 0.9461  val_loss: 1.1459


Epoch 57: 100%|██████████| 19/19 [00:00<00:00, 29.80it/s]


[57] train_loss: 0.9425  val_loss: 1.1094


Epoch 58: 100%|██████████| 19/19 [00:00<00:00, 29.86it/s]


[58] train_loss: 0.9374  val_loss: 0.9880


Epoch 59: 100%|██████████| 19/19 [00:00<00:00, 29.87it/s]


[59] train_loss: 0.9323  val_loss: 1.0595


Epoch 60: 100%|██████████| 19/19 [00:00<00:00, 29.43it/s]


[60] train_loss: 0.9264  val_loss: 1.0780


Epoch 61: 100%|██████████| 19/19 [00:00<00:00, 29.29it/s]


[61] train_loss: 0.9240  val_loss: 1.0193


Epoch 62: 100%|██████████| 19/19 [00:00<00:00, 29.38it/s]


[62] train_loss: 0.9193  val_loss: 1.0208


Epoch 63: 100%|██████████| 19/19 [00:00<00:00, 29.53it/s]


[63] train_loss: 0.9147  val_loss: 1.0446


Epoch 64: 100%|██████████| 19/19 [00:00<00:00, 29.44it/s]


[64] train_loss: 0.9112  val_loss: 1.0166


Epoch 65: 100%|██████████| 19/19 [00:00<00:00, 29.43it/s]


[65] train_loss: 0.9064  val_loss: 1.0521


Epoch 66: 100%|██████████| 19/19 [00:00<00:00, 29.45it/s]


[66] train_loss: 0.9030  val_loss: 1.0401


Epoch 67: 100%|██████████| 19/19 [00:00<00:00, 29.74it/s]


[67] train_loss: 0.8979  val_loss: 1.0256


Epoch 68: 100%|██████████| 19/19 [00:00<00:00, 29.76it/s]


[68] train_loss: 0.8937  val_loss: 0.9661


Epoch 69: 100%|██████████| 19/19 [00:00<00:00, 29.58it/s]


[69] train_loss: 0.8918  val_loss: 0.9878


Epoch 70: 100%|██████████| 19/19 [00:00<00:00, 29.17it/s]


[70] train_loss: 0.8874  val_loss: 1.0315


Epoch 71: 100%|██████████| 19/19 [00:00<00:00, 29.26it/s]


[71] train_loss: 0.8834  val_loss: 1.0278


Epoch 72: 100%|██████████| 19/19 [00:00<00:00, 29.40it/s]


[72] train_loss: 0.8789  val_loss: 1.0454


Epoch 73: 100%|██████████| 19/19 [00:00<00:00, 30.05it/s]


[73] train_loss: 0.8754  val_loss: 1.0043


Epoch 74: 100%|██████████| 19/19 [00:00<00:00, 29.70it/s]


[74] train_loss: 0.8721  val_loss: 1.1374


Epoch 75: 100%|██████████| 19/19 [00:00<00:00, 30.38it/s]


[75] train_loss: 0.8712  val_loss: 1.0392


Epoch 76: 100%|██████████| 19/19 [00:00<00:00, 29.20it/s]


[76] train_loss: 0.8659  val_loss: 1.0881


Epoch 77: 100%|██████████| 19/19 [00:00<00:00, 25.72it/s]


[77] train_loss: 0.8641  val_loss: 1.0610


Epoch 78: 100%|██████████| 19/19 [00:00<00:00, 29.46it/s]


[78] train_loss: 0.8596  val_loss: 1.0375


Epoch 79: 100%|██████████| 19/19 [00:00<00:00, 29.68it/s]


[79] train_loss: 0.8568  val_loss: 1.0507


Epoch 80: 100%|██████████| 19/19 [00:00<00:00, 30.08it/s]


[80] train_loss: 0.8545  val_loss: 1.0005


Epoch 81: 100%|██████████| 19/19 [00:00<00:00, 30.16it/s]


[81] train_loss: 0.8501  val_loss: 1.0648


Epoch 82: 100%|██████████| 19/19 [00:00<00:00, 30.01it/s]


[82] train_loss: 0.8484  val_loss: 1.1151


Epoch 83: 100%|██████████| 19/19 [00:00<00:00, 29.52it/s]


[83] train_loss: 0.8446  val_loss: 1.0692


Epoch 84: 100%|██████████| 19/19 [00:00<00:00, 29.78it/s]


[84] train_loss: 0.8425  val_loss: 1.0373


Epoch 85: 100%|██████████| 19/19 [00:00<00:00, 29.59it/s]


[85] train_loss: 0.8407  val_loss: 1.0247


Epoch 86: 100%|██████████| 19/19 [00:00<00:00, 29.68it/s]


[86] train_loss: 0.8371  val_loss: 1.0626


Epoch 87: 100%|██████████| 19/19 [00:00<00:00, 29.59it/s]


[87] train_loss: 0.8349  val_loss: 1.0362


Epoch 88: 100%|██████████| 19/19 [00:00<00:00, 29.83it/s]


[88] train_loss: 0.8335  val_loss: 1.0645


Epoch 89: 100%|██████████| 19/19 [00:00<00:00, 29.73it/s]


[89] train_loss: 0.8285  val_loss: 1.0914


Epoch 90: 100%|██████████| 19/19 [00:00<00:00, 29.64it/s]


[90] train_loss: 0.8273  val_loss: 0.9955


Epoch 91: 100%|██████████| 19/19 [00:00<00:00, 30.34it/s]


[91] train_loss: 0.8246  val_loss: 0.9878


Epoch 92: 100%|██████████| 19/19 [00:00<00:00, 29.82it/s]


[92] train_loss: 0.8199  val_loss: 1.0006


Epoch 93: 100%|██████████| 19/19 [00:00<00:00, 30.04it/s]


[93] train_loss: 0.8190  val_loss: 1.0052


Epoch 94: 100%|██████████| 19/19 [00:00<00:00, 30.18it/s]


[94] train_loss: 0.8156  val_loss: 1.0068


Epoch 95: 100%|██████████| 19/19 [00:00<00:00, 29.88it/s]


[95] train_loss: 0.8145  val_loss: 1.0596


Epoch 96: 100%|██████████| 19/19 [00:00<00:00, 29.81it/s]


[96] train_loss: 0.8124  val_loss: 1.0702


Epoch 97: 100%|██████████| 19/19 [00:00<00:00, 29.99it/s]


[97] train_loss: 0.8109  val_loss: 1.0899


Epoch 98: 100%|██████████| 19/19 [00:00<00:00, 29.80it/s]


[98] train_loss: 0.8074  val_loss: 1.0480


Epoch 99: 100%|██████████| 19/19 [00:00<00:00, 29.78it/s]


[99] train_loss: 0.8055  val_loss: 1.0443


Epoch 100: 100%|██████████| 19/19 [00:00<00:00, 29.73it/s]

[100] train_loss: 0.8037  val_loss: 1.0471





In [10]:
# save model
model.save(os.path.join(output_dir, "last"))
model.load_state_dict(best_state_dict)
model.save(os.path.join(output_dir, "best"))

In [None]:
# save yaml and lang
import shutil
src = os.path.join(repo_root, config_path)
dst = os.path.join(output_dir, "setting.yaml")
shutil.copy(src, dst)

if lang.__class__.__name__ == "HELM":
    lib_files = [
        "chembl_35_monomer_library.xml",
        "chembl_35_monomer_library_diff.xml",
        "HELMCoreLibrary.json",
        "monomerLib2.0.json"
    ]
    lang.load_monomer_library(*[repo_root + f"data/helm/library/{name}" for name in lib_files], culling=True)
lang.save(os.path.join(output_dir, "language.lang"))