In [17]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
from tqdm import tqdm
import torch
from OSUDataset import OsuDataset, pos_collate_fn, key_collate_fn
from utils.DatasetGenerator import get_X, get_pos_sequences, get_key_sequences
from torch.utils.data import DataLoader

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message=".*weights_only=False.*")
from OSUModel import OSUModelPos, OSUModelKey, PositionEncoder, KeypressEncoder, PositionDecoder, KeypressDecoder
pos_version = '0.14'
key_version = '0.3'

pos_input_size = 13
pos_hidden_size = 32
pos_num_layers = 2

key_input_size = 4
key_hidden_size = 16
key_num_layers = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

eval_pos_encoder = PositionEncoder(pos_input_size, pos_hidden_size, pos_num_layers).to(device)
eval_pos_decoder = PositionDecoder(pos_hidden_size, pos_num_layers).to(device)
eval_model_pos = OSUModelPos(eval_pos_encoder, eval_pos_decoder, device).to(device)
eval_model_pos.load_state_dict(torch.load(f'./OSU_model_pos_{pos_version}.pth'))
eval_model_pos.eval()

eval_key_encoder = KeypressEncoder(key_input_size, key_hidden_size, key_num_layers).to(device)
eval_key_decoder = KeypressDecoder(key_hidden_size, key_num_layers).to(device)
eval_model_key = OSUModelKey(eval_key_encoder, eval_key_decoder, device).to(device)
eval_model_key.load_state_dict(torch.load(f'./OSU_model_key_{key_version}.pth'))
eval_model_key.eval()

OSUModelKey(
  (key_encoder): KeypressEncoder(
    (lstm): LSTM(4, 16, num_layers=2, batch_first=True)
  )
  (key_decoder): KeypressDecoder(
    (lstm): LSTM(2, 16, num_layers=2, batch_first=True)
    (key): Linear(in_features=16, out_features=2, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [19]:
# Position Dataset
# pos_input_sequence, pos_time_steps = get_pos_sequences([X_test, None, test_data_path], N, react_time)
# dummy_pos_target = torch.zeros(1, 2)
# pos_target_sequence = [dummy_pos_target for i in range(len(pos_input_sequence))]

# pos_dataset = OsuDataset(pos_input_sequence, pos_target_sequence)
# pos_loader =  DataLoader(pos_dataset, batch_size=1, shuffle=False, collate_fn=pos_collate_fn, pin_memory=True)
from utils.dataparser.PositionData import PositionData
react_time = 500
path = '../Data/unseen/Candyyyland'
hr = False
set_paths = {
    'train': '../Data/train_pos',
    'valid': '../Data/valid_pos'
}
pdata = PositionData(set_paths, 8, 500, True)
X_test = pdata.get_X(path, hr)
start_time = X_test['time'].min() - react_time
pos_input_seq, pos_time_steps, _ = pdata.generate_one(path)
dummy_pos_target = torch.zeros(1, 2)
pos_target_seq = [dummy_pos_target for i in range(len(pos_input_seq))]
pos_dataset = OsuDataset(pos_input_seq, pos_target_seq)
pos_loader =  DataLoader(pos_dataset, batch_size=1, shuffle=False, collate_fn=pos_collate_fn, pin_memory=True)

In [20]:
# Position prediction
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
position_pred = []
with torch.no_grad():
    # initialize the first position sequence predictions
    # Cursor will be center (0.5, 0.5)
    prev_pred = torch.tensor([0.5, 0.5], device=device).unsqueeze(0).unsqueeze(1)

    for idx, (input_sequence, target_sequence, input_length, target_object) in tqdm(enumerate(pos_loader), total=len(pos_loader), desc="Predicting"):
        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)
        
        # appending relative distance
        # **Unpack the input_sequence**
        padded_input_sequence, input_lengths = pad_packed_sequence(input_sequence, batch_first=True)
        # padded_input_sequence shape: (batch_size, seq_len, feature_size)
        
        prev_position = prev_pred.squeeze(0).squeeze(0)  # Shape: (2,)
        dist_x = (padded_input_sequence[:, :, 0] - prev_position[0])  # Shape: (batch_size, seq_len)
        dist_y = (padded_input_sequence[:, :, 1] - prev_position[1])  # Shape: (batch_size, seq_len)
        
        # **Append distances to input_sequence**
        # Add an extra dimension to dist_x and dist_y to match dimensions for concatenation
        dist_x = dist_x.unsqueeze(2)  # Shape: (batch_size, seq_len, 1)
        dist_y = dist_y.unsqueeze(2)  # Shape: (batch_size, seq_len, 1)
        
        # Concatenate along the feature dimension (dim=2)
        new_input_sequence = torch.cat((padded_input_sequence, dist_x, dist_y), dim=2)
        # new_input_sequence shape: (batch_size, seq_len, feature_size + 2)
        
        new_input_sequence_packed = pack_padded_sequence(new_input_sequence, input_lengths, batch_first=True, enforce_sorted=False)
        
        # appending previous prediction
        target = torch.cat((prev_pred, target_sequence), dim=1)
        
        pos_output = eval_model_pos(new_input_sequence_packed, target, 0)
        
        prev_pred = pos_output
        
        position_pred.append(pos_output.flatten().tolist())

Predicting: 100%|██████████| 12499/12499 [00:19<00:00, 649.02it/s]


In [21]:
# Keypress Dataset
N = 10
react_time = 500
key_input_sequence, key_time_steps, key_end_times = get_key_sequences([X_test, None, path], N, react_time)
dummy_key_target = torch.zeros(1, 2)
key_target_sequence = [dummy_key_target for i in range(len(key_input_sequence))]

key_dataset = OsuDataset(key_input_sequence, key_target_sequence)
key_loader =  DataLoader(key_dataset, batch_size=1, shuffle=False, collate_fn=key_collate_fn, pin_memory=True)

In [22]:
# Keypress prediction
keypress_pred = []
with torch.no_grad():
    # initialize the first keypress sequence predictions
    # Keypress will be key1 (1.0, 0.0)
    prev_pred = torch.tensor([1.0, 0.0], device=device).unsqueeze(0).unsqueeze(1)
    
    for idx, (input_sequence, target_sequence, input_length, target_length) in tqdm(enumerate(key_loader), total=len(key_loader), desc="Predicting"):
        input_sequence, target_sequence = input_sequence.to(device), target_sequence.to(device)
            
        target = torch.cat((prev_pred, target_sequence), dim=1)
        key_output = eval_model_key(input_sequence, target, 0)
        
        prev_pred = key_output
        
        key_index = torch.argmax(key_output, 2)
        keypress_pred.append(key_index.flatten().tolist())

Predicting: 100%|██████████| 1105/1105 [00:01<00:00, 981.10it/s]


In [23]:
# Post Processing
import numpy as np
from scipy.interpolate import interp1d

positions = np.array(position_pred)
positions[:, 0] *= 512
positions[:, 1] *= 384

pos_time_column = np.array(pos_time_steps)
pos_time_column = (pos_time_column).astype(int)
pos_time_column = pos_time_column.reshape(-1, 1)
positions_matrix= np.concatenate((pos_time_column, positions), axis=1)

keypresses = np.array(keypress_pred)
keypresses = keypresses + 1
key_time_column = np.array(key_time_steps).astype(int)
key_time_column = key_time_column.reshape(-1, 1)
key_end_time_column = np.array(key_end_times).astype(int)
key_end_time_column = key_end_time_column.reshape(-1, 1)
keypresses_matrix = np.concatenate((key_time_column, key_end_time_column, keypresses), axis=1)


def interpolate_data(data):
    time = data[:, 0]
    x = data[:, 1]
    y = data[:, 2]

    dt = np.diff(time)
    gap_threshold = 25  # milliseconds
    gap_indices = np.where(dt > gap_threshold)[0]

    # Define segment start and end indices
    segment_starts = np.insert(gap_indices + 1, 0, 0)
    segment_ends = np.append(gap_indices, len(time) - 1)

    interpolated_time = []
    interpolated_x = []
    interpolated_y = []

    # Desired sampling interval (approximate 120Hz)
    sampling_interval = 8  # milliseconds

    for start_idx, end_idx in zip(segment_starts, segment_ends):
        segment_time = time[start_idx:end_idx + 1]
        segment_x = x[start_idx:end_idx + 1]
        segment_y = y[start_idx:end_idx + 1]

        # Round segment times to integers
        segment_start_time = int(np.ceil(segment_time[0]))
        segment_end_time = int(np.floor(segment_time[-1]))

        # Create new integer time points at the desired sampling interval
        new_time = np.arange(segment_start_time, segment_end_time + 1, sampling_interval)

        if len(new_time) == 0:
            continue  # No new time points in this segment

        # Choose interpolation kind
        kind = 'cubic' if len(segment_time) >= 4 else 'linear'

        # Interpolate x and y values
        x_interp = interp1d(segment_time, segment_x, kind=kind)
        y_interp = interp1d(segment_time, segment_y, kind=kind)

        new_x = x_interp(new_time)
        new_y = y_interp(new_time)

        # Collect interpolated data
        interpolated_time.append(new_time)
        interpolated_x.append(new_x)
        interpolated_y.append(new_y)

    # Combine all segments
    final_time = np.concatenate(interpolated_time)
    final_x = np.concatenate(interpolated_x)
    final_y = np.concatenate(interpolated_y)

    # Combine into final data array
    final_data = np.column_stack((final_time, final_x, final_y))
    return final_data

def merge_positions_keypresses(positions_matrix, keypresses_matrix):
    # Copy positions_matrix to avoid modifying the original
    # positions_matrix = positions_matrix.copy()
    
    # Initialize lists to collect new rows and indices to update
    new_rows = []
    update_indices = []
    
    # Iterate over each keypress event
    for keypress in keypresses_matrix:
        keypress_time, end_time, keycode = keypress
        keypress_time = float(keypress_time)
        end_time = float(end_time)
        keycode = int(keycode)
        
        # Find indices in positions_matrix where time matches
        idx = np.searchsorted(positions_matrix[:, 0], keypress_time, side='left')
        
        # Handle end_time == -1 (immediate release)
        if end_time == -1:
            # Check if positions_matrix has a row at keypress_time
            if idx < len(positions_matrix) and positions_matrix[idx, 0] == keypress_time:
                # Update keycode at this index
                update_indices.append((idx, keycode))
            else:
                # Insert new row with previous x, y, and keycode
                if idx > 0:
                    x_prev = positions_matrix[idx - 1, 1]
                    y_prev = positions_matrix[idx - 1, 2]
                else:
                    x_prev = positions_matrix[0, 1]
                    y_prev = positions_matrix[0, 2]
                new_row = [keypress_time, x_prev, y_prev, keycode]
                new_rows.append(new_row)
        else:
            # Find start and end indices for the time range
            idx_start = np.searchsorted(positions_matrix[:, 0], keypress_time, side='left')
            idx_end = np.searchsorted(positions_matrix[:, 0], end_time, side='left')
            
            # Update keycodes in the time range
            for idx_range in range(idx_start, idx_end):
                update_indices.append((idx_range, keycode))
            
            # Check if positions_matrix has a row at keypress_time
            if idx_start >= len(positions_matrix) or positions_matrix[idx_start, 0] != keypress_time:
                if idx_start > 0:
                    x_prev = positions_matrix[idx_start - 1, 1]
                    y_prev = positions_matrix[idx_start - 1, 2]
                else:
                    x_prev = positions_matrix[0, 1]
                    y_prev = positions_matrix[0, 2]
                new_row = [keypress_time, x_prev, y_prev, keycode]
                new_rows.append(new_row)
    
    # Update the keycodes in positions_matrix
    for idx, keycode in update_indices:
        positions_matrix[idx, 3] = keycode
    
    # Convert new_rows to a numpy array if any new rows exist
    if new_rows:
        new_rows_array = np.array(new_rows)
        # Combine the original and new positions
        positions_matrix = np.vstack((positions_matrix, new_rows_array))
        # Sort the combined positions by time
        positions_matrix = positions_matrix[np.argsort(positions_matrix[:, 0])]
    
    return positions_matrix

def _merge_pos_key(positions_matrix, keypresses_matrix):
    # Extract the times from positions_matrix for easy access
    times_array = positions_matrix[:, 0]
    update_indices = []
    
    # Function to find the index of the time in positions_matrix closest to target_time
    def find_nearest_index(times_array, target_time):
        idx = np.searchsorted(times_array, target_time)
        if idx == 0:
            return 0
        elif idx == len(times_array):
            return len(times_array) - 1
        else:
            prev_time = times_array[idx - 1]
            next_time = times_array[idx]
            if abs(prev_time - target_time) < abs(next_time - target_time):
                return idx - 1
            else:
                return idx
    
    # Iterate over each keypress event
    for keypress in keypresses_matrix:
        keypress_time, end_time, keycode = keypress
        keypress_time = float(keypress_time)
        end_time = float(end_time)
        keycode = int(keycode)
        
        if end_time == -1:
            # Single keypress event, find the nearest index
            idx = find_nearest_index(times_array, keypress_time)
            update_indices.append((idx, keycode))
        else:
            # Find indices in positions_matrix within the time range
            idx_start = np.searchsorted(times_array, keypress_time, side='left')
            idx_end = np.searchsorted(times_array, end_time, side='right')
            
            if idx_end > idx_start:
                # Update keycodes in the time range
                for idx in range(idx_start, idx_end):
                    update_indices.append((idx, keycode))
            else:
                # No positions within time range, find the nearest index to keypress_time
                idx = find_nearest_index(times_array, keypress_time)
                update_indices.append((idx, keycode))
    
    # Remove duplicate indices to avoid redundant updates
    update_indices = list(set(update_indices))
    
    # Update the keycodes in positions_matrix
    for idx, keycode in update_indices:
        positions_matrix[idx, 3] = keycode
    
    return positions_matrix

positions_matrix_interpolated = interpolate_data(positions_matrix)
# positions_matrix_interpolated = positions_matrix

num_rows = positions_matrix_interpolated.shape[0]
dummy_keys_column = np.zeros((num_rows, 1))
positions_matrix_interpolated = np.concatenate((positions_matrix_interpolated, dummy_keys_column), axis=1)
replay_predictions = _merge_pos_key(positions_matrix_interpolated, keypresses_matrix)


In [24]:
# Exporting replay
from osrparse import Replay
from osrparse.utils import ReplayEventOsu, Key
from pathlib import Path

replay_path = list(Path(path).glob('*.osr'))[0]
replay = Replay.from_path(replay_path)

time_deltas = np.diff(replay_predictions[:, 0], prepend=replay_predictions[0, 0])
time_deltas[0] = replay_predictions[0, 0]

keymap = [0, 1, 2, 11]
replay_data = []
for i in range(len(replay_predictions)):
    pred = replay_predictions[i]
    key = Key(keymap[int(pred[3])])
    replay_event = ReplayEventOsu(int(time_deltas[i]), float(pred[1]), float(pred[2]), key)
    replay_data.append(replay_event)
    
markers = [
    ReplayEventOsu(0, 256.0, -500.0, Key(0)),
    ReplayEventOsu(-1, 256.0, -500.0, Key(0))
]

replay.replay_data = markers + replay_data
replay.username = f'Model_p{pos_version}_k{key_version}'

replay.write_path(f"{path}/predictions_{pos_version}_{key_version}.osr")