In [1]:
# Debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.DatasetGenerator import get_X, get_y, get_pos_sequences, read_sequences
# path = '../Data/train_pos/annoy'
# df_y_pos, hr = get_y(path, 'pos')
# df_X = get_X(path, hr)
# get_pos_sequences([df_X, df_y_pos, path], 10, 500)
# input_seq, target_seq, target_obj = read_sequences(path, 'pos')
# pass

In [2]:
# # debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.Visualizer import visualize
# from utils.DataParser import parse_map, parse_replay_pos
# path = '../Data/train_pos/shink/'
# replay_path = list(Path(path).glob('*.osr'))[0]
# map_path = list(Path(path).glob('*.osu'))[0]
# replay_data, hr = parse_replay_pos(replay_path)
# hit_objects, _ = parse_map(map_path, hr)
# song = list(Path(path).glob('audio.*'))[0]
# visualize(hit_objects, replay_data, song)

In [3]:
# Data paths
%load_ext autoreload
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
REGENERATE = True

In [4]:
%autoreload 2
# Data parsing
from utils.DatasetGenerator import get_set
datasets = {
    'train': None,
    'valid': None
}

for idx, set in enumerate(data_paths):
    print(f'Converting {set} data')
    data = get_set(data_paths[set], 'pos', REGENERATE)
    
    datasets[set] = data

Converting train data


100%|██████████| 35/35 [00:04<00:00,  7.32it/s]


Converting valid data


100%|██████████| 3/3 [00:01<00:00,  1.90it/s]


In [5]:
# Sequence generation
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*weights_only=False.*")
from utils.DatasetGenerator import create_set_sequences, get_set_sequences
from itertools import chain

N = 5
window = 450
for idx, set in enumerate(datasets):
    if len(datasets[set]) > 0:
        print(f'Generating {set} sequences')
        create_set_sequences(datasets[set], N, window, 'pos')

combined_train_pos_input = None
combined_train_pos_target = None
combined_train_pos_object = None
valid_sequences = None
for idx, set in enumerate(data_paths):
    print(f'Retrieving {set} sequences')
    pos_input_sequences, pos_target_sequences, pos_target_objects = get_set_sequences(data_paths[set], 'pos')
    if set == 'train':
        combined_train_pos_input = list(chain.from_iterable(pos_input_sequences))
        combined_train_pos_target = list(chain.from_iterable(pos_target_sequences))
        combined_train_pos_object = list(chain.from_iterable(pos_target_objects))
    else:
        valid_sequences = [pos_input_sequences, pos_target_sequences, pos_target_objects]

Generating train sequences


100%|██████████| 35/35 [00:14<00:00,  2.40it/s]


Generating valid sequences


100%|██████████| 3/3 [00:03<00:00,  1.27s/it]


Retrieving train sequences


100%|██████████| 35/35 [00:51<00:00,  1.48s/it]


Retrieving valid sequences


100%|██████████| 3/3 [00:04<00:00,  1.51s/it]


In [6]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from Model.OSUDataset import OsuDataset, pos_collate_fn
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
BATCH_SIZE_TRAIN = 1024
if os.path.exists('./pos_dataset/train_dataset.pth') and not REGENERATE:
    train_dataset = torch.load('./pos_dataset/train_dataset.pth')
else:
    train_dataset = OsuDataset(combined_train_pos_input, combined_train_pos_target, combined_train_pos_object)
    torch.save(train_dataset, './pos_dataset/train_dataset.pth')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, drop_last=True, collate_fn=pos_collate_fn, pin_memory=True)

BATCH_SIZE_VALID = 64
valid_loaders = []
valid_folders = os.listdir(data_paths['valid'])
valid_data_size = sum(os.path.isdir(os.path.join(data_paths['valid'], folder)) for folder in valid_folders)
for idx in range(valid_data_size):
    if os.path.exists(f'./pos_dataset/valid_dataset_{idx}.pth') and not REGENERATE:
        valid_dataset = torch.load(f'./pos_dataset/valid_dataset_{idx}.pth')
    else:
        valid_dataset = valid_dataset = OsuDataset(valid_sequences[0][idx], valid_sequences[1][idx], valid_sequences[2][idx])
        torch.save(valid_dataset, f'./pos_dataset/valid_dataset_{idx}.pth')
        
    valid_loader =  DataLoader(valid_dataset, batch_size=BATCH_SIZE_VALID, shuffle=False, drop_last=True, collate_fn=pos_collate_fn, pin_memory=True)
    valid_loaders.append(valid_loader)

In [13]:
%autoreload 2
from Model.OSUModel import PositionEncoder, PositionDecoder, OSUModelPos
from torch.nn.utils.rnn import pad_packed_sequence
pos_input_size = 12

pos_hidden_size = 24
pos_num_layers = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoders take (input_size, hidden_size, num_layers)
pos_encoder = PositionEncoder(pos_input_size, pos_hidden_size, pos_num_layers).to(device)

# Decoders takes (hidden_size, num_layers)
pos_decoder = PositionDecoder(pos_hidden_size, pos_num_layers).to(device)

osuModelPos = OSUModelPos(pos_encoder, pos_decoder, device).to(device)

pos_criterion = nn.L1Loss(reduction='none')

pos_optimizer = torch.optim.Adam(
    osuModelPos.parameters(),
    lr = 0.01,
    weight_decay=1e-5
)


pos_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(pos_optimizer, mode='min',
                                                    factor=0.1, patience=15,
                                                    threshold=1e-4)

  
def get_pos_loss(pos_outputs, targets):
    # Position targets (batch_size, 2, 2)
    # Removing previous positions from targets
    pos_targets = targets[:, 1:, :]  

    # Position loss with absolute error
    pos_loss = pos_criterion(pos_outputs, pos_targets)
    
    # Sum the loss along the last dimension and compute the mean over the batch
    pos_loss = pos_loss.sum(dim=2).mean()

    return pos_loss

def get_hybrid_loss(pos_outputs, targets, train_objects):
    # Position targets (batch_size, 2, 2)
    # Removing previous positions from targets
    pos_targets = targets[:, 1:, :]  # Shape: (batch_size, 1, 2)
    
    # Position loss with absolute error
    pos_loss = pos_criterion(pos_outputs, pos_targets)  # Shape: (batch_size, 1, 2)
    
    # Sum the loss along the last dimension and compute the mean over the batch
    pos_loss = pos_loss.sum(dim=2).mean()
    
    # Object position (batch_size, 2)
    object_pos = train_objects[:, :2]  # Shape: (batch_size, 2)
    object_pos = object_pos.unsqueeze(1)  # Shape: (batch_size, 1, 2)
    
    # Time values (batch_size,)
    time_values = torch.abs(train_objects[:, 2])
    
    # Valid types (hit circles or slider start)
    valid = train_objects[:, 3]
    
    # Compute weights: weight increases from 0 to 1 as time decreases from 0.05 to 0
    weights = torch.clamp((0.05 - time_values) / 0.05, min=0, max=1) * 3  # Shape: (batch_size,)
    
    # Compute loss between pos_outputs and object_pos
    object_loss = pos_criterion(pos_outputs, object_pos)  # Shape: (batch_size, 1, 2)
    object_loss = object_loss.sum(dim=2).squeeze(1)  # Shape: (batch_size,)
    
    # Apply weights
    weighted_object_loss = object_loss * weights * valid  # Shape: (batch_size,)
    
    # Compute mean over the batch
    object_loss_mean = weighted_object_loss.mean()
    
    # Total loss
    total_loss = pos_loss + object_loss_mean
    
    return total_loss

In [None]:

num_epochs = 1000
total_train_len = len(train_loader)
total_valid_len = 0
for valid_loader in valid_loaders:
    total_valid_len += len(valid_loader)

for epoch in range(num_epochs):
    osuModelPos.train()
    train_loss_pos = 0
    
    # Training
    # for train_inputs, train_targets, input_lengths, target_lengths in train_loader:
    for train_inputs, train_targets, input_lengths, train_objects in train_loader:
        train_inputs, train_targets, train_objects = train_inputs.to(device), train_targets.to(device), train_objects.to(device)
        
        pos_optimizer.zero_grad()
        
        # Forward propagate
        train_pos_outputs = osuModelPos(train_inputs, train_targets)
        
        # Compute losses
        loss_pos = get_hybrid_loss(train_pos_outputs, train_targets, train_objects)
        # loss_pos = get_pos_loss(train_pos_outputs, train_targets)
        
        # Accumulate losses
        train_loss_pos += loss_pos.item()
        
        # Position encoder decoder back propagation
        loss_pos.backward()
        
        # Gradient clipping 
        torch.nn.utils.clip_grad_norm_(osuModelPos.parameters(), max_norm=1.0)
        # Updating weights
        pos_optimizer.step()
        
    train_loss_pos /= total_train_len
    
    # Validation
    osuModelPos.eval()
    total_valid_loss_pos = 0
    
    with torch.no_grad():
        for valid_loader in valid_loaders:
            valid_loss_pos = 0
            valid_loss_key = 0
            # for valid_inputs, valid_targets, input_lengths, target_lengths in valid_loader:
            for valid_inputs, valid_targets, input_lengths, valid_objects in valid_loader:
                valid_inputs, valid_targets = valid_inputs.to(device), valid_targets.to(device)

                valid_pos_outputs = osuModelPos(valid_inputs, valid_targets)
                
                # loss_pos = get_weighted_pos_loss(valid_pos_outputs, valid_targets, valid_inputs)
                loss_pos = get_pos_loss(valid_pos_outputs, valid_targets)

                valid_loss_pos += loss_pos.item()
                
            total_valid_loss_pos += valid_loss_pos
            
    total_valid_loss_pos /= total_valid_len
    
    pos_scheduler.step(total_valid_loss_pos)
    
    pos_lr = pos_optimizer.param_groups[0]['lr']
    
    # Printing info
    print(
    f"Epoch [{epoch+1}/{num_epochs}]"
    f"T Loss Pos: {train_loss_pos:.4f} | "
    f"V Loss Pos: {total_valid_loss_pos:.4f} | "
    f"LR Pos: {pos_lr} | "
    # f"TFR: {teacher_forcing_ratio:.2f}"
    )
    
    if pos_lr < 1e-5:
        break

Epoch [1/1000]T Loss Pos: 0.1038 | V Loss Pos: 0.0503 | LR Pos: 0.01 | 
Epoch [2/1000]T Loss Pos: 0.0499 | V Loss Pos: 0.0507 | LR Pos: 0.01 | 
Epoch [3/1000]T Loss Pos: 0.0428 | V Loss Pos: 0.0317 | LR Pos: 0.01 | 
Epoch [4/1000]T Loss Pos: 0.0404 | V Loss Pos: 0.0255 | LR Pos: 0.01 | 
Epoch [5/1000]T Loss Pos: 0.0375 | V Loss Pos: 0.0262 | LR Pos: 0.01 | 
Epoch [6/1000]T Loss Pos: 0.0340 | V Loss Pos: 0.0251 | LR Pos: 0.01 | 
Epoch [7/1000]T Loss Pos: 0.0356 | V Loss Pos: 0.0329 | LR Pos: 0.01 | 
Epoch [8/1000]T Loss Pos: 0.0323 | V Loss Pos: 0.0275 | LR Pos: 0.01 | 
Epoch [9/1000]T Loss Pos: 0.0340 | V Loss Pos: 0.0257 | LR Pos: 0.01 | 
Epoch [10/1000]T Loss Pos: 0.0330 | V Loss Pos: 0.0246 | LR Pos: 0.01 | 
Epoch [11/1000]T Loss Pos: 0.0324 | V Loss Pos: 0.0250 | LR Pos: 0.01 | 
Epoch [12/1000]T Loss Pos: 0.0321 | V Loss Pos: 0.0242 | LR Pos: 0.01 | 
Epoch [13/1000]T Loss Pos: 0.0336 | V Loss Pos: 0.0250 | LR Pos: 0.01 | 
Epoch [14/1000]T Loss Pos: 0.0312 | V Loss Pos: 0.0256 | LR 

In [10]:
VERSION = '0.9'
torch.save(osuModelPos.state_dict(), f'./OSU_model_pos_{VERSION}.pth')     