In [1]:
# !pip install https://files.pythonhosted.org/packages/03/16/63d67d3222044e702d0172cb412ef4a102240795740dacd6044bcc05786a/dlprog-1.2.8.tar.gz

In [2]:
from pathlib import Path
import json
from datetime import datetime, timedelta, timezone

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dlprog import train_progress

from data import PianoCoversDataset
from models import load_model, save_model
from train import loss_fn


prog = train_progress(width=20, defer=True)

with open("models/config.json", "r") as f:
    CONFIG = json.load(f)
path_pc = CONFIG["default"]["pc"]

JST = timezone(timedelta(hours=+9), "JST")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
model = load_model(amt=True, with_sv=True, device=device)
model = torch.compile(model)
torch.set_float32_matmul_precision("high")

In [6]:
batch_size = 4
dir_dataset = Path("dataset/")
dataset = PianoCoversDataset(dir_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [9]:
DIR_CHECKPOINTS = Path("models/params/checkpoints/")
FILE_NAME_LOG = "log.txt"

def train(model, dataloader, optimizer, scheduler, n_epochs=100):
    date = datetime.now(JST).strftime("%Y%m%d%H%M%S")
    dir_checkpoint = DIR_CHECKPOINTS / date
    dir_checkpoint.mkdir()
    file_log = dir_checkpoint / FILE_NAME_LOG

    model.train()
    prog.start(n_iter=len(dataset), n_epochs=n_epochs, label=["loss", "f1"])
    for epoch in range(1, n_epochs+1):
        for n, batch in enumerate(dataloader):
            spec, sv, onset, offset, mpe, velocity = batch
            spec = spec.to(device)
            sv = sv.to(device)
            onset = onset.to(device)
            offset = offset.to(device)
            mpe = mpe.to(device)
            velocity = velocity.to(device)

            optimizer.zero_grad()
            out = model(spec, sv)
            loss, f1 = loss_fn(out, (onset, offset, mpe, velocity))
            loss.backward()
            optimizer.step()

            prog.update([loss.item(), f1])
            if n % 100 == 0:
                save_model(model, path_pc)
                with open(file_log, "a") as f:
                    f.write(f"{n}, {prog.now_values()}\n")

        loss, f1 = prog.now_values()
        scheduler.step(loss)
        path_pc_epoch = dir_checkpoint / f"{epoch}.pth"
        save_model(model, path_pc_epoch)
        with open(file_log, "a") as f:
            time = datetime.now(JST).strftime("%Y/%m/%d %H:%M")
            f.write(f"{time}, epoch {epoch} finished, loss: {loss}, f1: {f1}\n")

        prog.memo()

In [11]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [None]:
train(model, dataloader, optimizer, scheduler, n_epochs=100)