In [2]:
from pathlib import Path
from datetime import datetime
import random
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, unpad_sequence, pack_sequence, pad_packed_sequence
from torchmetrics.classification import BinaryF1Score
from tqdm.auto import tqdm
import wandb
from utils import find_file_by_stem

torch.set_printoptions(sci_mode=False)
RANDOM_SEED = 1
DEV = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class OsuDataset:
    """
    Beatmap + audio dataset.
    GETITEM: specs, beat_phase, beat_num, difficulty, onsets, actions
    """

    def __init__(self, beatmap_path, audio_path):
        self.beatmap_fns = sorted(list(beatmap_path.glob('*.pt')))
        self.audio_fns = list(audio_path.glob('*.pt'))

    def __len__(self):
        return len(self.beatmap_fns)

    def __getitem__(self, idx):
        beatmap_fn = self.beatmap_fns[idx]
        audio_fn = find_file_by_stem(
            self.audio_fns, beatmap_fn.stem.split('-')[0])
        if audio_fn == -1:
            raise FileNotFoundError
        actions, onsets, _, difficulty = torch.load(beatmap_fn).values()
        specs, beat_phase, beat_num = torch.load(audio_fn).values()

        # randomly slice data to 30s
        if specs.shape[1] > 3001:
            start = random.randint(0, specs.shape[1] - 3001)
            actions = actions[start:start+3001]
            onsets = onsets[start+1:start+3001]
            specs = specs[:, start+1:start+3001, :]
            beat_phase = beat_phase[start+1:start+3001]
            beat_num = beat_num[start+1:start+3001]
            difficulty = torch.tensor(difficulty).expand(3000)
        else:
            raise IndexError(f'Beatmap shorter than 30s: {beatmap_fn.name}')

        return specs, beat_phase, beat_num, difficulty, onsets, actions

In [3]:
class PadCollater:
    def __call__(self, batch):
        specs = []
        beat_phases = []
        beat_nums = []
        difficulties = []
        onsets = []
        actions = []

        for x in batch:
            specs.append(x[0])
            beat_phases.append(x[1])
            beat_nums.append(x[2])
            difficulties.append(x[3])
            onsets.append(x[4])
            actions.append(x[5])

        specs = torch.stack(specs)
        beat_phases = torch.stack(beat_phases)
        beat_nums = torch.stack(beat_nums)
        difficulties = torch.stack(difficulties)
        onsets = torch.stack(onsets)
        actions = torch.stack(actions)

        return specs, beat_phases, beat_nums, difficulties, onsets, actions

## Baseline Model

In [5]:
class OsuModel(nn.Module):
    def __init__(self, device, bp_emb_dim=16, bn_emb_dim=8, diff_emb_dim=8,
                 np_hidden_size=256, np_num_layers=2, ns_pre_proj_size=32,
                 ns_hidden_size=256, ns_num_layers=2, num_tokens=256):
        super().__init__()
        self.device = device
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(in_channels=3,
                              out_channels=1,
                              kernel_size=(15, 5),
                              padding='same')
        self.beat_phase_emb = nn.Embedding(49, bp_emb_dim)
        self.beat_num_emb = nn.Embedding(4, bn_emb_dim)
        self.difficulty_emb = nn.Embedding(21, diff_emb_dim)

        self.np_gru = nn.GRU(input_size=80 + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=np_hidden_size,
                             num_layers=np_num_layers,
                             batch_first=True,
                             bidirectional=True)
        self.np_proj_1 = nn.Linear(np_hidden_size*2, 128)
        self.np_proj_2 = nn.Linear(128, 1)

        self.ns_pre_proj = nn.Linear(128, ns_pre_proj_size)
        self.ns_gru = nn.GRU(input_size=80 + ns_pre_proj_size + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=ns_hidden_size,
                             num_layers=ns_num_layers,
                             batch_first=True,
                             bidirectional=False)
        self.ns_proj_1 = nn.Linear(ns_hidden_size, ns_hidden_size)
        self.ns_proj_2 = nn.Linear(ns_hidden_size, num_tokens)

    def forward(self, specs, beat_phases, beat_nums, difficulties, lengths):
        conv_outs = [self.gelu(self.conv(spec)).squeeze() for spec in specs]
        bp_emb = unpad_sequence(self.beat_phase_emb(beat_phases).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        bn_emb = unpad_sequence(self.beat_num_emb(beat_nums).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        diff = difficulties.unsqueeze(1).expand(-1, lengths.max().item())
        diff_emb = unpad_sequence(self.difficulty_emb(diff).to('cpu'),
                                  lengths.to('cpu'), batch_first=True)

        # ========== Note Placement ========== #

        np_in = []
        for i in range(len(lengths)):
            np_in.append(torch.cat([conv_outs[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        np_in_packed = pack_sequence(np_in, enforce_sorted=False)
        np_out, last_hidden = self.np_gru(np_in_packed)
        np_out_padded, _ = pad_packed_sequence(np_out, batch_first=True)

        np_proj_1_out = self.gelu(self.np_proj_1(np_out_padded))
        np_pred = self.sigmoid(self.np_proj_2(np_proj_1_out)).squeeze()

        # ========== Note Selection ========== #

        ns_pre_proj_padded = self.ns_pre_proj(np_proj_1_out)
        ns_pre_proj = unpad_sequence(ns_pre_proj_padded, lengths, batch_first=True)
        ns_in = []
        for i in range(len(lengths)):
            ns_in.append(torch.cat([conv_outs[i],
                                    ns_pre_proj[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        ns_in_packed = pack_sequence(ns_in, enforce_sorted=False)
        ns_out, last_hidden = self.ns_gru(ns_in_packed)
        ns_out_padded, _ = pad_packed_sequence(ns_out, batch_first=True)

        ns_proj_1_out = self.gelu(self.ns_proj_1(ns_out_padded))
        ns_logit = self.ns_proj_2(ns_proj_1_out)

        return np_pred, ns_logit

## ConvStack Model

In [4]:
class StackModel(nn.Module):
    def __init__(self, num_stacks=7, conv_hidden=16,
                 bp_emb_dim=16, bn_emb_dim=8, diff_emb_dim=8,
                 np_hidden_size=256, np_num_layers=2, ns_pre_proj_size=32,
                 ns_hidden_size=256, ns_num_layers=2,
                 num_tokens=256, action_emb_dim=32):
        super().__init__()
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.stack = nn.Sequential(
            nn.Conv2d(3, 8, (5, 3), stride=(1, 2), padding=(2, 1)),
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 16, (5, 3), stride=(1, 2), padding=(2, 1)),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 32, (5, 3), stride=(1, 2), padding=(2, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, (5, 3), stride=(1, 2), padding=(2, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.beat_phase_emb = nn.Embedding(49, bp_emb_dim)
        self.beat_num_emb = nn.Embedding(4, bn_emb_dim)
        self.difficulty_proj = nn.Sequential(
            nn.Linear(1, diff_emb_dim),
            nn.ReLU(),
            nn.Linear(diff_emb_dim, diff_emb_dim),
            nn.ReLU()
        )
        self.action_emb = nn.Embedding(num_tokens, action_emb_dim)

        self.np_gru = nn.GRU(input_size=320 + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=np_hidden_size,
                             num_layers=np_num_layers,
                             batch_first=True,
                             bidirectional=True)
        self.np_proj_1 = nn.Linear(np_hidden_size*2, 128)
        self.np_proj_2 = nn.Linear(128, 1)

        self.ns_pre_proj = nn.Linear(128, ns_pre_proj_size)
        self.ns_gru = nn.GRU(input_size=320 + ns_pre_proj_size + bp_emb_dim + bn_emb_dim + diff_emb_dim + action_emb_dim,
                             hidden_size=ns_hidden_size,
                             num_layers=ns_num_layers,
                             batch_first=True,
                             bidirectional=False)
        self.ns_proj_1 = nn.Linear(ns_hidden_size, ns_hidden_size)
        self.ns_proj_2 = nn.Linear(ns_hidden_size, num_tokens)

    def forward(self, specs, beat_phases, beat_nums, difficulties, actions):
        conv_outs = self.stack(specs)
        conv_outs = conv_outs.permute(0, 2, 1, 3).reshape(conv_outs.shape[0], conv_outs.shape[2], -1)
        bp_emb = self.beat_phase_emb(beat_phases)
        bn_emb = self.beat_num_emb(beat_nums)
        diff_proj = self.difficulty_proj(difficulties)

        # ========== Note Placement ========== #
        np_in = torch.cat([conv_outs, bp_emb, bn_emb, diff_proj], dim=-1)
        np_out, last_hidden = self.np_gru(np_in)

        np_proj_1_out = self.gelu(self.np_proj_1(np_out))
        np_pred = self.sigmoid(self.np_proj_2(np_proj_1_out)).squeeze()

        # ========== Note Selection ========== #
        ns_pre_proj = self.gelu(self.ns_pre_proj(np_proj_1_out))
        action_emb = self.action_emb(actions)
        ns_in = torch.cat(
            [conv_outs, ns_pre_proj, bp_emb, bn_emb, diff_proj, action_emb], dim=-1)
        ns_out, ns_last_hidden = self.ns_gru(ns_in)

        ns_proj_1_out = self.gelu(self.ns_proj_1(ns_out))
        ns_logit = self.ns_proj_2(ns_proj_1_out)

        return np_pred, ns_logit, ns_last_hidden

In [None]:
def infer(model, specs, beat_phases, beat_nums, difficulties, device):
    specs = specs.to(device)
    beat_phases = beat_phases.to(device)
    beat_nums = beat_nums.to(device)
    difficulties = difficulties.to(device)

    model.eval()
    with torch.inference_mode():
        conv_outs = model.stack(specs)
        conv_outs = conv_outs.permute(0, 2, 1, 3).flatten(2,3)
        bp_emb = model.beat_phase_emb(beat_phases)
        bn_emb = model.beat_num_emb(beat_nums)
        diff_proj = model.difficulty_proj(difficulties)

        np_in = torch.cat([conv_outs, bp_emb, bn_emb, diff_proj], dim=-1)
        np_out, _ = model.np_gru(np_in)
        np_proj_1_out = model.gelu(model.np_proj_1(np_out))
        ns_pre_proj = model.gelu(model.ns_pre_proj(np_proj_1_out))
        
        out = torch.zeros([specs.shape[0], specs.shape[2]], device=device)
        action_emb = model.action_emb(torch.zeros([specs.shape[0], 1], device=device, dtype=torch.long))
        last_hidden = torch.zeros([model.ns_num_layers, specs.shape[0], model.ns_hidden_size], device=device)

        ns_in = torch.cat([conv_outs, ns_pre_proj, bp_emb, bn_emb, diff_proj], dim=-1)
        for i in specs.shape[2]:
            ns_in_temp = torch.cat([
                ns_in[:, i:i+1], action_emb
            ]) # N x 1 x C
            ns_out, last_hidden = model.ns_gru(ns_in_temp, last_hidden)
            ns_proj_1_out = model.gelu(model.ns_proj_1(ns_out))
            ns_logit = model.ns_proj_2(ns_proj_1_out)
            ns_pred = ns_logit.argmax(dim=-1)
            out[:, i] = ns_pred
            action_emb = model.action_emb(ns_pred)

    return out

In [5]:
class Trainer():
    def __init__(self, model, optimizer, train_loader,
                 valid_loader, device, checkpoint_path: Path,
                 np_fl_gamma=2, np_fl_weight=0.8,
                 ns_fl_gamma=2, ns_fl_weight=0.8,
                 np_loss_multiplier=7):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.device = device
        self.checkpoint_path = checkpoint_path
        checkpoint_path.mkdir(exist_ok=True)
        self.np_fl_gamma = np_fl_gamma
        self.np_fl_weight = np_fl_weight
        self.ns_fl_gamma = ns_fl_gamma
        self.ns_fl_weight = ns_fl_weight
        self.np_loss_multiplier = np_loss_multiplier
        self.start_epoch = 0
        self.f1 = BinaryF1Score().to(self.device)
        # TODO: metrics (Perplexity, F-score, AUC...)

    def load_checkpoint(self, fn):
        checkpoint = torch.load(fn, map_location=self.device)
        self.start_epoch = checkpoint['epoch'] + 1
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if (torch.is_tensor(v)):
                    state[k] = v.to(self.device)

    def save_checkpoint(self, epoch, fn):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()
        }
        torch.save(checkpoint, fn)

    def binary_focal_loss(self, y, pred, gamma, pos_weight):
        """
        Biary focal loss for when y=1 is the minority class.\n
        INPUT
            gamma: factor for suppressing loss for easy examples (gamma > 1)
            pos_weight: how much to suppress loss when y=0 (0 <= pos_weight <= 1)
        """
        return -(y * (1-pred).pow(gamma) * pred.log() +
                 pos_weight * (1-y) * pred.pow(gamma) * (1-pred).log()).mean()

    def multi_focal_loss(self, y, pred, gamma, pos_weight):
        """
        Multi-class focal loss for when y=0 is the majority class.\n
        INPUT
            gamma: factor for suppressing loss for easy examples (gamma > 1)
            pos_weight: how much to suppress loss when y=0 (0 <= pos_weight <= 1)
        """
        p_y = pred[torch.arange(len(pred)), y]
        weight_mask = torch.where(y == 0, pos_weight, 1)
        return -(weight_mask * (1 - p_y).pow(gamma) * p_y.log()).mean()

    def train(self, num_epochs):
        self.model.to(self.device)
        for epoch in tqdm(range(self.start_epoch, num_epochs)):
            self.model.train()
            for batch in tqdm(self.train_loader, leave=False):
                specs, beat_phases, beat_nums, difficulties, onsets, actions = batch
                specs = specs.to(self.device)
                beat_phases = beat_phases.to(self.device)
                beat_nums = beat_nums.to(self.device)
                difficulties = difficulties.to(self.device)
                onsets = onsets.to(self.device)
                actions_gt = actions[:, 1:].to(self.device)
                actions_shifted = actions[:, :-1].to(self.device)

                np_pred, ns_logit, _ = self.model(
                    specs, beat_phases, beat_nums, difficulties, actions_shifted)

                np_pred = torch.reshape(np_pred, [-1])
                np_label = torch.reshape(onsets, [-1])

                ns_pred = torch.reshape(
                    ns_logit, [-1, ns_logit.shape[-1]]).softmax(dim=-1)
                ns_label = torch.reshape(actions_gt, [-1])

                np_loss = self.binary_focal_loss(
                    np_label, np_pred, self.np_fl_gamma, self.np_fl_weight) * self.np_loss_multiplier
                ns_loss = self.multi_focal_loss(
                    ns_label, ns_pred, self.ns_fl_gamma, self.ns_fl_weight)

                batch_loss = np_loss + ns_loss
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()

                ns_acc = (ns_pred.argmax(dim=-1) == ns_label).float().mean()
                wandb.log({'train_np_loss': np_loss.item(),
                           'train_ns_loss': ns_loss.item(),
                           'train_loss': batch_loss.item(),
                           'train_acc': ns_acc.item(),
                           'train_np_f1': self.f1(np_pred, np_label.int()).item()})

            self.model.eval()
            with torch.inference_mode():
                valid_np_loss_sum = 0
                valid_ns_loss_sum = 0
                valid_loss_sum = 0
                valid_np_f1_sum = 0
                valid_ns_acc_sum = 0
                for batch in tqdm(self.valid_loader, leave=False):
                    specs, beat_phases, beat_nums, difficulties, onsets, actions = batch
                    specs = specs.to(self.device)
                    beat_phases = beat_phases.to(self.device)
                    beat_nums = beat_nums.to(self.device)
                    difficulties = difficulties.to(self.device)
                    onsets = onsets.to(self.device)
                    actions_gt = actions[:, 1:].to(self.device)
                    actions_shifted = actions[:, :-1].to(self.device)

                    np_pred, ns_logit, _ = self.model(
                        specs, beat_phases, beat_nums, difficulties, actions_shifted)

                    np_pred = torch.reshape(np_pred, [-1])
                    np_label = torch.reshape(onsets, [-1])

                    ns_pred = torch.reshape(
                        ns_logit, [-1, ns_logit.shape[-1]]).softmax(dim=-1)
                    ns_label = torch.reshape(actions_gt, [-1])

                    np_loss = self.binary_focal_loss(
                        np_label, np_pred, self.np_fl_gamma, self.np_fl_weight) * self.np_loss_multiplier
                    ns_loss = self.multi_focal_loss(
                        ns_label, ns_pred, self.ns_fl_gamma, self.ns_fl_weight)

                    batch_loss = np_loss + ns_loss
                    valid_np_loss_sum += np_loss.item()
                    valid_ns_loss_sum += ns_loss.item()
                    valid_loss_sum += batch_loss.item()
                    ns_acc = (ns_pred.argmax(dim=-1) == ns_label).float().mean()
                    valid_ns_acc_sum += ns_acc.item()
                    valid_np_f1_sum += self.f1(np_pred, np_label.int()).item()

                wandb.log({'valid_np_loss': valid_np_loss_sum / len(self.valid_loader),
                           'valid_ns_loss': valid_ns_loss_sum / len(self.valid_loader),
                           'valid_loss': valid_loss_sum / len(self.valid_loader),
                           'valid_acc': valid_ns_acc_sum / len(self.valid_loader),
                           'valid_np_f1': valid_np_f1_sum / len(self.valid_loader)})

            time = datetime.now().strftime('%m-%d-%H-%M-%S')
            checkpoint_path = Path(
                self.checkpoint_path / f'{time}-epoch{epoch}.pt')
            self.save_checkpoint(epoch, checkpoint_path)

## Train Baseline Model

In [7]:
beatmap_path = Path('osu_dataset/beatmap/4keys/')
audio_path = Path('osu_dataset/audio/')

base_set = OsuDataset(beatmap_path, audio_path)
generator = torch.Generator().manual_seed(RANDOM_SEED)
train_set, valid_set = torch.utils.data.random_split(
    base_set, [0.8, 0.2], generator)

collater = PadCollater()
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=12, shuffle=True, generator=generator, collate_fn=collater, drop_last=True)
valid_loader = torch.utils.data.DataLoader(
    valid_set, batch_size=4, shuffle=False, collate_fn=collater, drop_last=True)

model = OsuModel(device=DEV)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
trainer = Trainer(model, optimizer, train_loader, valid_loader,
                  device=DEV, checkpoint_path=Path('checkpoints'),
                  ns_fl_gamma=2, ns_fl_weight=0.8,
                  np_fl_gamma=3, np_fl_weight=0.3)

## Train ConvStack Model

In [None]:
NUM_CONV_STACKS = 7
CONV_HIDDEN = 16
BP_EMB_DIM = 16
BN_EMB_DIM = 8
DIFF_EMB_DIM = 8
NP_HIDDEN_SIZE = 256
NP_NUM_LAYERS = 2
NS_PRE_PROJ_SIZE = 32
NS_HIDDEN_SIZE = 256
NS_NUM_LAYERS = 2
ACTION_EMB_DIM = 32
NP_FL_GAMMA = 2
NP_FL_WEIGHT = 0.8
NS_FL_GAMMA = 3
NS_FL_WEIGHT = 0.5
LEARNING_RATE = 1e-4
BATCH_SIZE = 6
NUM_EPOCHS = 60

wandb.init(project='AutoOsu',
           config={
               'num_conv_stacks': NUM_CONV_STACKS,
               'conv_hidden': CONV_HIDDEN,
               'bp_emb_dim': BP_EMB_DIM,
               'bn_emb_dim': BN_EMB_DIM,
               'diff_emb_dim': DIFF_EMB_DIM,
               'np_hidden_size': NP_HIDDEN_SIZE,
               'np_num_layers': NP_NUM_LAYERS,
               'ns_pre_proj_size': NS_PRE_PROJ_SIZE,
               'ns_hidden_size': NS_HIDDEN_SIZE,
               'ns_num_layers': NS_NUM_LAYERS,
               'action_emb_dim': ACTION_EMB_DIM,
               'learning_rate': LEARNING_RATE,
               'batch_size': BATCH_SIZE,
               'num_epochs': NUM_EPOCHS
           })

beatmap_path = Path('osu_dataset/beatmap/4keys/')
audio_path = Path('osu_dataset/audio/')

base_set = OsuDataset(beatmap_path, audio_path)
generator = torch.Generator().manual_seed(RANDOM_SEED)
train_set, valid_set = torch.utils.data.random_split(
    base_set, [0.8, 0.2], generator)

collater = PadCollater()
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True, generator=generator, collate_fn=collater, drop_last=False)
valid_loader = torch.utils.data.DataLoader(
    valid_set, batch_size=4, shuffle=False, collate_fn=collater, drop_last=False)

model = StackModel(num_stacks=NUM_CONV_STACKS, conv_hidden=CONV_HIDDEN,
                   bp_emb_dim=BP_EMB_DIM, bn_emb_dim=BN_EMB_DIM, diff_emb_dim=DIFF_EMB_DIM,
                   np_hidden_size=NP_HIDDEN_SIZE, np_num_layers=NP_NUM_LAYERS,
                   ns_pre_proj_size=NS_PRE_PROJ_SIZE, ns_hidden_size=NS_HIDDEN_SIZE, ns_num_layers=NS_NUM_LAYERS,
                   action_emb_dim=ACTION_EMB_DIM)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
trainer = Trainer(model, optimizer, train_loader, valid_loader,
                  device=DEV, checkpoint_path=Path('checkpoints/convstack/'),
                  np_fl_gamma=NP_FL_GAMMA, np_fl_weight=NP_FL_WEIGHT,
                  ns_fl_gamma=NS_FL_GAMMA, ns_fl_weight=NS_FL_WEIGHT)

trainer.train(NUM_EPOCHS)
wandb.finish()

In [None]:
wandb.init(project='AutoOsu', id='wwjh1mq4', resume='must')
trainer.start_epoch = 40
trainer.train(60)
wandb.finish()

In [None]:
wandb.finish()