In [15]:
# # Debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.DatasetGenerator import get_X, get_y, get_pos_sequences, read_sequences
# path = '../Data/train_pos/annoy'
# df_y_pos, hr = get_y(path, 'pos')
# df_X = get_X(path, hr)
# get_pos_sequences([df_X, df_y_pos, path], 10, 500)
# input_seq, target_seq = read_sequences(path, 'pos')
# pass

In [16]:
# debug
# %load_ext autoreload
# %autoreload 2
# import sys
# from pathlib import Path
# sys.path.append(str(Path.cwd().parent))
# from utils.Visualizer import visualize
# from utils.DataParser import parse_map, parse_replay_pos
# path = '../Data/train_pos/ty/'
# replay_path = list(Path(path).glob('*.osr'))[0]
# map_path = list(Path(path).glob('*.osu'))[0]
# replay_data, hr = parse_replay_pos(replay_path)
# hit_objects, _ = parse_map(map_path, hr)
# song = list(Path(path).glob('audio.*'))[0]
# visualize(hit_objects, replay_data, song)

In [17]:
# Data paths
%load_ext autoreload
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
REGENERATE = True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
%autoreload 2
# Data parsing
from utils.DatasetGenerator import get_set
datasets = {
    'train': None,
    'valid': None
}

for idx, set in enumerate(data_paths):
    print(f'Converting {set} data')
    data = get_set(data_paths[set], 'pos', REGENERATE)
    
    datasets[set] = data

Converting train data


100%|██████████| 31/31 [00:04<00:00,  6.83it/s]


Converting valid data


100%|██████████| 3/3 [00:01<00:00,  1.70it/s]


In [19]:
# Sequence generation
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*weights_only=False.*")
from utils.DatasetGenerator import create_set_sequences, get_set_sequences
from itertools import chain

N = 3
window = 400
for idx, set in enumerate(datasets):
    if len(datasets[set]) > 0:
        print(f'Generating {set} sequences')
        create_set_sequences(datasets[set], N, window, 'pos')

combined_train_key_input = None
combined_train_key_target = None
valid_sequences = None
for idx, set in enumerate(data_paths):
    print(f'Retrieving {set} sequences')
    key_input_sequences, key_target_sequences = get_set_sequences(data_paths[set], 'pos')
    if set == 'train':
        combined_train_key_input = list(chain.from_iterable(key_input_sequences))
        combined_train_key_target = list(chain.from_iterable(key_target_sequences))
    else:
        valid_sequences = [key_input_sequences, key_target_sequences]

Generating train sequences


100%|██████████| 31/31 [00:07<00:00,  4.00it/s]


Generating valid sequences


100%|██████████| 3/3 [00:02<00:00,  1.07it/s]


Retrieving train sequences


100%|██████████| 31/31 [00:30<00:00,  1.02it/s]


Retrieving valid sequences


100%|██████████| 3/3 [00:02<00:00,  1.20it/s]


In [20]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from Model.OSUDataset import OsuDataset, collate_fn
data_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
BATCH_SIZE_TRAIN = 1024
if os.path.exists('./pos_dataset/train_dataset.pth') and not REGENERATE:
    train_dataset = torch.load('./pos_dataset/train_dataset.pth')
else:
    train_dataset = OsuDataset(combined_train_key_input, combined_train_key_target)
    torch.save(train_dataset, './pos_dataset/train_dataset.pth')
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True, drop_last=True, collate_fn=collate_fn, pin_memory=True)

BATCH_SIZE_VALID = 64
valid_loaders = []
valid_folders = os.listdir(data_paths['valid'])
valid_data_size = sum(os.path.isdir(os.path.join(data_paths['valid'], folder)) for folder in valid_folders)
for idx in range(valid_data_size):
    if os.path.exists(f'./pos_dataset/valid_dataset_{idx}.pth') and not REGENERATE:
        valid_dataset = torch.load(f'./pos_dataset/valid_dataset_{idx}.pth')
    else:
        valid_dataset = valid_dataset = OsuDataset(valid_sequences[0][idx], valid_sequences[1][idx])
        torch.save(valid_dataset, f'./pos_dataset/valid_dataset_{idx}.pth')
        
    valid_loader =  DataLoader(valid_dataset, batch_size=BATCH_SIZE_VALID, shuffle=False, drop_last=True, collate_fn=collate_fn, pin_memory=True)
    valid_loaders.append(valid_loader)

In [21]:
%autoreload 2
from Model.OSUModel import PositionEncoder, PositionDecoder, OSUModelPos
from torch.nn.utils.rnn import pad_packed_sequence
pos_input_size = 10

pos_hidden_size = 24
pos_num_layers = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoders take (input_size, hidden_size, num_layers)
pos_encoder = PositionEncoder(pos_input_size, pos_hidden_size, pos_num_layers).to(device)

# Decoders takes (hidden_size, num_layers)
pos_decoder = PositionDecoder(pos_hidden_size, pos_num_layers).to(device)

osuModelPos = OSUModelPos(pos_encoder, pos_decoder, device).to(device)

pos_criterion = nn.L1Loss(reduction='none')

pos_optimizer = torch.optim.Adam(
    osuModelPos.parameters(),
    lr = 0.01,
    weight_decay=1e-5
)


pos_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(pos_optimizer, mode='min',
                                                    factor=0.1, patience=15,
                                                    threshold=1e-4)

# def masked_loss_pos(pos_inputs, pos_outputs, targets, target_lengths, threshold = 0.20, base_scaling_factor = 2.0):
#   # Position targets: skip the first target to align with pos_outputs
#     pos_targets = targets[:, 1:, :]  # Shape: (batch_size, seq_len - 1, 2)
#     target_lengths = target_lengths.to(pos_outputs.device)
#     # Create mask for variable length sequences
#     max_len = pos_outputs.size(1)  # seq_len - 1
#     mask = torch.arange(max_len, device=pos_outputs.device).unsqueeze(0) < target_lengths.unsqueeze(1)
#     # Shape: (batch_size, seq_len - 1)

#     # Compute the position loss using L1 loss without reduction
#     pos_loss = pos_criterion(pos_outputs, pos_targets)  # Shape: (batch_size, seq_len - 1, 2)
#     pos_loss = pos_loss.sum(dim=2)  # Sum over position dimensions (x and y), shape: (batch_size, seq_len - 1)

#     # Unpack pos_inputs to get time_values
#     padded_inputs, input_lengths = pad_packed_sequence(pos_inputs, batch_first=True)
#     # padded_inputs shape: (batch_size, seq_len, 10)

#     # Extract time values from column index 7 (8th column)
#     # We are interested in the first two entries of each sequence
#     time_values = padded_inputs[:, :2, 7]  # Shape: (batch_size, 2)

#     # Compute absolute time values
#     abs_time_values = torch.abs(time_values)  # Shape: (batch_size, 2)
    
#     # Find the smallest absolute time value for each sample in the batch
#     min_abs_time_values, _ = torch.min(abs_time_values, dim=1)  # Shape: (batch_size,)
    
#     # Compute scaling factors based on time_values
#     scaling_factor = torch.ones_like(time_values, device=pos_outputs.device)  # Initialize with ones
#     time_mask = time_values <= threshold  # Boolean mask where time_value <= threshold

#     # Compute scaling factors for time_values <= threshold
#     scaling_factor[time_mask] = 1.0 + (threshold - time_values[time_mask]) * ((base_scaling_factor - 1.0) / threshold)
#     # Ensure scaling_factor is at least 1.0

#     # Apply scaling factors to pos_loss
#     pos_loss = pos_loss * scaling_factor

#     # Apply the sequence mask
#     pos_loss = (pos_loss * mask).sum() / mask.sum()

#     return pos_loss
  
def masked_loss_pos(pos_outputs, targets, target_lengths):
    # Position targets (batch_size, sequence_len - 1, 2)
    pos_targets = targets[:, 1:, :]

    # Creating mask for variable length sequences
    max_len = pos_outputs.size(1)
    #(batch_size, max_len). True when value is less than target length
    mask = torch.arange(max_len).unsqueeze(0) < target_lengths.unsqueeze(1) 
    mask = mask.to(pos_outputs.device)

    # Position loss with absolute error.
    pos_loss = pos_criterion(pos_outputs, pos_targets)
    pos_loss = pos_loss.sum(dim=2)
    pos_loss = (pos_loss * mask).sum() / mask.sum()

    return pos_loss
  
# def get_pos_loss(pos_outputs, targets):
#     # Position targets (batch_size, 2, 2)
#     # Removing previous positions from targets
#     pos_targets = targets[:, 1:, :]  

#     # Position loss with absolute error
#     pos_loss = pos_criterion(pos_outputs, pos_targets)
    
#     # Sum the loss along the last dimension and compute the mean over the batch
#     pos_loss = pos_loss.sum(dim=2).mean()

#     return pos_loss

In [22]:
num_epochs = 1000
total_train_len = len(train_loader)
total_valid_len = 0
for valid_loader in valid_loaders:
    total_valid_len += len(valid_loader)

for epoch in range(num_epochs):
    osuModelPos.train()
    train_loss_pos = 0
    teacher_forcing_ratio = max(1 - (epoch / 20), 0)
    
    # Training
    for train_inputs, train_targets, input_lengths, target_lengths in train_loader:
        train_inputs, train_targets = train_inputs.to(device), train_targets.to(device)
        
        pos_optimizer.zero_grad()
        
        # Forward propagate
        train_pos_outputs = osuModelPos(train_inputs, train_targets, teacher_forcing_ratio)
        
        # Compute losses
        # loss_pos = get_pos_loss(train_pos_outputs, train_targets)
        loss_pos = masked_loss_pos(train_pos_outputs, train_targets, target_lengths)
        
        # Accumulate losses
        train_loss_pos += loss_pos.item()
        
        # Position encoder decoder back propagation
        loss_pos.backward()
        
        # Gradient clipping 
        torch.nn.utils.clip_grad_norm_(osuModelPos.parameters(), max_norm=1.0)
        # Updating weights
        pos_optimizer.step()
        
    train_loss_pos /= total_train_len
    
    # Validation
    osuModelPos.eval()
    total_valid_loss_pos = 0
    
    with torch.no_grad():
        for valid_loader in valid_loaders:
            valid_loss_pos = 0
            valid_loss_key = 0
            for valid_inputs, valid_targets, input_lengths, target_lengths in valid_loader:
                valid_inputs, valid_targets = valid_inputs.to(device), valid_targets.to(device)

                valid_pos_outputs = osuModelPos(valid_inputs, valid_targets, 0)
                
                # loss_pos = get_pos_loss(valid_pos_outputs, valid_targets)
                loss_pos = masked_loss_pos(valid_pos_outputs, valid_targets, target_lengths)
                valid_loss_pos += loss_pos.item()
                
            total_valid_loss_pos += valid_loss_pos
            
    total_valid_loss_pos /= total_valid_len
    
    pos_scheduler.step(total_valid_loss_pos)
    
    pos_lr = pos_optimizer.param_groups[0]['lr']
    
    # Printing info
    print(f"Epoch [{epoch+1}/{num_epochs}]"
    f"T Loss Pos: {train_loss_pos:.4f} | "
    f"V Loss Pos: {total_valid_loss_pos:.4f} | "
    f"LR Pos: {pos_lr} | "
    f"TFR: {teacher_forcing_ratio:.2f}")


Epoch [1/1000]T Loss Pos: 0.0807 | V Loss Pos: 0.0367 | LR Pos: 0.01 | TFR: 1.00
Epoch [2/1000]T Loss Pos: 0.0309 | V Loss Pos: 0.0311 | LR Pos: 0.01 | TFR: 0.95
Epoch [3/1000]T Loss Pos: 0.0254 | V Loss Pos: 0.0336 | LR Pos: 0.01 | TFR: 0.90
Epoch [4/1000]T Loss Pos: 0.0238 | V Loss Pos: 0.0250 | LR Pos: 0.01 | TFR: 0.85
Epoch [5/1000]T Loss Pos: 0.0215 | V Loss Pos: 0.0251 | LR Pos: 0.01 | TFR: 0.80
Epoch [6/1000]T Loss Pos: 0.0211 | V Loss Pos: 0.0257 | LR Pos: 0.01 | TFR: 0.75
Epoch [7/1000]T Loss Pos: 0.0210 | V Loss Pos: 0.0244 | LR Pos: 0.01 | TFR: 0.70
Epoch [8/1000]T Loss Pos: 0.0192 | V Loss Pos: 0.0233 | LR Pos: 0.01 | TFR: 0.65
Epoch [9/1000]T Loss Pos: 0.0198 | V Loss Pos: 0.0235 | LR Pos: 0.01 | TFR: 0.60
Epoch [10/1000]T Loss Pos: 0.0203 | V Loss Pos: 0.0265 | LR Pos: 0.01 | TFR: 0.55
Epoch [11/1000]T Loss Pos: 0.0202 | V Loss Pos: 0.0230 | LR Pos: 0.01 | TFR: 0.50
Epoch [12/1000]T Loss Pos: 0.0182 | V Loss Pos: 0.0243 | LR Pos: 0.01 | TFR: 0.45
Epoch [13/1000]T Loss Pos

KeyboardInterrupt: 

In [23]:
VERSION = '0.7'
torch.save(osuModelPos.state_dict(), f'./OSU_model_pos_{VERSION}.pth')     