In [None]:
%load_ext autoreload
%autoreload 2

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import re
import math
import glob
from importlib import reload

import tqdm 
import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import osu.rulesets.beatmap as bm
import osu.rulesets.replay as rp
import osu.rulesets.hitobjects as hitobjects
import osu.dataset as dataset

import osu.preview.preview as preview

In [None]:
obj_dataset = dataset.replay_mapping_from_cache(3250)

In [None]:
obj_dataset

In [None]:
input_data = dataset.input_data(obj_dataset, verbose=True)
output_data = dataset.target_data(obj_dataset, verbose=True)

In [None]:
input = input_data
output = output_data

In [None]:
input[0:500]

In [None]:
output[0:500]

In [None]:
import torch

xs = np.reshape(input.values, (-1, dataset.BATCH_LENGTH, len(dataset.INPUT_FEATURES)))

# try:
#     target_data = pd.read_pickle('.data/target_data.dat')
# except:
#     target_data = dataset.target_data(dataset, verbose=True)
#     target_data.to_pickle('.data/target_data.dat')

ys = np.reshape(output.values, (-1, dataset.BATCH_LENGTH, len(dataset.OUTPUT_FEATURES)))

xs.shape

In [None]:
ys.shape

In [None]:
# save dataset so we have it later

# os.makedirs(f'.datasets', exist_ok=True)
# 
# np.save(f'.datasets/xs.npy', xs)
# np.save(f'.datasets/ys.npy', ys)

In [None]:
# reload from saved

xs = np.load(f'.datasets/xs.npy')
ys = np.load(f'.datasets/ys.npy')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from random import randint 
from torch.utils.data import DataLoader, TensorDataset

BATCH_SIZE = 128

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 
# x_train, x_test, y_train, y_test = train_test_split(xs, ys, test_size=0.2, random_state=randint(0, 100))
# 
# train_dataset = TensorDataset(torch.FloatTensor(x_train), torch.FloatTensor(y_train))
# 
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
from osu.rnn import OsuReplayRNN

In [None]:
print("Creating RNN model...")
rnn = OsuReplayRNN(batch_size=BATCH_SIZE, noise_std=0.00)
rnn.load_data(xs, ys)

In [None]:
RNN_EPOCHS = 25

# Train the RNN
rnn.train(epochs=RNN_EPOCHS)
rnn.plot_losses()
rnn.save()

In [None]:
import torch
import gc

# Clear GPU cache
torch.cuda.empty_cache()

# Force garbage collection
gc.collect()

del rnn

In [None]:
from osu.gan import OsuReplayGAN

In [None]:
print("Creating GAN model...")
gan = OsuReplayGAN(batch_size=BATCH_SIZE)
gan.load_data(xs, ys)

In [None]:
GAN_EPOCHS = 10

# Train the GAN
for i in range(8):
    gan.train(epochs=GAN_EPOCHS)
    gan.save()
    
gan.plot_losses()

In [None]:
from osu.keys import OsuKeyModel

In [None]:
print("Creating keypress model...")
keys = OsuKeyModel(batch_size=BATCH_SIZE)
keys.load_data(xs, ys)

In [None]:
KEYS_EPOCHS = 12

# Train the key model
for i in range(12):
    keys.train(epochs=KEYS_EPOCHS)
    keys.save()
    
keys.plot_losses()

In [None]:
del keys

In [None]:
from osu.rulesets.mods import Mods
import osu.rulesets.beatmap as bm
import osu.rulesets.replay as rp
import osu.dataset as dataset

test_name = '1hope'
test_mods = Mods.HARD_ROCK
test_map_path = f'assets/{test_name}_map.osu'
test_song = f'assets/{test_name}_song.mp3'

test_map = bm.load(test_map_path)
test_map.apply_mods(test_mods)

data = dataset.input_data(test_map)
data = np.reshape(data.values, (-1, dataset.BATCH_LENGTH, len(dataset.INPUT_FEATURES)))
data = torch.FloatTensor(data)

In [None]:
with torch.no_grad():
    # Use generator for inference
    replay_data = gan.generate(data)
    
replay_data = np.concatenate(replay_data)
if not os.path.exists('.generated'):
    os.makedirs('.generated')
    
np.save('.generated\\' + test_name + '.npy', replay_data)

print(f"Generated replay data shape: {replay_data.shape}")
print(f"Saved to .generated\\{test_name}.npy")

In [None]:
replay_data[:500]

In [None]:
import importlib
import osu.preview.preview as preview

importlib.reload(preview)

preview.preview_replay_raw(replay_data, test_map_path, test_mods, test_song)