# This is a notebook for testing diff models and architectures 
checked into git because im too lazy

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from models.vae.vae_t import OsuReplayTVAE

In [3]:
model = OsuReplayTVAE.load("replaytvae_morerecent.pt")

ReplayTVAE initialized on cuda
decoder parameters: 2806210
encoder parameters: 2621952
Total parameters: 5428162
OsuReplayTVAE loaded from replaytvae_morerecent.pt


In [4]:
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm.auto as tqdm

import osu.dataset as dataset

from models.annealer import Annealer
from models.base import OsuModel
from models.model_utils import TransformerArgs
from models.vae.encoder import ReplayEncoder
from models.standard_encoder_t import MapEncoder

class ReplayEncoder_Lstm(nn.Module):
    def __init__(self, input_size, latent_dim=32, noise_std=0.0, past_frames=0, future_frames=0):
        super().__init__()
        self.past_frames = past_frames
        self.future_frames = future_frames
        self.window_size = past_frames + 1 + future_frames  # +1 for current frame

        # Windowed beatmap features + cursor positions
        # TODO! testing unconditional vae
        combined_size = (input_size * self.window_size) + 2

        self.lstm = nn.LSTM(combined_size, 256, num_layers=2, batch_first=True, dropout=0.2)

        self.dense1 = nn.Linear(256, 128)
        self.noise_std = noise_std
        self.dense2 = nn.Linear(128, 64)

        # Output layers for reparameterization trick
        self.mu_layer = nn.Linear(64, latent_dim)
        self.logvar_layer = nn.Linear(64, latent_dim)

    def forward(self, beatmap_features, positions):
        # beatmap_features is now already windowed
        # gaussian noise during training
        if self.training and self.noise_std > 0:
            noise = torch.randn_like(positions) * self.noise_std
            positions = positions + noise

        # Combine windowed beatmap features with positions
        # (embeddings in this case)
        x = torch.cat([beatmap_features, positions], dim=-1)
        # x = beatmap_features

        # Encode sequence
        _, (h_n, _) = self.lstm(x)

        h = h_n[-1]

        h = F.relu(self.dense1(h))

        h = F.relu(self.dense2(h))

        # Output mean and log variance for reparameterization trick
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)

        return mu, logvar


class ReplayDecoder_Lstm(nn.Module):
    """Decode latent code + beatmap features to cursor positions"""

    def __init__(self, input_size, latent_dim=48, past_frames=0, future_frames=0):
        super().__init__()
        self.past_frames = past_frames
        self.future_frames = future_frames
        self.window_size = past_frames + 1 + future_frames

        combined_size = (input_size) + latent_dim

        # Symmetric layers to encoder (not really almost)
        self.lstm = nn.LSTM(combined_size, 256, num_layers=2, batch_first=True, dropout=0.3)

        self.dense1 = nn.Linear(256, 128)
        self.dense2 = nn.Linear(128, 64)

        self.output_layer = nn.Linear(64, 2)  # x, y positions

    def forward(self, beatmap_features, latent_code):
        batch_size, seq_len, _ = beatmap_features.shape

        # beatmap_features is already windowed
        # Expand latent code to sequence length
        latent_expanded = latent_code.unsqueeze(1).expand(-1, seq_len, -1)

        # Combine windowed features with latent code
        x = torch.cat([beatmap_features, latent_expanded], dim=-1)

        lstm_out, _ = self.lstm(x)

        features = F.relu(self.dense1(lstm_out))
        features = F.relu(self.dense2(features))

        positions = self.output_layer(features)

        return positions


# switch from the RNN based VAE to a transformer based one
# TODO! save hyperparam dict
class OsuReplayTVAE_Lstm(OsuModel):
    def __init__(
        self,
        annealer: Annealer = None,
        batch_size=64,
        device=None,
        latent_dim=64,
        transformer_args: TransformerArgs = None,
        noise_std=0.0,
        frame_window=(400, 900),
        compile: bool = True
    ):
        self.latent_dim = latent_dim
        self.transformer_args = transformer_args or TransformerArgs()
        self.past_frames = frame_window[0]
        self.future_frames = frame_window[1]
        self.noise_std = noise_std
        self.annealer = annealer or Annealer(
            total_steps=10, range=(0, 0.3), cyclical=True, stay_max_steps=5
        )

        super().__init__(
            batch_size=batch_size,
            device=device,
            compile=compile
        )

    def _initialize_models(self, **kwargs):
        self.encoder = ReplayEncoder_Lstm(
            input_size=self.transformer_args.embed_dim,
            latent_dim=self.latent_dim,
            noise_std=self.noise_std,
            # map encoder already encodes information in a window
            past_frames=0,
            future_frames=0,
        )

        self.map_encoder = MapEncoder(
            input_size=len(dataset.INPUT_FEATURES),
            transformer_args=self.transformer_args,
            future_frames=self.future_frames,
            past_frames=self.past_frames
        )

        self.decoder = ReplayDecoder_Lstm(
            input_size=self.transformer_args.embed_dim,
            latent_dim=self.latent_dim,
            past_frames=self.past_frames,
            future_frames=self.future_frames,
        )

    def _initialize_optimizers(self):
        params = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.map_encoder.parameters())
        self.optimizer = optim.AdamW(
            params, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001
        )


    def _train_epoch(self, epoch, total_epochs, **kwargs):
        epoch_total_loss = 0
        epoch_recon_loss = 0
        epoch_kl_loss = 0

        for i, (batch_x, batch_y_pos) in enumerate(
            tqdm.tqdm(
                self.train_loader,
                disable=True,
                position=1,
                desc=f"Epoch {epoch + 1}/{total_epochs} (Beta: {self.annealer.current()})",
            )
        ):
            self._set_custom_train_status(f"Batch {i}/{len(self.train_loader)}")
            batch_x = batch_x.to(self.device)             # (B, T, features)
            batch_y_pos = batch_y_pos.to(self.device)     # (B, T, pos)

            self.optimizer.zero_grad()

            # Forward pass
            reconstructed, mu, logvar = self.forward(batch_x, batch_y_pos)

            # Compute loss
            total_loss, recon_loss, kl_loss = self.loss_function(
                reconstructed, batch_y_pos, mu, logvar
            )

            # Backward pass
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                list(self.encoder.parameters()) + list(self.decoder.parameters()),
                max_norm=1.0,
            )
            self.optimizer.step()

            epoch_total_loss += total_loss.item()
            epoch_recon_loss += recon_loss.item()
            epoch_kl_loss += kl_loss.item()

        # Calculate average losses
        avg_total_loss = epoch_total_loss / len(self.train_loader)
        avg_recon_loss = epoch_recon_loss / len(self.train_loader)
        avg_kl_loss = epoch_kl_loss / len(self.train_loader)

        # Step the annealer
        self.annealer.step()

        return {
            "total_loss": avg_total_loss,
            "recon_loss": avg_recon_loss,
            "kl_loss": avg_kl_loss,
        }

    # allow backpropogatino through sampling
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, beatmap_features, positions):
        embeddings = self.map_encoder(beatmap_features)

        mu, logvar = self.encoder(embeddings, positions)

        # Sample latent code
        z = self.reparameterize(mu, logvar)

        reconstructed = self.decoder(embeddings, z)

        return reconstructed, mu, logvar

    # recon + kl term
    def loss_function(self, reconstructed, original, mu, logvar):
        # TODO! am i supposed to avg this? probably?
        recon_loss = F.mse_loss(reconstructed, original, reduction="sum")
        # recon_loss /= original.shape[0]

        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kld /= original.shape[0]

        total_loss = recon_loss + self.annealer(kld)

        return total_loss, recon_loss, kld

    def _get_state_dict(self):
        return {
            "encoder": self.encoder.state_dict(),
            "decoder": self.decoder.state_dict(),
            "latent_dim": self.latent_dim,
            "transformer_args": self.transformer_args.to_dict(),
            "noise_std": self.noise_std,
            "input_size": self.input_size,
            "past_frames": self.past_frames,
            "future_frames": self.future_frames,
        }

    def _load_state_dict(self, checkpoint):
        self.encoder.load_state_dict(checkpoint["encoder"])
        self.decoder.load_state_dict(checkpoint["decoder"])

    @classmethod
    def load(cls, path: str, device: Optional[torch.device] = None, **kwargs):
        device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        checkpoint = torch.load(path, map_location=device)

        # Load transformer configuration from checkpoint
        transformer_args = TransformerArgs.from_dict(checkpoint["transformer_args"])

        vae_args = {
            "latent_dim": checkpoint.get("latent_dim", 64),
            "transformer_args": transformer_args,
            "noise_std": checkpoint.get("noise_std", 0.0),
            "frame_window": (checkpoint.get("past_frames", 40), checkpoint.get("future_frames", 90)),
        }

        instance = cls(device=device, **kwargs, **vae_args)

        instance._load_state_dict(checkpoint)
        instance._set_eval_mode()

        print(f"{cls.__name__} loaded from {path}")
        return instance

    def generate(self, beatmap_data, num_samples=1):
        self._set_eval_mode()

        with torch.no_grad():
            beatmap_tensor = torch.FloatTensor(beatmap_data).to(self.device)

            batch_size = beatmap_tensor.shape[0]

            # sample from prior distribution
            z = torch.randn(batch_size, self.latent_dim, device=self.device)
            embeddings = self.map_encoder(beatmap_tensor)
            # embeddings, mu, logvar = self.encoder(embeddings)
            # z = self.reparameterize(mu, logvar)

            pos = self.decoder(embeddings, z)

        self._set_train_mode()

        return pos.cpu().numpy()


In [6]:
model = OsuReplayTVAE_Lstm.load("replaytvae_lstm_most_recent.pt")
# from models.vae.vae import OsuReplayVAE
# model = OsuReplayVAE.load("replayvae_most_recent.pt")

ReplayTVAE_Lstm initialized on cuda
decoder parameters: 1028418
encoder parameters: 973120
map_encoder parameters: 2572416
Total parameters: 4573954
OsuReplayTVAE_Lstm loaded from replaytvae_lstm_most_recent.pt


In [7]:
# test training results
from osu.rulesets.mods import Mods
import osu.rulesets.beatmap as bm
import osu.dataset as dataset
import numpy as np
import torch

test_name = '1hope'
test_mods = Mods.NONE
test_map_path = f'assets/{test_name}_map.osu'
test_song = f'assets/{test_name}_song.mp3'

test_map = bm.load(test_map_path)
test_map.apply_mods(test_mods)

data = dataset.input_data(test_map)
data = np.reshape(data.values, (-1, dataset.BATCH_LENGTH, len(dataset.INPUT_FEATURES)))
data = torch.FloatTensor(data)#[:1, :, :]

data.shape

Turning 1HOPE SNIPER into time series data: 100%|████████████████████████████████| 1/1 [00:00<00:00,  4.27it/s]


torch.Size([5, 2048, 9])

In [8]:
replay_data = model.generate(data)

import os

replay_data = np.concatenate(replay_data)
replay_data = np.pad(replay_data, ((0, 0), (0, 2)), mode='constant', constant_values=0)
if not os.path.exists('.generated'):
    os.makedirs('.generated')

print(f"Generated replay data shape: {replay_data.shape}")

replay_data

W0901 12:41:11.354000 5318 torch/_inductor/utils.py:1250] [1/0_1] Not enough SMs to use max_autotune_gemm mode


Generated replay data shape: (10240, 4)


array([[ 0.06962939, -0.06804687,  0.        ,  0.        ],
       [ 0.13007933, -0.00253613,  0.        ,  0.        ],
       [ 0.2266632 , -0.16867389,  0.        ,  0.        ],
       ...,
       [ 0.18648267, -0.15628761,  0.        ,  0.        ],
       [ 0.25100222, -0.10195369,  0.        ,  0.        ],
       [ 0.31378463,  0.03294045,  0.        ,  0.        ]],
      shape=(10240, 4), dtype=float32)

In [None]:
import osu.preview.preview as preview

preview.preview_replay_raw(replay_data, test_map_path, test_mods, test_song)

pygame 2.6.1 (SDL 2.28.4, Python 3.12.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


  from pkg_resources import resource_stream, resource_exists
