In [None]:
import pandas as pd, numpy as np, torch, torch.nn as nn, torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from utils.exft import exFTModel
from utils.wmdl import Composer
from utils.nntool import UniDataset, SimpTrainer
from sklearn.model_selection import train_test_split
from pickle import dump

In [None]:
ft_ = exFTModel(r"./ft_model/cczh128.bin")

In [None]:
# we'll process data here later

df: pd.DataFrame = pd.read_csv()    # *

In [None]:
Xdata, ydata = [], []

_subs: str
for _subs, compd in zip(df["subs"], df["compd"]):   # *
    subs = tuple([ torch.from_numpy( ft_.get_word_vec(i) ) for i in _subs.split("+") ])
    compd = ft_.get_word_vec(compd)
    Xdata.append(subs); ydata.append( torch.from_numpy(compd) )

ydata = torch.stack(ydata)

print(Xdata, "\n", ydata)

In [None]:
Xtr, Xts, ytr, yts = train_test_split(Xdata, ydata, test_size = 0.2)

In [None]:
trainer = SimpTrainer(
    mdl := Composer(
        128,
        32,
    ),
    (
        UniDataset(Xtr, ytr)
        .to_dataloader(batch_size = 10, _collate_fn = lambda x: x)
    ),
    nn.MSELoss(),       # why not use Cosine Loss (..) here ?
    opt := optim.AdamW(
        mdl.parameters(),
        lr = 3e-4,
        weight_decay = 1E-2,
    )
)

sched = StepLR(opt, step_size = 5, gamma = 0.63)

In [None]:
for i in range(50):
    ls_tr = trainer.train_epoch(verbose = f"epoch - {i}")
    sched.step()
    
    if i % 10 == 9:
        ls_ts = trainer.eval_(Xts, yts)
        print(f"checkpoint-()  tr:{ls_tr: .4f}; ts:{ls_ts: .4f}")

In [None]:
with open(r"./models/w0.bin", "wb") as f: dump(mdl, f)