In [1]:
# !pip install https://files.pythonhosted.org/packages/03/16/63d67d3222044e702d0172cb412ef4a102240795740dacd6044bcc05786a/dlprog-1.2.8.tar.gz

In [2]:
from pathlib import Path
import json
from datetime import datetime, timedelta, timezone

import torch
import torch.nn as nn
import torch.optim as optim
from dlprog import train_progress

from models import load_model, save_model
from data import SyncedPianos

In [3]:
device_pc = torch.device("cuda:0")
device_amt = device_pc if torch.cuda.device_count() <= 1 else torch.device("cuda:1")
prog = train_progress(width=20, defer=True)

with open("models/config.json", "r") as f:
    CONFIG = json.load(f)
n_frames = CONFIG["data"]["input"]["num_frame"]
path_pc = CONFIG["default"]["pc"]

JST = timezone(timedelta(hours=+9), "JST")

device_pc, device_amt

(device(type='cuda', index=0), device(type='cuda', index=1))

In [4]:
amt = load_model(amt=True, with_sv=False, device=device_amt)
pc = load_model(amt=True, with_sv=True, device=device_pc)

amt = torch.compile(amt)
pc = torch.compile(pc)

In [5]:
batch_size = 16

In [6]:
dir_specs = Path("dataset/spec/")
synced_pianos = SyncedPianos(dir_specs, n_frames=n_frames, batch_size=batch_size)

In [12]:
def select(label, thr, random_prob=0.3):
    idx = (label > thr)
    shifted_p = torch.roll(idx, 1, -1)
    shifted_n = torch.roll(idx, -1, -1)
    random = torch.rand(idx.shape).to(device_pc)
    random = random < random_prob
    idx = idx | shifted_p | shifted_n | random
    return idx


BCE_LOSS = nn.BCELoss()
CE_LOSS = nn.CrossEntropyLoss()

def loss_fn(pred, label, thr_onset=0.5, thr_offset=0.5, thr_mpe=0.5):
    # unpack
    onset_f_pred, offset_f_pred, mpe_f_pred, velocity_f_pred, _, \
    onset_t_pred, offset_t_pred, mpe_t_pred, velocity_t_pred = pred

    _, _, _, _, _, \
    onset_label, offset_label, mpe_label, velocity_label = label

    onset_label = (onset_label > thr_onset).float()
    offset_label = (offset_label > thr_offset).float()
    mpe_label = (mpe_label > thr_mpe).float()

    # select
    onset_idx = select(onset_label, thr_onset)
    onset_f_pred = onset_f_pred[onset_idx]
    onset_t_pred = onset_t_pred[onset_idx]
    onset_label = (onset_label[onset_idx] > thr_onset).float()

    offset_idx = select(offset_label, thr_offset)
    offset_f_pred = offset_f_pred[offset_idx]
    offset_t_pred = offset_t_pred[offset_idx]
    offset_label = (offset_label[offset_idx] > thr_offset).float()

    velocity_label = velocity_label.argmax(dim=-1)
    velocity_idx = select(velocity_label, 0, 0.01)
    velocity_f_pred = velocity_f_pred[velocity_idx]
    velocity_t_pred = velocity_t_pred[velocity_idx]
    velocity_label = velocity_label[velocity_idx]

    # velocity_dim = velocity_f_pred.shape[-1]
    # velocity_f_pred = velocity_f_pred.view(-1, velocity_dim)
    # velocity_t_pred = velocity_t_pred.view(-1, velocity_dim)
    # velocity_label = velocity_label.argmax(dim=-1).view(-1)

    # loss
    loss_onset_f = BCE_LOSS(onset_f_pred, onset_label)
    loss_offset_f = BCE_LOSS(offset_f_pred, offset_label)
    loss_mpe_f = BCE_LOSS(mpe_f_pred, mpe_label)
    loss_velocity_f = CE_LOSS(velocity_f_pred, velocity_label)

    loss_onset_t = BCE_LOSS(onset_t_pred, onset_label)
    loss_offset_t = BCE_LOSS(offset_t_pred, offset_label)
    loss_mpe_t = BCE_LOSS(mpe_t_pred, mpe_label)
    loss_velocity_t = CE_LOSS(velocity_t_pred, velocity_label)

    loss = \
        loss_onset_f + loss_offset_f + loss_mpe_f + loss_velocity_f + \
        loss_onset_t + loss_offset_t + loss_mpe_t + loss_velocity_t

    return loss

In [13]:
DIR_CHECKPOINTS = Path("models/params/checkpoints/")
FILE_NAME_LOG = "log.txt"

def train(pc, amt, optimizer, scheduler, n_epochs=100):
    pc.train()
    amt.eval()

    date = datetime.now(JST).strftime("%Y%m%d%H%M%S")
    dir_checkpoint = DIR_CHECKPOINTS / date
    dir_checkpoint.mkdir()
    file_log = dir_checkpoint / FILE_NAME_LOG

    prog.start(n_iter=len(synced_pianos), n_epochs=n_epochs)
    for epoch in range(1, n_epochs+1):
        for n, sync_piano in enumerate(synced_pianos, 1):
            for orig, piano, sv in sync_piano:
                orig = orig.to(device_pc)
                piano = piano.to(device_amt)
                optimizer.zero_grad()

                out_pc = pc(orig, sv=sv)
                with torch.no_grad():
                    out_amt = amt(piano)
                out_amt = [x.to(device_pc) for x in out_amt]

                loss = loss_fn(out_pc, out_amt)
                loss.backward()
                optimizer.step()
                prog.update(loss.item(), advance=0)

            prog.update(note=f"song: {n}/{len(synced_pianos)}")
            save_model(pc, path_pc)
            with open(file_log, "a") as f:
                f.write(f"{n}, {prog.now_values()}\n")

        scheduler.step(prog.now_values())
        path_pc_epoch = dir_checkpoint / f"{epoch}.pth"
        save_model(pc, path_pc_epoch)
        with open(file_log, "a") as f:
            time = datetime.now(JST).strftime("%Y/%m/%d %H:%M")
            f.write(f"{time}, epoch {epoch} finished, loss: {prog.now_values()}\n")

        prog.memo()

In [14]:
torch.set_float32_matmul_precision("high")

In [15]:
optimizer = optim.Adam(pc.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

In [16]:
train(pc, amt, optimizer, scheduler)

  1/100: ##                    12% [00:09:03.45] loss: 3.03792, song: 17/139 

KeyboardInterrupt: 