In [1]:
from pathlib import Path
from datetime import datetime
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, unpad_sequence, pack_sequence, pad_packed_sequence
from torchmetrics.classification import BinaryF1Score
from torchmetrics.functional.text.perplexity import perplexity
from tqdm.auto import tqdm
from utils import index_to_combination

torch.set_printoptions(sci_mode=False)
RANDOM_SEED = 1
DEV = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class OsuModel(nn.Module):
    def __init__(self, device, bp_emb_dim=16, bn_emb_dim=8, diff_emb_dim=8,
                 np_hidden_size=256, np_num_layers=2, ns_pre_proj_size=32,
                 ns_hidden_size=256, ns_num_layers=2, num_tokens=256):
        super().__init__()
        self.device = device
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.conv = nn.Conv2d(in_channels=3,
                              out_channels=1,
                              kernel_size=(15, 5),
                              padding='same')
        self.beat_phase_emb = nn.Embedding(49, bp_emb_dim)
        self.beat_num_emb = nn.Embedding(4, bn_emb_dim)
        self.difficulty_emb = nn.Embedding(21, diff_emb_dim)

        self.np_gru = nn.GRU(input_size=80 + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=np_hidden_size,
                             num_layers=np_num_layers,
                             batch_first=True,
                             bidirectional=True)
        self.np_proj_1 = nn.Linear(np_hidden_size*2, 128)
        self.np_proj_2 = nn.Linear(128, 1)

        self.ns_pre_proj = nn.Linear(128, ns_pre_proj_size)
        self.ns_gru = nn.GRU(input_size=80 + ns_pre_proj_size + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=ns_hidden_size,
                             num_layers=ns_num_layers,
                             batch_first=True,
                             bidirectional=False)
        self.ns_proj_1 = nn.Linear(ns_hidden_size, ns_hidden_size)
        self.ns_proj_2 = nn.Linear(ns_hidden_size, num_tokens)

    def forward(self, specs, beat_phases, beat_nums, difficulties, lengths):
        conv_outs = [self.gelu(self.conv(spec)).squeeze() for spec in specs]
        bp_emb = unpad_sequence(self.beat_phase_emb(beat_phases).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        bn_emb = unpad_sequence(self.beat_num_emb(beat_nums).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        diff = difficulties.unsqueeze(1).expand(-1, lengths.max().item())
        diff_emb = unpad_sequence(self.difficulty_emb(diff).to('cpu'),
                                  lengths.to('cpu'), batch_first=True)

        # ========== Note Placement ========== #

        np_in = []
        for i in range(len(lengths)):
            np_in.append(torch.cat([conv_outs[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        np_in_packed = pack_sequence(np_in, enforce_sorted=False)
        np_out, last_hidden = self.np_gru(np_in_packed)
        np_out_padded, _ = pad_packed_sequence(np_out, batch_first=True)

        np_proj_1_out = self.gelu(self.np_proj_1(np_out_padded))
        np_pred = self.sigmoid(self.np_proj_2(np_proj_1_out)).squeeze()

        # ========== Note Selection ========== #

        ns_pre_proj_padded = self.ns_pre_proj(np_proj_1_out)
        ns_pre_proj = unpad_sequence(ns_pre_proj_padded, lengths, batch_first=True)
        ns_in = []
        for i in range(len(lengths)):
            ns_in.append(torch.cat([conv_outs[i],
                                    ns_pre_proj[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        ns_in_packed = pack_sequence(ns_in, enforce_sorted=False)
        ns_out, last_hidden = self.ns_gru(ns_in_packed)
        ns_out_padded, _ = pad_packed_sequence(ns_out, batch_first=True)

        ns_proj_1_out = self.gelu(self.ns_proj_1(ns_out_padded))
        ns_logit = self.ns_proj_2(ns_proj_1_out)

        return np_pred, ns_logit

In [6]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, hidden_size):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, hidden_size, 3, padding=1),
            nn.LayerNorm(80),
            nn.ReLU(),
            nn.Conv2d(hidden_size, in_channels, 1),
            nn.LayerNorm(80),
            nn.ReLU()
        )
    
    def forward(self, x):
        return self.block(x) + x

class ConvStack(nn.Module):
    def __init__(self, num_blocks, in_channels, hidden_size, out_channels):
        super().__init__()
        self.relu = nn.ReLU()
        self.first_conv = nn.Conv2d(in_channels, hidden_size, 1)
        self.first_norm = nn.LayerNorm(80)

        self.stack = nn.ModuleList([
            ConvBlock(hidden_size, hidden_size) for _ in range(num_blocks)
        ])

        self.last_conv = nn.Conv2d(hidden_size, out_channels, 1)
        self.last_norm = nn.LayerNorm(80)

    def forward(self, x):
        x = self.relu(self.first_norm(self.first_conv(x)))
        for block in self.stack:
            x = block(x)
        return self.relu(self.last_norm(self.last_conv(x)))

class StackModel(nn.Module):
    def __init__(self, device, num_stacks=7, bp_emb_dim=16, bn_emb_dim=8, diff_emb_dim=8,
                 np_hidden_size=256, np_num_layers=2, ns_pre_proj_size=32,
                 ns_hidden_size=256, ns_num_layers=2, num_tokens=256):
        super().__init__()
        self.device = device
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.stack = ConvStack(num_stacks, 3, 16, 1)
        self.beat_phase_emb = nn.Embedding(49, bp_emb_dim)
        self.beat_num_emb = nn.Embedding(4, bn_emb_dim)
        self.difficulty_emb = nn.Embedding(21, diff_emb_dim)

        self.np_gru = nn.GRU(input_size=80 + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=np_hidden_size,
                             num_layers=np_num_layers,
                             batch_first=True,
                             bidirectional=True)
        self.np_proj_1 = nn.Linear(np_hidden_size*2, 128)
        self.np_proj_2 = nn.Linear(128, 1)

        self.ns_pre_proj = nn.Linear(128, ns_pre_proj_size)
        self.ns_gru = nn.GRU(input_size=80 + ns_pre_proj_size + bp_emb_dim + bn_emb_dim + diff_emb_dim,
                             hidden_size=ns_hidden_size,
                             num_layers=ns_num_layers,
                             batch_first=True,
                             bidirectional=False)
        self.ns_proj_1 = nn.Linear(ns_hidden_size, ns_hidden_size)
        self.ns_proj_2 = nn.Linear(ns_hidden_size, num_tokens)

    def forward(self, specs, beat_phases, beat_nums, difficulties, lengths):
        conv_outs = [self.stack(spec).squeeze() for spec in specs]
        bp_emb = unpad_sequence(self.beat_phase_emb(beat_phases).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        bn_emb = unpad_sequence(self.beat_num_emb(beat_nums).to('cpu'),
                                lengths.to('cpu'), batch_first=True)
        diff = difficulties.unsqueeze(1).expand(-1, lengths.max().item())
        diff_emb = unpad_sequence(self.difficulty_emb(diff).to('cpu'),
                                  lengths.to('cpu'), batch_first=True)

        # ========== Note Placement ========== #

        np_in = []
        for i in range(len(lengths)):
            np_in.append(torch.cat([conv_outs[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        np_in_packed = pack_sequence(np_in, enforce_sorted=False)
        np_out, last_hidden = self.np_gru(np_in_packed)
        np_out_padded, _ = pad_packed_sequence(np_out, batch_first=True)

        np_proj_1_out = self.gelu(self.np_proj_1(np_out_padded))
        np_pred = self.sigmoid(self.np_proj_2(np_proj_1_out)).squeeze()

        # ========== Note Selection ========== #

        ns_pre_proj_padded = self.ns_pre_proj(np_proj_1_out)
        ns_pre_proj = unpad_sequence(ns_pre_proj_padded, lengths, batch_first=True)
        ns_in = []
        for i in range(len(lengths)):
            ns_in.append(torch.cat([conv_outs[i],
                                    ns_pre_proj[i],
                                    bp_emb[i].to(self.device),
                                    bn_emb[i].to(self.device),
                                    diff_emb[i].to(self.device)],
                                   dim=-1))
        ns_in_packed = pack_sequence(ns_in, enforce_sorted=False)
        ns_out, last_hidden = self.ns_gru(ns_in_packed)
        ns_out_padded, _ = pad_packed_sequence(ns_out, batch_first=True)

        ns_proj_1_out = self.gelu(self.ns_proj_1(ns_out_padded))
        ns_logit = self.ns_proj_2(ns_proj_1_out)

        return np_pred, ns_logit

In [4]:
def infer(model, specs, beat_phases, beat_nums, difficulty, device):
    assert specs.shape[1] == beat_phases.shape[0] == beat_nums.shape[0]
    specs = specs.unsqueeze(0).to(device)
    beat_phases = beat_phases.unsqueeze(0).to(device)
    beat_nums = beat_nums.unsqueeze(0).to(device)
    diff = torch.tensor([difficulty]).to(device)
    lengths = torch.tensor([specs.shape[2]])

    model.eval()
    with torch.inference_mode():
        np_pred, ns_logit = model(specs, beat_phases, beat_nums, diff, lengths)
        np_pred = np_pred.squeeze().round().bool()
        ns_pred = ns_logit.squeeze().softmax(dim=-1)
        ns_pred_idx = ns_pred.argmax(dim=-1)
        
        return np_pred, ns_pred_idx

In [50]:
model = StackModel(DEV)
model.to(DEV)
state = torch.load(Path('checkpoints/convstack/07-17-22-26-31-epoch14.pt'), map_location=DEV)
model.load_state_dict(state['model_state_dict'])
actions, onsets, _, difficulty = torch.load(Path('osu_dataset/beatmap/4keys/1008060-202-4.pt')).values()
specs, beat_phase, beat_num = torch.load(Path('osu_dataset/audio/1008060.pt')).values()
np_pred, ns_pred_idx = infer(model, specs, beat_phase, beat_num, 11, DEV)

In [51]:
(np_pred > 0).sum()

tensor(1014, device='cuda:0')

In [52]:
(ns_pred_idx > 0).sum()

tensor(646, device='cuda:0')

In [51]:
index_to_combination(5, 4)

(1, 1, 0, 0)

In [53]:
beatmap = []
for i, token in enumerate(ns_pred_idx):
    if token.item() > 0:
        key0, key1, key2, key3 = index_to_combination(token.item(), 4)
        beatmap.append([i * 10, key0, key1, key2, key3])

In [54]:
beatmap_str_list = []
for action in beatmap:
    time = action[0]
    keys = action[1:]
    for idx, key in enumerate(keys):
        if key > 0:
            if key == 2:
                print('long note start!!')
            if key == 3:
                print('long note end!!')
            xpos = 64 + idx * 128
            beatmap_str_list.append(f'{xpos},192,{time},1,0,0:0:0:0:')

In [129]:
np_pred

tensor([False, False, False,  ..., False, False, False], device='cuda:0')

In [100]:
import random

beatmap_str_list = []
for idx, action in enumerate(np_pred):
    if action.item():
        xpos = 64 + random.randint(0, 3) * 128
        beatmap_str_list.append(f'{xpos},192,{idx * 10},1,0,0:0:0:0:')

In [55]:
file = open('beatmap.txt', 'w')
file.write('\n'.join(beatmap_str_list))
file.close()