In [1]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
%load_ext autoreload
%autoreload 2
from datasets import load_dataset
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from language import DynamicLanguage
from transition import RNNLanguageModel

def train_rnn_with_dynamic_language(lang: DynamicLanguage, dataset_path: str, test_dataset_path: str=None, test_size: float=0.1, batch_size=64, lr=1e-3, num_epochs=10, rnn_type="GRU", embed_size=None, hidden_size=256, num_layers=2, dropout=0.3):
    device="cuda:0" if torch.cuda.is_available() else "cpu"
    print("Is CUDA available: " + str(torch.cuda.is_available()))
    
    # make dataset and build vocabs
    if test_dataset_path is None:
        ds = load_dataset("text", data_files={"train": dataset_path})
        ds = ds["train"].train_test_split(test_size=test_size)
    else:
        ds = load_dataset("text", data_files={"train": dataset_path, "test": test_dataset_path})
    lang.build_vocab(ds)

    ds_tokenized = ds.map(lambda x: {"ids": lang.sentence2indices(x["text"])}, remove_columns=["text"])
    train_dataset = ds_tokenized["train"]
    test_dataset  = ds_tokenized["test"]
    pad_id = lang.pad_id()
    
    def collate(batch):
        seqs = [torch.tensor(ex["ids"]) for ex in batch]
        maxlen = max(len(s) for s in seqs)
        padded = torch.full((len(seqs), maxlen), pad_id, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, :len(s)] = s
        return padded[:, :-1], padded[:, 1:]  # input, target
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate)
    
    model = RNNLanguageModel(pad_id=lang.pad_id(), vocab_size=len(lang.vocab()), embed_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, rnn_type=rnn_type, dropout=dropout).to(device)
    optim = torch.optim.Adam(model.parameters(), lr=lr)


    best_val_loss = float('inf')
    best_state_dict = None
    
    for epoch in range(1, num_epochs + 1):
        model.train()
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"Epoch {epoch}"):
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
            optim.zero_grad()
            loss.backward()
            optim.step()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)

        # validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                logits, _ = model(x)
                loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), y.reshape(-1), ignore_index=pad_id)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(test_loader)

        print(f"[{epoch}] train_loss: {avg_train_loss:.4f}  val_loss: {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state_dict = model.state_dict()
        
    return model, best_state_dict

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import yaml
import os
from utils import class_from_package

config_path = "config/train_rnn_guacamol.yaml"
with open(os.path.join(repo_root, config_path)) as f:
    conf = yaml.safe_load(f)

output_dir = os.path.join(repo_root, conf.pop("output_dir"))
lang_class = class_from_package("language", conf.pop("lang_class"))
lang = lang_class(**conf.pop("lang_args", {}))
conf["dataset_path"]= os.path.join(repo_root, conf["dataset_path"])
if "test_dataset_path" in conf:
    conf["test_dataset_path"] = os.path.join(repo_root, conf["test_dataset_path"])

model, best_state_dict = train_rnn_with_dynamic_language(lang=lang, **conf)

Is CUDA available: True


Map: 100%|██████████| 1273104/1273104 [00:36<00:00, 34882.81 examples/s]
Map: 100%|██████████| 238706/238706 [00:06<00:00, 34481.00 examples/s]
Epoch 1:  42%|████▏     | 1051/2487 [00:35<00:48, 29.40it/s]

In [10]:
# save model
model.save(os.path.join(output_dir, "last"))
model.load_state_dict(best_state_dict)
model.save(os.path.join(output_dir, "best"))

In [None]:
# save yaml and lang
import shutil
src = os.path.join(repo_root, config_path)
dst = os.path.join(output_dir, "setting.yaml")
shutil.copy(src, dst)

if lang.__class__.__name__ == "HELM":
    lib_files = [
        "chembl_35_monomer_library.xml",
        "chembl_35_monomer_library_diff.xml",
        "HELMCoreLibrary.json",
        "monomerLib2.0.json"
    ]
    lang.load_monomer_library(*[repo_root + f"data/helm/library/{name}" for name in lib_files], culling=True)
lang.save(os.path.join(output_dir, "language.lang"))