# Audio2Map
This is an encoder-decoder model based off of seq2seq.

It takes in an audio file for the music as an mp3 and outputs a fully functional map for the hit rhythm game Osu!

In [190]:
import librosa
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from functools import reduce

In [191]:
torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device(type='cuda', index=0)

## Preprocessing

Here, we transform the input to a Constant-Q spectrogram spanning C1 to roughly C7. Then, for training, we obtain the pkl file containing the output vector representing the target output map.

We also obtain the difficulty for our target output to feed into the decoder, when deployed, this will be input from the user.

In [192]:
import datetime
def convert_to_spectrogram(filename):
	try:
		targetSampleRate = 11025
		y, sr = librosa.load(filename, sr=targetSampleRate)
		C = np.abs(librosa.cqt(y, sr=targetSampleRate, n_bins=84, bins_per_octave=12))
		S = librosa.amplitude_to_db(C, ref=np.max)
		#plot the spectrogram
		
		'''plt.figure(figsize=(12, 4))
		librosa.display.specshow(S, sr=targetSampleRate, x_axis='time', y_axis='cqt_note')
		plt.colorbar(format='%+2.0f dB')
		plt.title('Constant-Q power spectrogram')
		plt.tight_layout()
		plt.show()'''
		return S
	except:
		tsprint("ERROR: cannot convert to spectrogram. Removed file " + filename + ".")

def get_pkl(filename):
	try:
		return pickle.load(open(filename, 'rb'))
	except:
		tsprint("ERROR: .pkl file does not exist.")
		return -1

def tsprint(s):
	print("[" + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "] " + s)

   
def parse_difficulty(filename):
	if(not os.path.isfile(filename)):
		tsprint("ERROR: map file does not exist. Removing.")
		os.remove("pickles/" + filename.split("/")[1].split(".")[0] + ".pkl")
		return -1

	with open(filename, "r") as f:
		try:
			lines = f.readlines()
		except:
			tsprint("ERROR: cannot read lines of .osu file.")


	difficulty = [-1,-1,-1,-1,-1,-1]

	for line in lines:
		#difficulty
		if line.startswith("HPDrainRate"): difficulty[0] = float(line.split(":", 1)[1])
		elif line.startswith("CircleSize"): difficulty[1] = float(line.split(":", 1)[1])
		elif line.startswith("OverallDifficulty"): difficulty[2] = float(line.split(":", 1)[1])
		elif line.startswith("ApproachRate"): difficulty[3] = float(line.split(":", 1)[1])
		elif line.startswith("SliderMultiplier"): difficulty[4] = float(line.split(":", 1)[1])
		elif line.startswith("SliderTickRate"): difficulty[5] = float(line.split(":", 1)[1])
		elif not (line.startswith("[Difficulty]")): break

	#check if all the difficulty stats are there
	for val in difficulty:
		if val == -1:
			tsprint("ERROR: Not a valid osu! map due to insufficient stats. Removed file " + filename + ".")
			os.remove(filename)
			return -1


	return torch.tensor(difficulty)

def load_data():
	inputs = []
	diffs = []
	targets = []

	curr_length = 0
	counter = 0

	if os.path.isfile("loaded_save.pkl"):
		inputs, diffs, targets = pickle.load(open("loaded_save.pkl", 'rb'))
		curr_length = len(inputs)


	for pickle_root, pickle_dirs, pickle_files in os.walk("pickles"):
		for pickle_file in pickle_files:
			counter += 1
			if counter < curr_length: continue

			tsprint("Parsing file " + pickle_file)
			inputs.append(convert_to_spectrogram(os.path.join("audio/", pickle_file.split("_")[0] + ".mp3")))
			diffs.append(parse_difficulty("maps/" + pickle_file.split(".")[0] + ".osu"))
			targets.append(get_pkl("pickles/" + pickle_file))

			if counter % 100 == 0:
				pickle.dump([inputs, diffs, targets], open("loaded_save.pkl", 'wb'))
				tsprint("Saved progress.")
				tsprint("Parsed " + str(counter) + " files.")
	
	return inputs, diffs, targets

In [266]:
# Now get the data >:D
inputs, diffs, targets = load_data() #pickle.load(open("loaded_save.pkl", 'rb'))

[2024-05-30 22:54:51] Parsing file 629136_6.pkl
[2024-05-30 22:54:51] Parsing file 629136_4.pkl
[2024-05-30 22:54:52] Parsing file 629136_5.pkl
[2024-05-30 22:54:52] Parsing file 998593_1.pkl
[2024-05-30 22:54:52] Parsing file 822575_1.pkl
[2024-05-30 22:54:52] Parsing file 998593_0.pkl
[2024-05-30 22:54:53] Parsing file 910271_1.pkl
[2024-05-30 22:54:53] Parsing file 629136_0.pkl
[2024-05-30 22:54:53] Parsing file 822575_2.pkl
[2024-05-30 22:54:54] Parsing file 629136_1.pkl
[2024-05-30 22:54:54] Parsing file 910271_3.pkl
[2024-05-30 22:54:54] Parsing file 888479_2.pkl
[2024-05-30 22:54:54] Parsing file 550491_0.pkl
[2024-05-30 22:54:55] Parsing file 543109_2.pkl
[2024-05-30 22:54:55] Parsing file 351153_0.pkl
[2024-05-30 22:54:55] Parsing file 562857_0.pkl
[2024-05-30 22:54:56] Parsing file 776808_0.pkl
[2024-05-30 22:54:56] Parsing file 543109_0.pkl
[2024-05-30 22:54:56] Parsing file 910271_0.pkl
[2024-05-30 22:54:57] Parsing file 81254_0.pkl
[2024-05-30 22:54:57] Parsing file 81254_

In [267]:
for i in range(len(inputs)):
	if inputs[i] is not None:
		inputs[i] = torch.tensor(inputs[i].T)
		diffs[i] = torch.t(diffs[i])
		targets[i] = torch.t(targets[i][0].to_dense())

torch.Size([2195, 84])
torch.Size([2195, 84])
torch.Size([2195, 84])
torch.Size([6633, 84])
torch.Size([3289, 84])
torch.Size([6633, 84])
torch.Size([3316, 84])
torch.Size([2195, 84])
torch.Size([3289, 84])
torch.Size([2195, 84])
torch.Size([3316, 84])
torch.Size([4180, 84])
torch.Size([8356, 84])
torch.Size([5715, 84])
torch.Size([1933, 84])
torch.Size([3030, 84])
torch.Size([5386, 84])
torch.Size([5715, 84])
torch.Size([3316, 84])
torch.Size([5627, 84])
torch.Size([5627, 84])
torch.Size([3737, 84])
torch.Size([6633, 84])
torch.Size([4559, 84])
torch.Size([4180, 84])
torch.Size([3316, 84])
torch.Size([4180, 84])
torch.Size([5715, 84])
torch.Size([4245, 84])
torch.Size([1919, 84])
torch.Size([2195, 84])
torch.Size([5212, 84])
torch.Size([1933, 84])
torch.Size([3289, 84])
torch.Size([4180, 84])
torch.Size([8356, 84])
torch.Size([1933, 84])
torch.Size([2195, 84])


In [257]:
from sklearn.model_selection import train_test_split

train_x, test_x, train_diffs, test_diffs, train_y, test_y = train_test_split(inputs, diffs, targets, test_size=0.1)
train_x, val_x, train_diffs, val_diffs, train_y, val_y = train_test_split(train_x, train_diffs, train_y, test_size=0.1)

In [258]:
from torch import nn
from torch.nn import functional as F
"""
- Given a song, we can generate a spectrogram
- Take the spectrogram and produce a list of times (rythmic beats)
"""
# Encoder
audio_dim = 84
hidden_dim = 64

# Use LSTM to predict the note timings of the song
class Encoder(nn.Module):
    def __init__(self, dropout=0.2):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(audio_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out, hidden = self.lstm(x)
        out = self.dropout(out)
        return out, hidden

In [289]:
# Decoder

num_features = 8

class Decoder(nn.Module):
    def __init__(self, dropout=0.2):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(num_features + 6, hidden_dim, num_layers=2, batch_first=True)
        self.dropout = nn.Dropout(dropout)

    def forward(self, encoder_out, encoder_hc, difficulty, target=None):
        decoder_input = torch.zeros((1, num_features + 6))
        decoder_hidden = encoder_hc
        decoder_outputs = []

        for i in range(encoder_out.shape[1]):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target is not None:
                decoder_input = torch.cat((target[i], torch.reshape(difficulty, (1, -1))), 1)
            else:
                print(decoder_output.shape)
                decoder_input = torch.cat((decoder_output, torch.reshape(difficulty, (1, -1))), 1)
        print(decoder_outputs.shape)
        decoder_outputs = torch.cat(decoder_outputs, 1)
        return decoder_outputs, decoder_hidden, None
        
    def forward_step(self, x, hc):
        x, hc = self.lstm(x, hc)
        x = self.dropout(x)
        return x, hc

## Training Time!

In [290]:
from torch.optim import Adam
enc = Encoder(0.4)
dec = Decoder(0.4)

In [291]:
import time 
def train_epoch(data, encoder, decoder, encoder_opt, decoder_opt, lossfunc):
    # For now, data is a tuple of (x, diff, y)
    total_loss = 0
    for sample in data:
        x = sample[0]
        diff = sample[1]
        y = sample[2]
        encoder_opt.zero_grad()
        decoder_opt.zero_grad()

        encoder_outputs, encoder_hc = encoder(x)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hc, diff)

        loss = lossfunc(decoder_outputs, y)
        loss.backward()

        encoder_opt.step()
        decoder_opt.step()

        total_loss += loss.item()
    return total_loss/len(data)

def train(data, encoder, decoder, epochs=10, learning_rate=0.01):
    start = time.time()
    losshistory = []

    lossfunc = nn.MSELoss()

    enc_opt = Adam(enc.parameters(), lr=0.01)
    dec_opt = Adam(dec.parameters(), lr=0.01)
    for epoch in range(epochs):
        loss = train_epoch(data, encoder, decoder, enc_opt, dec_opt, lossfunc)
        curr_time = time.time()
        losshistory.append(loss)
        print(f"Epoch {epoch+1} Loss: {loss} Time: {curr_time - start}")
    return losshistory

In [292]:
train_loss = train(list(zip(train_x, train_diffs, train_y)), enc, dec)

torch.Size([1, 64])


RuntimeError: input.size(-1) must be equal to input_size. Expected 14, got 70

In [None]:
plt.plot(np.arange(len(train_loss)), train_loss, 'b', label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Make some predictions :O

Do some fine tuning for the model as necessary

In [None]:
def decode_audio(audio):
	states = encoder.predict(audio)
	
	target_seq = np.zeros((1, 1, 84))

	decoded_map = []

	for i in range(audio.shape[0]):
		output, h, c = decoder.predict([target_seq] + states)
		
		decoded_map.append(output)

		states = [h, c]

	return decoded_map

In [None]:
for i in range(10):
	decoded_map = decode_audio(test_x[i])	
	print("Actual: ")
	print(test_y[i])
	print("Predicted: ")
	print(decoded_map)
	print(keras.losses.MSE(test_y[i], decoded_map))

## Evaluate that bish B)

In [None]:
decoder_targets = np.zeros(targets.shape)
decoder_targets[:, 0:-1] = decoder_targets[:, 1:]

audio2map.compile(optimizer='adam', loss='mean_squared_error', metrics=['loss', 'accuracy', 'precision', 'recall', 'f1'])
history = audio2map.fit([inputs, targets, diffs], decoder_targets, epochs=10, batch_size=32, validation_split=0.15)
audio2map.save('audio2map_full.h5') # This may not work well, but just in case we can

In [None]:
print("Final Loss: ")
print(history.history['loss'])
print("Final Accuracy: ")
print(history.history['accuracy'])
print("Final Precision: ")
print(history.history['precision'])
print("Final Recall: ")
print(history.history['recall'])
print("Final F1: ")
print(history.history['f1'])