In [4]:
import numpy as np
from scipy.special import softmax

In [5]:
gen = {
    1: 'rock',
    2: 'pop',
    3: 'electrónica',
    4: 'hip hop'
}

sub = {
    1: ['rock clásico', 'grunge', 'punk rock'],
    2: ['pop contemporáneo', 'k-pop', 'synthpop'],
    3: ['house', 'techno', 'trance'],
    4: ['rap', 'trap', 'lo-fi']
}

albums = {
    (1, 1): ['The Beatles - Abbey Road', 'Led Zeppelin - IV', 'Pink Floyd - The Dark Side of the Moon', 'The Rolling Stones - Let It Bleed', 'Queen - A Night at the Opera'],
    (1, 2): ['Nirvana - Nevermind', 'Pearl Jam - Ten', 'Soundgarden - Superunknown', 'Alice in Chains - Dirt', 'Stone Temple Pilots - Core'],
    (1, 3): ['The Ramones - Rocket to Russia', 'Sex Pistols - Never Mind the Bollocks', 'The Clash - London Calling', 'Dead Kennedys - Fresh Fruit for Rotting Vegetables', 'Green Day - Dookie'],

    (2, 1): ['Taylor Swift - 1989', 'Ariana Grande - Sweetener', 'Dua Lipa - Future Nostalgia', 'Billie Eilish - When We All Fall Asleep, Where Do We Go?', 'Ed Sheeran - Divide'],
    (2, 2): ['BTS - Map of the Soul: 7', 'Blackpink - The Album', 'EXO - Don’t Mess Up My Tempo', 'Twice - Eyes Wide Open', 'Stray Kids - Go Live'],
    (2, 3): ['Depeche Mode - Violator', 'The Weeknd - After Hours', 'CHVRCHES - The Bones of What You Believe', 'Ladytron - Velocifero', 'M83 - Hurry Up, We’re Dreaming'],

    (3, 1): ['Daft Punk - Random Access Memories', 'Calvin Harris - Motion', 'David Guetta - Nothing But the Beat', 'Disclosure - Settle', 'Avicii - True'],
    (3, 2): ['Carl Cox - F.A.C.T.', 'Richie Hawtin - DE9: Closer to the Edit', 'Jeff Mills - The Bells', 'Charlotte de Witte - Return to Nowhere', 'Adam Beyer - Drumcode Radio'],
    (3, 3): ['Armin van Buuren - A State of Trance 2013', 'Tiesto - Elements of Life', 'Paul van Dyk - In Between', 'Above & Beyond - Group Therapy', 'Ferry Corsten - Blueprint'],

    (4, 1): ['Kendrick Lamar - To Pimp a Butterfly', 'Nas - Illmatic', 'J. Cole - 2014 Forest Hills Drive', 'Jay-Z - The Blueprint', 'Tupac - All Eyez on Me'],
    (4, 2): ['Travis Scott - Astroworld', 'Future - DS2', 'Young Thug - So Much Fun', 'Lil Uzi Vert - Eternal Atake', '21 Savage - i am > i was'],
    (4, 3): ['Joji - Ballads 1', 'Chillhop Music - Essentials Summer 2020', 'Jinsang - Life', 'Nujabes - Modal Soul', 'Brock Berrigan - The Scenic Route']
}

In [6]:
def select_random_album(gen, sub, albums):
    '''
    Selecciona un album aleatorio de la base de datos
    '''
    gen_idx = np.random.choice(list(gen.keys()))
    sub_idx = np.random.choice(range(1, len(sub[gen_idx]) + 1))
    selected_album = np.random.choice(albums[(gen_idx, sub_idx)])
    return gen_idx, sub_idx, selected_album

def get_user_feedback(selected_album, gen, sub, gen_idx, sub_idx):
    '''
    Pregunta al usuario si le gusta el album seleccionado
    '''
    print(f"Te gusta el albúm {selected_album}? del género {gen[gen_idx]} y subgénero {sub[gen_idx][sub_idx - 1]}")
    response = input("¿Te gusta? 1: Sí, 0: No\n")
    return 1 if response == '1' else -1

def update_reward(reward, rwrd, learning_rate):
    '''
    Actualiza la recompensa
    '''
    return reward + learning_rate * (rwrd - reward)

def update_parameters(gen_params, sub_params, gen_idx, sub_idx, rwrd, reward, learning_rate):
    '''
    Actualiza los parámetros de los generos y subgeneros
    '''
    genonehot = np.zeros(len(gen_params))
    genonehot[gen_idx - 1] = 1

    subonehot = np.zeros(len(sub_params[gen_idx]))
    subonehot[sub_idx - 1] = 1

    softmax_gen = softmax(gen_params)
    softmax_sub = softmax(sub_params[gen_idx])

    gradgen = genonehot - softmax_gen
    gradsub = subonehot - softmax_sub

    gen_params += learning_rate * gradgen * (rwrd - reward)
    sub_params[gen_idx] += learning_rate * gradsub * (rwrd - reward)

def main(gen, sub, albums, gen_params, sub_params, learning_rate, epoch):
    '''
    Entrena el modelo
    '''

    # Inicialización de la recompensa
    reward = 0

    # Entrenamiento
    for i in range(epoch):
        gen_idx, sub_idx, selected_album = select_random_album(gen, sub, albums) # Seleccionar un album aleatorio
        rwrd = get_user_feedback(selected_album, gen, sub, gen_idx, sub_idx) # Obtener feedback del usuario
        reward = update_reward(reward, rwrd, learning_rate) # Actualizar recompensa
        update_parameters(gen_params, sub_params, gen_idx, sub_idx, rwrd, reward, learning_rate) # Actualizar parámetros

        print(f"Reward: {reward}")
        print(f"Params Gen: {gen_params}")
        print(f"Params Sub: {sub_params[gen_idx]}")

    print("-----------------------------------")
    print("Parámetros finales generos:", gen_params, softmax(gen_params))

    print("Parámetros finales subgeneros:")
    for genero, params in sub_params.items():
        print(f"{genero}: {params}, softmax: {softmax(params)}")

    print("-----------------------------------")
    print("Recompensa promedio final:", reward)


In [7]:
gen_params = np.zeros(len(gen))
sub_params = {genero: np.zeros(len(sub[genero])) for genero in gen}

learning_rate = 0.1
epoch = 20

main(gen, sub, albums, gen_params, sub_params, learning_rate, epoch)

Te gusta el albúm Joji - Ballads 1? del género hip hop y subgénero lo-fi
Reward: -0.1
Params Gen: [ 0.0225  0.0225  0.0225 -0.0675]
Params Sub: [ 0.03  0.03 -0.06]
Te gusta el albúm M83 - Hurry Up, We’re Dreaming? del género pop y subgénero synthpop
Reward: -0.19
Params Gen: [ 0.04319531 -0.03780469  0.04319531 -0.04858592]
Params Sub: [ 0.027  0.027 -0.054]
Te gusta el albúm 21 Savage - i am > i was? del género hip hop y subgénero trap
Reward: -0.271
Params Gen: [ 0.06220692 -0.02027231  0.06220692 -0.10414153]
Params Sub: [ 0.05501775 -0.01788225 -0.0371355 ]
Te gusta el albúm Nujabes - Modal Soul? del género hip hop y subgénero lo-fi
Reward: -0.34390000000000004
Params Gen: [ 0.07962127 -0.00423664  0.07962127 -0.15500591]
Params Sub: [ 0.07810637  0.0035831  -0.08168947]
Te gusta el albúm Young Thug - So Much Fun? del género hip hop y subgénero trap
Reward: -0.40951000000000004
Params Gen: [ 0.09553561  0.01039758  0.09553561 -0.2014688 ]
Params Sub: [ 0.09934317 -0.0357542  -0.063