In [None]:
# Debug
%load_ext autoreload
%autoreload 2
from utils.DatasetGenerator import get_X, get_y, get_key_sequences, read_sequences
path = '../Data/train_key/beyond'
df_X = get_X(path)
df_y_key = get_y(path, 'key')
get_key_sequences([df_X, df_y_key, path], 10, 500)
input_seq, target_seq = read_sequences(path, 'key')
pass
# sequences = read_sequences(path)

In [None]:
%load_ext autoreload
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
data_paths = {
    'train': '../Data/train_key',
    'valid': '../Data/valid_key'
}

In [2]:
%autoreload 2
# Data parsing
from utils.DatasetGenerator import get_set
datasets = {
    'train': None,
    'valid': None
}

for idx, set in enumerate(data_paths):
    print(f'Converting {set} data')
    data = get_set(data_paths[set], 'key', regenerate=True)
    
    datasets[set] = data

Converting train data


100%|██████████| 11/11 [00:03<00:00,  3.59it/s]


Converting valid data


100%|██████████| 4/4 [00:01<00:00,  2.54it/s]


In [3]:
# Sequence generation
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*weights_only=False.*")
from utils.DatasetGenerator import create_set_sequences, get_set_sequences
from itertools import chain

N = 10
window = 500
for idx, set in enumerate(datasets):
    if len(datasets[set]) > 0:
        print(f'Generating {set} sequences')
        create_set_sequences(datasets[set], N, window, 'key')

combined_train_key_input = None
combined_train_key_target = None
valid_sequences = None
for idx, set in enumerate(data_paths):
    print(f'Retrieving {set} sequences')
    key_input_sequences, key_target_sequences = get_set_sequences(data_paths[set], 'key')
    if set == 'train':
        combined_train_key_input = list(chain.from_iterable(key_input_sequences))
        combined_train_key_target = list(chain.from_iterable(key_target_sequences))
    else:
        valid_sequences = [key_input_sequences, key_target_sequences]

Generating train sequences


100%|██████████| 11/11 [00:03<00:00,  3.12it/s]


Generating valid sequences


100%|██████████| 4/4 [00:01<00:00,  2.43it/s]


Retrieving train sequences


100%|██████████| 11/11 [00:01<00:00,  5.77it/s]


Retrieving valid sequences


100%|██████████| 4/4 [00:00<00:00, 15.62it/s]


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from Model.OSUDataset import OsuDataset, collate_fn
data_paths = {
    'train': '../Data/train_key',
    'valid': '../Data/valid_key'
}
BATCH_SIZE_TRAIN = 512
if os.path.exists('train_dataset.pth'):
    train_dataset = torch.load('./key_dataset/train_dataset.pth')
else:
    train_dataset = OsuDataset(combined_train_key_input, combined_train_key_target)
    torch.save(train_dataset, './key_dataset/train_dataset.pth')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, drop_last=True, collate_fn=collate_fn, pin_memory=True)

BATCH_SIZE_VALID = 64
valid_loaders = []
valid_folders = os.listdir(data_paths['valid'])
valid_data_size = sum(os.path.isdir(os.path.join(data_paths['valid'], folder)) for folder in valid_folders)
for idx in range(valid_data_size):
    if os.path.exists(f'./key_dataset/valid_dataset_{idx}.pth'):
        valid_dataset = torch.load(f'./key_dataset/valid_dataset_{idx}.pth')
    else:
        valid_dataset = valid_dataset = OsuDataset(valid_sequences[0][idx], valid_sequences[1][idx])
        torch.save(valid_dataset, f'./key_dataset/valid_dataset_{idx}.pth')
        
    valid_loader =  DataLoader(valid_dataset, batch_size=BATCH_SIZE_VALID, shuffle=False, drop_last=True, collate_fn=collate_fn, pin_memory=True)
    valid_loaders.append(valid_loader)

In [5]:
%autoreload 2
from Model.OSUModel import KeypressEncoder, KeypressDecoder, OSUModelKey
key_input_size = 4
key_hidden_size = 16
key_num_layers = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

key_encoder = KeypressEncoder(key_input_size, key_hidden_size, key_num_layers).to(device)
key_decoder = KeypressDecoder(key_hidden_size, key_num_layers).to(device)

osuModelKey = OSUModelKey(key_encoder, key_decoder, device)
key_criterion = nn.CrossEntropyLoss(reduction='none')

key_optimizer = torch.optim.Adam(
    osuModelKey.parameters(),
    lr=0.01,
    weight_decay=1e-5
)

key_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(key_optimizer, mode='min',
                                                    factor=0.1, patience=25,
                                                    threshold=1e-4)

def masked_loss_key(key_outputs, targets, target_lengths):    
    # Keypress targets (batch_size, sequence_len - 1, num_keys)
    key_targets = targets[:, 1:, :]
    
    # Creating mask for variable length sequences
    max_len = key_outputs.size(1)
    #(batch_size, max_len). True when value is less than target length
    mask = torch.arange(max_len).unsqueeze(0) < target_lengths.unsqueeze(1) 
    mask = mask.to(key_outputs.device)
    
    # Keypress loss with cross entropy loss
    key_targets_idx = torch.argmax(key_targets, dim=2)
    key_loss = key_criterion(key_outputs.permute(0, 2, 1), key_targets_idx)
    key_loss = (key_loss * mask).sum() / mask.sum()
    
    return key_loss

In [6]:
num_epochs = 1000
total_train_len = len(train_loader)
total_valid_len = 0
for valid_loader in valid_loaders:
    total_valid_len += len(valid_loader)

for epoch in range(num_epochs):
    osuModelKey.train()
    train_loss_key = 0
    teacher_forcing_ratio = max(1 - (epoch / 20), 0)
    
    # Training
    for train_inputs, train_targets, input_lengths, target_lengths in train_loader:
        train_inputs, train_targets = train_inputs.to(device), train_targets.to(device)
        
        key_optimizer.zero_grad()
        
        # Forward propagate
        train_key_outputs = osuModelKey(train_inputs, train_targets, teacher_forcing_ratio)
        
        # Compute losses
        loss_key = masked_loss_key(train_key_outputs, train_targets, target_lengths)
        
        # Accumulate losses
        train_loss_key += loss_key.item()
        
        # Keypress encoder back propagation
        loss_key.backward()
        
        # Gradient clipping 
        torch.nn.utils.clip_grad_norm_(osuModelKey.parameters(), max_norm=1.0)
        
        # Updating weights
        key_optimizer.step()
        
    train_loss_key /= total_train_len
    
    # Validation
    osuModelKey.eval()
    total_valid_loss_key = 0
    
    with torch.no_grad():
        for valid_loader in valid_loaders:
            valid_loss_pos = 0
            valid_loss_key = 0
            for valid_inputs, valid_targets, input_lengths, target_lengths in valid_loader:
                valid_inputs, valid_targets = valid_inputs.to(device), valid_targets.to(device)

                valid_key_outputs = osuModelKey(valid_inputs, valid_targets)
                
                loss_key = masked_loss_key(valid_key_outputs, valid_targets, target_lengths)
                valid_loss_key += loss_key.item()
                
            total_valid_loss_key += valid_loss_key
            
    total_valid_loss_key /= total_valid_len
    
    key_scheduler.step(total_valid_loss_key)
    
    key_lr = key_optimizer.param_groups[0]['lr']
    
    # Printing info
    print(f"Epoch [{epoch+1}/{num_epochs}]"
    f"T Loss Key: {train_loss_key:.4f} | "
    f"V Loss Key: {total_valid_loss_key:.4f} | "
    f"LR Key: {key_lr} | "
    f"TFR: {teacher_forcing_ratio:.2f}")


Epoch [1/1000]T Loss Key: 0.4726 | V Loss Key: 0.4438 | LR Key: 0.01 | TFR: 1.00
Epoch [2/1000]T Loss Key: 0.2352 | V Loss Key: 0.4173 | LR Key: 0.01 | TFR: 0.95
Epoch [3/1000]T Loss Key: 0.1970 | V Loss Key: 0.4096 | LR Key: 0.01 | TFR: 0.90
Epoch [4/1000]T Loss Key: 0.1694 | V Loss Key: 0.4696 | LR Key: 0.01 | TFR: 0.85
Epoch [5/1000]T Loss Key: 0.1470 | V Loss Key: 0.4887 | LR Key: 0.01 | TFR: 0.80
Epoch [6/1000]T Loss Key: 0.1349 | V Loss Key: 0.5573 | LR Key: 0.01 | TFR: 0.75
Epoch [7/1000]T Loss Key: 0.1280 | V Loss Key: 0.3795 | LR Key: 0.01 | TFR: 0.70
Epoch [8/1000]T Loss Key: 0.1210 | V Loss Key: 0.5004 | LR Key: 0.01 | TFR: 0.65
Epoch [9/1000]T Loss Key: 0.1221 | V Loss Key: 0.4353 | LR Key: 0.01 | TFR: 0.60
Epoch [10/1000]T Loss Key: 0.1164 | V Loss Key: 0.3981 | LR Key: 0.01 | TFR: 0.55
Epoch [11/1000]T Loss Key: 0.1161 | V Loss Key: 0.3985 | LR Key: 0.01 | TFR: 0.50
Epoch [12/1000]T Loss Key: 0.1111 | V Loss Key: 0.3093 | LR Key: 0.01 | TFR: 0.45
Epoch [13/1000]T Loss Key

KeyboardInterrupt: 

In [7]:
VERSION = '0.3'
torch.save(osuModelKey.state_dict(), f'./OSU_model_key_{VERSION}.pth')     