# Audio2Map
This is an encoder-decoder model based off of seq2seq.

It takes in an audio file for the music as an mp3 and outputs a fully functional map for the hit rhythm game Osu!

In [None]:
import librosa
import os
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from functools import reduce

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_features = 8
STOP = torch.full((1, num_features), -1, dtype=torch.float32)

## Preprocessing

Here, we transform the input to a Constant-Q spectrogram spanning C1 to roughly C7. Then, for training, we obtain the pkl file containing the output vector representing the target output map.

We also obtain the difficulty for our target output to feed into the decoder, when deployed, this will be input from the user.

In [None]:
import datetime
def convert_to_spectrogram(filename):
	try:
		targetSampleRate = 11025
		y, sr = librosa.load(filename, sr=targetSampleRate)
		C = np.abs(librosa.cqt(y, sr=targetSampleRate, n_bins=84, bins_per_octave=12))
		S = librosa.amplitude_to_db(C, ref=np.max)
		#plot the spectrogram

		'''plt.figure(figsize=(12, 4))
		librosa.display.specshow(S, sr=targetSampleRate, x_axis='time', y_axis='cqt_note')
		plt.colorbar(format='%+2.0f dB')
		plt.title('Constant-Q power spectrogram')
		plt.tight_layout()
		plt.show()'''
		return S
	except:
		tsprint("ERROR: cannot convert to spectrogram. Removed file " + filename + ".")
		os.remove(filename) #remove audio file
		fname = filename.split('/')[1]
		# Remove all maps with invalid audio
		for f in [name for name in os.listdir("maps/") if os.path.isfile(os.path.join("maps/", name))]:
			if fname in f:
				os.remove(os.path.join("maps/", f))
		# Remove all pkls with invalid audio
		for f in [name for name in os.listdir("pickles/") if os.path.isfile(os.path.join("pickles/", name))]:
			if fname in f:
				os.remove(os.path.join("pickles/", f))

def get_pkl(filename):
	try:
		return pickle.load(open(filename, 'rb'))
	except:
		tsprint("ERROR: .pkl file does not exist.")
		return -1

def tsprint(s):
	print("[" + datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + "] " + s)

def parse_difficulty(filename):
	if(not os.path.isfile(filename)):
		tsprint("ERROR: map file does not exist. Removing.")
		os.remove("pickles/" + filename.split("/")[1].split(".")[0] + ".pkl")
		return -1

	with open(filename, "r") as f:
		try:
			lines = f.readlines()
		except:
			tsprint("ERROR: cannot read lines of .osu file.")


	difficulty = [-1,-1,-1,-1,-1,-1]

	for line in lines:
		#difficulty
		if line.startswith("HPDrainRate"): difficulty[0] = float(line.split(":", 1)[1])
		elif line.startswith("CircleSize"): difficulty[1] = float(line.split(":", 1)[1])
		elif line.startswith("OverallDifficulty"): difficulty[2] = float(line.split(":", 1)[1])
		elif line.startswith("ApproachRate"): difficulty[3] = float(line.split(":", 1)[1])
		elif line.startswith("SliderMultiplier"): difficulty[4] = float(line.split(":", 1)[1])
		elif line.startswith("SliderTickRate"): difficulty[5] = float(line.split(":", 1)[1])
		elif not (line.startswith("[Difficulty]")): break

	#check if all the difficulty stats are there
	for val in difficulty:
		if val == -1:
			tsprint("ERROR: Not a valid osu! map due to insufficient stats. Removed file " + filename + ".")
			os.remove(filename)
			return -1


	return torch.tensor(difficulty)

# TODO: Deprecate
def load_data():
	inputs = []
	diffs = []
	targets = []

	curr_length = 0
	counter = 0

	if os.path.isfile("loaded_save.pkl"):
		inputs, diffs, targets = pickle.load(open("loaded_save.pkl", 'rb'))
		curr_length = len(inputs)


	for pickle_root, pickle_dirs, pickle_files in os.walk("pickles"):
		for pickle_file in pickle_files:
			counter += 1
			if counter < curr_length: continue

			tsprint("Parsing file " + pickle_file)
			inputs.append(convert_to_spectrogram(os.path.join("audio/", pickle_file.split("_")[0] + ".mp3")))
			diffs.append(parse_difficulty("maps/" + pickle_file.split(".")[0] + ".osu"))
			targets.append(get_pkl("pickles/" + pickle_file))

			if counter % 10 == 0:
				pickle.dump([inputs, diffs, targets], open("loaded_save.pkl", 'wb'))
				tsprint("Saved progress.")
				tsprint("Parsed " + str(counter) + " files.")

	return inputs, diffs, targets

In [None]:
from torch.utils.data import Dataset

class Audio2Map(Dataset):
     def __init__(self, input_dir, maps_dir, target_dir):
        self.in_dir = input_dir
        self.maps_dir = maps_dir
        self.tar_dir = target_dir
     def __len__(self):
        return len([name for name in os.listdir(self.tar_dir) if os.path.isfile(os.path.join(self.tar_dir, name))])
     def __getitem__(self, idx):
        # use the listdir() index 5Head
        # Get the current map name w/o .osu
        files = [name for name in os.listdir(self.tar_dir) if os.path.isfile(os.path.join(self.tar_dir, name))]
        currfile = files[idx][:-4]
        spec = convert_to_spectrogram(os.path.join(self.in_dir, currfile.split('_', 1)[0] + ".mp3"))
        while(type(spec) == type(None)):
            idx += 1
            currfile = files[idx][:-4]
            spec = convert_to_spectrogram(os.path.join(self.in_dir, currfile.split('_', 1)[0] + ".mp3"))
        input = torch.tensor(spec.T).float()
        diff = torch.t(parse_difficulty(os.path.join(self.maps_dir, currfile + ".osu"))).float()
        out = get_pkl(os.path.join(self.tar_dir, currfile + ".pkl"))[0].to_dense().float()
        out = torch.cat((out, STOP), 0)
        return input, diff, out

In [None]:
from torch.utils.data import random_split
# Now get the data >:D
#inputs, diffs, targets = load_data() #pickle.load(open("loaded_save.pkl", 'rb'))
a2m_data = Audio2Map("audio/", "maps/", "pickles/")

test_split = 0.2
train_data, test_data = random_split(a2m_data, [1-test_split, test_split])

In [None]:
from torch.utils.data import DataLoader

# Must sample individually due to each input and output being different sizes :(
train_dl = DataLoader(train_data, batch_size = None, batch_sampler = None, shuffle = True)
test_dl = DataLoader(test_data, batch_size = None, batch_sampler = None, shuffle = True)
# Output from dataloader is a list of size 3 containing a single input, difficulty, and output

In [None]:
from torch import nn
from torch.nn import functional as F
"""
- Given a song, we can generate a spectrogram
- Take the spectrogram and produce a list of times (rythmic beats)
"""
# Encoder
audio_dim = 84
hidden_dim = 64

# Use LSTM to predict the note timings of the song
class Encoder(nn.Module):
    def __init__(self, dropout=0.2):
        super(Encoder, self).__init__()
        self.lstm = nn.LSTM(audio_dim, hidden_dim, batch_first=True, bidirectional=True, device=device)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out, hidden = self.lstm(x)
        out = self.dropout(out)
        return out, hidden

In [None]:
# Decoder
from math import floor
class Decoder(nn.Module):
    def __init__(self, dropout=0.2):
        super(Decoder, self).__init__()
        self.lstm = nn.LSTM(num_features + 6, hidden_dim, num_layers=2, batch_first=True, device=device)
        self.dropout = nn.Dropout(dropout)
        self.hiddenfc = nn.Linear(hidden_dim, hidden_dim//2, device=device)
        self.outputfc = nn.Linear(hidden_dim//2, num_features, device=device)

    def forward(self, encoder_out, encoder_hc, difficulty, target=None):
        decoder_input = torch.zeros((1, num_features + 6), device=device)
        decoder_hidden = encoder_hc
        decoder_outputs = []

        prev_percent = 0
        currStop = torch.cat((STOP.to(device), difficulty.unsqueeze(0)), 1)

        while(not torch.equal(decoder_input, currStop)):
            decoder_output, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target is not None:
                curr_percent = floor(((len(decoder_outputs)+1)/target.shape[0])*100)
                if curr_percent > prev_percent:
                    prev_percent = curr_percent
                    print(f"Training...{curr_percent}%")

                i = len(decoder_outputs)-1
                decoder_input = torch.cat((target[i], difficulty), 0).unsqueeze(0)
            else:
                if (len(decoder_outputs)+1) % 100 == 0
                    print(f"Timestep: {len(decoder_outputs)+1}")
                decoder_input = torch.cat((decoder_output, difficulty.unsqueeze(0)), 1).detach()

        decoder_outputs = torch.cat(decoder_outputs, 0)
        return decoder_outputs, decoder_hidden, None

    def forward_step(self, x, hc):
        x, hc = self.lstm(x, hc)
        drp = self.dropout(x)
        hidden = self.hiddenfc(drp)
        out = self.outputfc(hidden)
        return out, hc

## Training Time!

In [None]:
from torch.optim import Adam
# Create models and offload to GPU for processing, if available
enc = Encoder(0.4).to(device)
dec = Decoder(0.4).to(device)

def print_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    size_all_mb = (param_size + buffer_size) / 1024**2
    print('model size: {:.3f}MB'.format(size_all_mb))

print_model_size(enc)
print_model_size(dec)

In [None]:
import time
def train_epoch(data, encoder, decoder, encoder_opt, decoder_opt, lossfunc):
    # For now, data is a tuple of (x, diff, y)
    total_loss = 0
    for i, sample in enumerate(data):
        tsprint(f"Current sample {i+1}")
        # Put data onto the GPU if available, otherwise its just on cpu unlucky
        x = sample[0].to(device)
        diff = sample[1].to(device)
        y = sample[2].to(device)

        encoder_opt.zero_grad()
        decoder_opt.zero_grad()

        encoder_outputs, encoder_hc = encoder(x)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hc, diff, target=y)

        loss = lossfunc(decoder_outputs, y)
        loss.backward()

        encoder_opt.step()
        decoder_opt.step()

        total_loss += loss.item()
    return total_loss/len(data)

def train(data, encoder, decoder, epochs=10, learning_rate=0.01):
    start = time.time()
    losshistory = []

    lossfunc = nn.MSELoss()

    enc_opt = Adam(enc.parameters(), lr=0.01)
    dec_opt = Adam(dec.parameters(), lr=0.01)
    for epoch in range(epochs):
        tsprint(f"Epoch: {epoch+1}")
        loss = train_epoch(data, encoder, decoder, enc_opt, dec_opt, lossfunc)
        curr_time = time.time()
        losshistory.append(loss)
        print(f"Loss: {loss} Time: {curr_time - start}")
    return losshistory

In [None]:
def get_mem_req(enc, dec):
    # register forward hooks to check intermediate activation size
    acts = []
    for name, module in enc.named_modules():
        if name == 'classifier' or name == 'features':
            continue
        module.register_forward_hook(lambda m, input, output: acts.append(output[0].detach()))
    for name, module in dec.named_modules():
        if name == 'classifier' or name == 'features':
            continue
        module.register_forward_hook(lambda m, input, output: acts.append(output[0].detach()))

    # execute single training step
    X, diff, y_true = next(iter(train_dl))
    # Forward pass
    y_hat, hc = enc(X[:10000])
    y_hat, _, _ = dec(y_hat, hc, diff, target=y_true[-10000:])
    loss = nn.MSELoss()(y_hat, y_true[-10000:])
    # Backward pass
    enc_opt = Adam(enc.parameters(), lr=0.01)
    dec_opt = Adam(dec.parameters(), lr=0.01)
    enc_opt.zero_grad()
    dec_opt.zero_grad()
    loss.backward()
    enc_opt.step()
    dec_opt.step()

    # approximate memory requirements
    model_param_size = sum([p.nelement() for p in enc.parameters()]) + sum([p.nelement() for p in dec.parameters()])
    grad_size = model_param_size
    batch_size = reduce((lambda x, y: x * y), X.shape)
    optimizer_size = sum([p.nelement() for p in enc_opt.param_groups[0]['params']]) + sum([p.nelement() for p in dec_opt.param_groups[0]['params']])
    act_size = sum([a.nelement() for a in acts])

    total_nb_elements = model_param_size + grad_size + batch_size + optimizer_size + act_size
    total_mb = total_nb_elements * 4 / 1024**2
    print(total_mb)
#get_mem_req(enc, dec)

In [None]:
train_loss = train(train_dl, enc, dec, epochs=5)

In [None]:
torch.save(enc.state_dict(), 'encoder.pth')
torch.save(dec.state_dict(), 'decoder.pth')

In [None]:
plt.plot(np.arange(len(train_loss)), train_loss, 'b', label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Make some predictions :O

Do some fine tuning for the model as necessary

In [None]:
def decode_audio(encoder, decoder, audio, diff):
	enc_out, states = encoder(audio)
	out, _, _ = decoder(enc_out, states, diff)

	return out

In [None]:
test_loss = 0
num_samples = 1

for i, sample in enumerate(test_dl):
	x = sample[0].to(device)
	diff = sample[1].to(device)
	y = sample[2].to(device)
	decoded_map = decode_audio(enc, dec, x, diff)
	print("Actual: ")
	print(y)
	print("Predicted: ")
	print(decoded_map)
	curr_loss = nn.MSELoss()(y, decoded_map)
	print(curr_loss)
	test_loss += curr_loss
test_loss /= num_samples
print(f"Average Testing Loss: {test_loss}")

## Evaluate that bish B)

In [None]:
full_enc = Encoder(0.4).to(device)
full_dec = Decoder(0.4).to(device)
a2m_dl = DataLoader(a2m_data, batch_size = None, batch_sampler = None, shuffle = True)

full_train_loss = train(a2m_dl, enc, dec, epochs=10)

In [None]:
plt.plot(np.arange(len(full_train_loss)), full_train_loss, 'b', label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
full_test_loss = 0
num_samples = 1

for i, sample in enumerate(test_dl):
	x = sample[0].to(device)
	diff = sample[1].to(device)
	y = sample[2].to(device)
	decoded_map = decode_audio(enc, dec, x, diff)
	print("Actual: ")
	print(y)
	print("Predicted: ")
	print(decoded_map)
	curr_loss = nn.MSELoss()(y, decoded_map)
	print(curr_loss)
	full_test_loss += curr_loss
full_test_loss /= num_samples
print(f"Average Testing Loss: {full_test_loss}")