In [1]:
# # Debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.DatasetGenerator import get_X, get_y, get_pos_sequences, read_sequences
# path = '../Data/train_pos/annoy'
# df_y_pos, hr = get_y(path, 'pos')
# df_X = get_X(path, hr)
# get_pos_sequences([df_X, df_y_pos, path], 10, 500)
# input_seq, target_seq = read_sequences(path, 'pos')
# pass

In [2]:
# # debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.Visualizer import visualize
# from utils.DataParser import parse_map, parse_replay_pos
# path = '../Data/train_pos/shink/'
# replay_path = list(Path(path).glob('*.osr'))[0]
# map_path = list(Path(path).glob('*.osu'))[0]
# replay_data, hr = parse_replay_pos(replay_path)
# hit_objects, _ = parse_map(map_path, hr)
# song = list(Path(path).glob('audio.*'))[0]
# visualize(hit_objects, replay_data, song)

In [3]:
# Data paths
%load_ext autoreload
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
REGENERATE = True

In [4]:
%autoreload 2
# Data parsing
from utils.DatasetGenerator import get_set
datasets = {
    'train': None,
    'valid': None
}

for idx, set in enumerate(data_paths):
    print(f'Converting {set} data')
    data = get_set(data_paths[set], 'pos', REGENERATE)
    
    datasets[set] = data

Converting train data


100%|██████████| 35/35 [00:04<00:00,  8.55it/s]


Converting valid data


100%|██████████| 3/3 [00:01<00:00,  1.91it/s]


In [5]:
# Sequence generation
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*weights_only=False.*")
from utils.DatasetGenerator import create_set_sequences, get_set_sequences
from itertools import chain

N = 3
window = 400
for idx, set in enumerate(datasets):
    if len(datasets[set]) > 0:
        print(f'Generating {set} sequences')
        create_set_sequences(datasets[set], N, window, 'pos')

combined_train_key_input = None
combined_train_key_target = None
valid_sequences = None
for idx, set in enumerate(data_paths):
    print(f'Retrieving {set} sequences')
    key_input_sequences, key_target_sequences = get_set_sequences(data_paths[set], 'pos')
    if set == 'train':
        combined_train_key_input = list(chain.from_iterable(key_input_sequences))
        combined_train_key_target = list(chain.from_iterable(key_target_sequences))
    else:
        valid_sequences = [key_input_sequences, key_target_sequences]

Generating train sequences


100%|██████████| 35/35 [00:08<00:00,  4.21it/s]


Generating valid sequences


100%|██████████| 3/3 [00:02<00:00,  1.15it/s]


Retrieving train sequences


100%|██████████| 35/35 [00:33<00:00,  1.06it/s]


Retrieving valid sequences


100%|██████████| 3/3 [00:02<00:00,  1.29it/s]


In [6]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from Model.OSUDataset import OsuDataset, pos_collate_fn
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
BATCH_SIZE_TRAIN = 1024
if os.path.exists('./pos_dataset/train_dataset.pth') and not REGENERATE:
    train_dataset = torch.load('./pos_dataset/train_dataset.pth')
else:
    train_dataset = OsuDataset(combined_train_key_input, combined_train_key_target)
    torch.save(train_dataset, './pos_dataset/train_dataset.pth')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, drop_last=True, collate_fn=pos_collate_fn, pin_memory=True)

BATCH_SIZE_VALID = 64
valid_loaders = []
valid_folders = os.listdir(data_paths['valid'])
valid_data_size = sum(os.path.isdir(os.path.join(data_paths['valid'], folder)) for folder in valid_folders)
for idx in range(valid_data_size):
    if os.path.exists(f'./pos_dataset/valid_dataset_{idx}.pth') and not REGENERATE:
        valid_dataset = torch.load(f'./pos_dataset/valid_dataset_{idx}.pth')
    else:
        valid_dataset = valid_dataset = OsuDataset(valid_sequences[0][idx], valid_sequences[1][idx])
        torch.save(valid_dataset, f'./pos_dataset/valid_dataset_{idx}.pth')
        
    valid_loader =  DataLoader(valid_dataset, batch_size=BATCH_SIZE_VALID, shuffle=False, drop_last=True, collate_fn=pos_collate_fn, pin_memory=True)
    valid_loaders.append(valid_loader)

In [None]:
%autoreload 2
from Model.OSUModel import PositionEncoder, PositionDecoder, OSUModelPos
from torch.nn.utils.rnn import pad_packed_sequence
pos_input_size = 10

pos_hidden_size = 24
pos_num_layers = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoders take (input_size, hidden_size, num_layers)
pos_encoder = PositionEncoder(pos_input_size, pos_hidden_size, pos_num_layers).to(device)

# Decoders takes (hidden_size, num_layers)
pos_decoder = PositionDecoder(pos_hidden_size, pos_num_layers).to(device)

osuModelPos = OSUModelPos(pos_encoder, pos_decoder, device).to(device)

pos_criterion = nn.L1Loss(reduction='none')

pos_optimizer = torch.optim.Adam(
    osuModelPos.parameters(),
    lr = 0.01,
    weight_decay=1e-5
)


pos_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(pos_optimizer, mode='min',
                                                    factor=0.1, patience=15,
                                                    threshold=1e-4)

  
# def masked_loss_pos(pos_outputs, targets, target_lengths):
#     # Position targets (batch_size, sequence_len - 1, 2)
#     pos_targets = targets[:, 1:, :]

#     # Creating mask for variable length sequences
#     max_len = pos_outputs.size(1)
#     #(batch_size, max_len). True when value is less than target length
#     mask = torch.arange(max_len).unsqueeze(0) < target_lengths.unsqueeze(1) 
#     mask = mask.to(pos_outputs.device)

#     # Position loss with absolute error.
#     pos_loss = pos_criterion(pos_outputs, pos_targets)
#     pos_loss = pos_loss.sum(dim=2)
#     pos_loss = (pos_loss * mask).sum() / mask.sum()

#     return pos_loss
  
def get_pos_loss(pos_outputs, targets):
    # Position targets (batch_size, 2, 2)
    # Removing previous positions from targets
    pos_targets = targets[:, 1:, :]  

    # Position loss with absolute error
    pos_loss = pos_criterion(pos_outputs, pos_targets)
    
    # Sum the loss along the last dimension and compute the mean over the batch
    pos_loss = pos_loss.sum(dim=2).mean()

    return pos_loss

def get_weighted_pos_loss(pos_outputs, targets, pos_inputs):
    # Position targets (batch_size, 2, 2)
    # Removing previous positions from targets
    pos_targets = targets[:, 1:, :]  # Shape: (batch_size, 1, 2)
    
    # Position loss with absolute error (no reduction)
    pos_loss = pos_criterion(pos_outputs, pos_targets)  # Shape: (batch_size, 1, 2)
        
    # Sum the loss along the last dimension to get per-sample loss
    pos_loss = pos_loss.sum(dim=2).squeeze(1)  # Shape: (batch_size,)
    
    inputs_padded, lengths = pad_packed_sequence(pos_inputs, batch_first=True)
    
    batch_size = inputs_padded.size(0)
    
    # Extract times and object_types for first two objects
    times_padded = inputs_padded[:, :2, 7]  # Shape: (batch_size, 2)
    object_types_padded = inputs_padded[:, :2, 2]  # Shape: (batch_size, 2)
    
    # Create valid_mask indicating valid entries (sequence lengths may be less than 2)
    positions = torch.arange(2, device=pos_outputs.device).unsqueeze(0)  # Shape: (1, 2)
    lengths_expanded = (lengths.unsqueeze(1)).to(pos_outputs.device)  # Shape: (batch_size, 1)
    valid_mask = positions < lengths_expanded  # Shape: (batch_size, 2)
    
    # Set times to a large value where entries are invalid (so they won't be selected as minimum)
    times_padded = times_padded.clone()
    times_padded[~valid_mask] = float('inf')
    
    # Compute absolute times
    abs_times_padded = times_padded.abs()  # Shape: (batch_size, 2)
    
    # Get indices of minimum times
    min_time_indices = abs_times_padded.argmin(dim=1)  # Shape: (batch_size,)
    
    # Gather the times and object_types corresponding to min_time_indices
    times = times_padded[torch.arange(batch_size), min_time_indices]  # Shape: (batch_size,)
    object_types = object_types_padded[torch.arange(batch_size), min_time_indices]  # Shape: (batch_size,)
    
    # Define the time threshold and extra weight
    time_threshold = 0.2  # Adjust this threshold as needed
    extra_weight = 3.0    # Adjust the extra weight as needed
    
    # Compute a smooth weight based on time
    time_weight = ((time_threshold - times).clamp(min=0.0) / time_threshold).clamp(max=1.0)
    # Multiply by object type indicator (1 if type 1, 0 otherwise)
    weights = 1.0 + time_weight * (object_types == 1).float() * (extra_weight - 1.0)
    
    # Apply the weights to the per-sample losses
    weighted_loss = pos_loss * weights  # Shape: (batch_size,)
    
    # Compute the mean over the batch
    pos_loss = weighted_loss.mean()
    
    return pos_loss

In [8]:
num_epochs = 1000
total_train_len = len(train_loader)
total_valid_len = 0
for valid_loader in valid_loaders:
    total_valid_len += len(valid_loader)

for epoch in range(num_epochs):
    osuModelPos.train()
    train_loss_pos = 0
    teacher_forcing_ratio = max(1 - (epoch / 20), 0)
    
    # Training
    # for train_inputs, train_targets, input_lengths, target_lengths in train_loader:
    for train_inputs, train_targets, input_lengths in train_loader:
        train_inputs, train_targets = train_inputs.to(device), train_targets.to(device)
        
        pos_optimizer.zero_grad()
        
        # Forward propagate
        train_pos_outputs = osuModelPos(train_inputs, train_targets, teacher_forcing_ratio)
        
        # Compute losses
        # loss_pos = get_weighted_pos_loss(train_pos_outputs, train_targets, train_inputs)
        loss_pos = get_pos_loss(train_pos_outputs, train_targets)
        # loss_pos = masked_loss_pos(train_pos_outputs, train_targets, target_lengths)
        
        # Accumulate losses
        train_loss_pos += loss_pos.item()
        
        # Position encoder decoder back propagation
        loss_pos.backward()
        
        # Gradient clipping 
        torch.nn.utils.clip_grad_norm_(osuModelPos.parameters(), max_norm=1.0)
        # Updating weights
        pos_optimizer.step()
        
    train_loss_pos /= total_train_len
    
    # Validation
    osuModelPos.eval()
    total_valid_loss_pos = 0
    
    with torch.no_grad():
        for valid_loader in valid_loaders:
            valid_loss_pos = 0
            valid_loss_key = 0
            # for valid_inputs, valid_targets, input_lengths, target_lengths in valid_loader:
            for valid_inputs, valid_targets, input_lengths in valid_loader:
                valid_inputs, valid_targets = valid_inputs.to(device), valid_targets.to(device)

                valid_pos_outputs = osuModelPos(valid_inputs, valid_targets, 0)
                
                # loss_pos = get_weighted_pos_loss(valid_pos_outputs, valid_targets, valid_inputs)
                loss_pos = get_pos_loss(valid_pos_outputs, valid_targets)
                # loss_pos = masked_loss_pos(valid_pos_outputs, valid_targets, target_lengths)
                valid_loss_pos += loss_pos.item()
                
            total_valid_loss_pos += valid_loss_pos
            
    total_valid_loss_pos /= total_valid_len
    
    pos_scheduler.step(total_valid_loss_pos)
    
    pos_lr = pos_optimizer.param_groups[0]['lr']
    
    # Printing info
    print(f"Epoch [{epoch+1}/{num_epochs}]"
    f"T Loss Pos: {train_loss_pos:.4f} | "
    f"V Loss Pos: {total_valid_loss_pos:.4f} | "
    f"LR Pos: {pos_lr} | "
    f"TFR: {teacher_forcing_ratio:.2f}")


Epoch [1/1000]T Loss Pos: 0.0832 | V Loss Pos: 0.0422 | LR Pos: 0.01 | TFR: 1.00
Epoch [2/1000]T Loss Pos: 0.0337 | V Loss Pos: 0.0453 | LR Pos: 0.01 | TFR: 0.95
Epoch [3/1000]T Loss Pos: 0.0265 | V Loss Pos: 0.0298 | LR Pos: 0.01 | TFR: 0.90
Epoch [4/1000]T Loss Pos: 0.0234 | V Loss Pos: 0.0277 | LR Pos: 0.01 | TFR: 0.85
Epoch [5/1000]T Loss Pos: 0.0238 | V Loss Pos: 0.0243 | LR Pos: 0.01 | TFR: 0.80
Epoch [6/1000]T Loss Pos: 0.0225 | V Loss Pos: 0.0239 | LR Pos: 0.01 | TFR: 0.75
Epoch [7/1000]T Loss Pos: 0.0223 | V Loss Pos: 0.0237 | LR Pos: 0.01 | TFR: 0.70
Epoch [8/1000]T Loss Pos: 0.0214 | V Loss Pos: 0.0230 | LR Pos: 0.01 | TFR: 0.65
Epoch [9/1000]T Loss Pos: 0.0207 | V Loss Pos: 0.0235 | LR Pos: 0.01 | TFR: 0.60
Epoch [10/1000]T Loss Pos: 0.0219 | V Loss Pos: 0.0237 | LR Pos: 0.01 | TFR: 0.55
Epoch [11/1000]T Loss Pos: 0.0201 | V Loss Pos: 0.0257 | LR Pos: 0.01 | TFR: 0.50
Epoch [12/1000]T Loss Pos: 0.0200 | V Loss Pos: 0.0226 | LR Pos: 0.01 | TFR: 0.45
Epoch [13/1000]T Loss Pos

KeyboardInterrupt: 

In [9]:
VERSION = '0.8'
torch.save(osuModelPos.state_dict(), f'./OSU_model_pos_{VERSION}.pth')     