In [1]:
import numpy as np
import matplotlib.pyplot as plt 
from IPython.display import display, HTML
from tqdm import tqdm
import os
import pandas as pd
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import plotly.graph_objects as go
# import imageio
import mpl_toolkits.mplot3d.axes3d as p3
import random

class Record:
    def __init__(self, T_0, notes_to_choose = 5):
        self.notes_to_choose = notes_to_choose
        self.transitions = [T_0]
        self.trajectories = []
        
    def give_transition_matrix(self,melody):
        lookup = np.empty((5))
        lookup[[0,1,2,3,4]] = np.arange(5)
        counts = np.zeros((5,5))
        np.add.at(counts, (melody[:-1], melody[1:]),1 )
        probs = counts/counts.sum(axis=0, keepdims=True)
        self.transitions.append(probs)
        return probs
    

class Evaluator():
    def __init__(self):
        self.reward_history = []
        self.trajectories = []
        self.trajectories_eval = []
    def give_reward(self, guess, label):
        if guess == label:
            return 1
        else:
            return 0
        
class Agent():
    def __init__(self, notes_to_choose, gamma=0., ep=0.01):
        self.n_actions = notes_to_choose
        self.states = notes_to_choose
        self.q = np.zeros((notes_to_choose, notes_to_choose))
        self.n = np.zeros((notes_to_choose, notes_to_choose))
        
        self.epsilon = ep
        self.gamma = gamma
        
    def give_action(self, state, greedy=False):
        if np.random.random()<self.epsilon and (greedy == False):
            return np.random.choice(range(self.n_actions), 1)[0]
        else:
            qs = self.q[state,:]
            qs = np.where( qs == np.max(qs) )[0]
            return np.random.choice(qs, 1)[0]
            

    def q_learn(self, ts):
        ### q-learn
        for interaction in ts:
            s, a, r, ns = interaction
            if ns != -1:
                self.n[s,a]+=1
                self.q[s,a] += (r+ self.gamma*np.max(self.q[ns,:]) - self.q[s,a])/self.n[s,a]
            else:
                self.n[s,a]+=1
                self.q[s,a] += (r- self.q[s,a])/self.n[s,a]
                
    def decrease_ep(self,ind):
        self.epsilon = max(self.epsilon*np.exp(-(ind+1)/100), 0.01)
        return
    
    def give_melody(self, melody):
        note_player=[]
        for ind,note in enumerate(melody):
            note_player.append(self.give_action(note, greedy=True))
        return note_player

In [2]:
evaluator = Evaluator()
agent = Agent(12, gamma=0., ep=1)
episode_evaluate = [0,1,2,3,10,20,30] #[int(k) for k in np.logspace(0,4,10)]#np.logspace(0,4,10)#[0,1,3,5,10,50]#100,500]#list(range(10))#[0, 10, 50,100]

melody = [0,1,2,3] 

def generate_run():
    for episode in tqdm(range(10**3)):
        ts=[]
        actions=[]
        for ind,note in enumerate(melody):
            action = agent.give_action(note)
            actions.append(action)
            reward = evaluator.give_reward(action,note)
            evaluator.reward_history.append(reward)
            if ind!=len(melody)-1:
                next_state = action
            else:
                next_state = -1
            ts.append([note, action, reward, next_state])

        agent.q_learn(ts)
        agent.decrease_ep(episode)
        #evaluator.trajectories.append(actions)

        if episode in episode_evaluate:
            evaluator.trajectories_eval.append(agent.give_melody(melody))
            evaluator.trajectories.append(actions)
            
    
    return evaluator.trajectories_eval

# generate_run()

In [3]:
test = generate_run()

# For visualization purposes we will add 1 to every entry
test = (np.asarray(test) + 1)

100%|██████████| 1000/1000 [00:00<00:00, 2211.66it/s]


In [5]:
%matplotlib tk
plt.rcParams["figure.figsize"] = [7.50, 3.50]
plt.rcParams["figure.autolayout"] = True
plt.rcParams['figure.facecolor'] = 'white'
# plt.zlim ([0, 1])

N = 50
fps = 250
frn = 75

x = np.linspace(-4, 4, N + 1)
x, y = np.meshgrid(x, x)
zarray = np.zeros((N + 1, N + 1, frn))

f = lambda x, y, sig: 1 / np.sqrt(sig) * np.exp(-(x ** 2 + y ** 2) / sig ** 2)

for i in range(frn):
   zarray[:, :, i] = f(x, y, 1.5 + np.sin(i * 2 * np.pi / frn))

# def change_plot(frame_number, zarray, plots):
#     # get random index
#     index = np.random.randint(4)
#     plots[index][0].remove()
#     ax =fig.add_subplot(1,4,index,projection='3d')
#     plots[index][0] = ax.plot_surface(x, y, zarray[:, :, frame_number], cmap="afmhot_r")

# color_opts = ["cubehelix_r", "gist_earth_r", "gnuplot2_r", "inferno_r"]
color_opts = ["pink_r", "afmhot_r", "ocean_r", "gist_stern_r"]
    
def change_plot(frame_number, zarray, plot1, plot2, plot3, plot4):
   plt.pause(0.1)
   new_col = color_opts[np.random.randint(4)]
   plot1[0].remove()
   plot1[0] = ax1.plot_surface(x, y, zarray[:, :, frame_number], cmap=new_col)
   plot2[0].remove()
   plot2[0] = ax2.plot_surface(x, y, zarray[:, :, frame_number], cmap=new_col)
   plot3[0].remove()
   plot3[0] = ax3.plot_surface(x, y, zarray[:, :, frame_number], cmap=new_col)
   plot4[0].remove()
   plot4[0] = ax4.plot_surface(x, y, zarray[:, :, frame_number], cmap=new_col)


fig = plt.figure()
ax1 =fig.add_subplot(141,projection='3d')
ax1.set_zlim(-5,5)
ax2 =fig.add_subplot(142,projection='3d')
ax2.set_zlim(-5,5)
ax3 =fig.add_subplot(143,projection='3d')
ax3.set_zlim(-5,5)
ax4 =fig.add_subplot(144,projection='3d')
ax4.set_zlim(-5,5)

plot1 = [ax1.plot_surface(x, y, zarray[:, :, 0], color='0.75', rstride=1, cstride=1)]
plot2 = [ax2.plot_surface(x, y, zarray[:, :, 0], color='0.75', rstride=1, cstride=1)]
plot3 = [ax3.plot_surface(x, y, zarray[:, :, 0], color='0.75', rstride=1, cstride=1)]
plot4 = [ax4.plot_surface(x, y, zarray[:, :, 0], color='0.75', rstride=1, cstride=1)]

plots = [plot1, plot2, plot3, plot4]

# ax1.axis('off')
# # ax1.set_facecolor("black")
# ax1.grid(False)

# ax2.axis('off')
# # ax2.set_facecolor("black")
# ax2.grid(False)

# ax3.axis('off')
# # ax3.set_facecolor("black")
# ax3.grid(False)

def turn_black(ax):
    # Get rid of colored axes planes
    # First remove fill
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    # Now set color to white (or whatever is "invisible")
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')

    # Bonus: To get rid of the grid as well:
    ax.axis('off')
#     ax.set_facecolor("black")
    ax.grid(False)

turn_black(ax1)
turn_black(ax2)
turn_black(ax3)
turn_black(ax4)

ani = animation.FuncAnimation(fig, change_plot, frames = 100, 
                              fargs=(zarray, plot1, plot2, plot3, plot4), repeat = False)

# ax.set_zlim(0, 1.1)
# ani = animation.FuncAnimation(fig, change_plot, frames = 70,
#                               fargs=(zarray, plots), interval=1000 / fps, repeat = False)

plt.show()

In [6]:
# !brew install imagemagick
# ani.save('~/Desktop/animation.html', writer='imagemagick', fps=60)

Updating Homebrew...
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/ca-certificates/manifests/2021-[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/ca-certificates/blobs/sha256:47[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mPouring ca-certificates--2021-09-30.all.bottle.1.tar.gz[0m
[34m==>[0m [1mRegenerating CA certificate bundle from keychain, this may take a while...[0m
🍺  /usr/local/Cellar/ca-certificates/2021-09-30: 3 files, 203.5KB
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/portable-ruby/portable-ruby/blobs/sha256:0cb1cc7af109437fe0e020c9f3b7b95c3c709b140bde9f991ad2c1433496dd42[0m
######################################################################### 100.0%
[34m==>[0m [1mPouring portable-ruby

[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/aom/manifests/3.2.0_1[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/aom/blobs/sha256:4ccf3a3b28fa2f[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libde265/manifests/1.0.8[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libde265/blobs/sha256:774fe5c9c[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
###########

[34m==>[0m [1mPouring webp--1.2.1.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/webp/1.2.1: 39 files, 2.4MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mjpeg-xl[39m[0m
[34m==>[0m [1mPouring jpeg-xl--0.5_1.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/jpeg-xl/0.5_1: 63 files, 17.4MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mlibvmaf[39m[0m
[34m==>[0m [1mPouring libvmaf--2.3.0.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/libvmaf/2.3.0: 16 files, 2.6MB
[32m==>[0m [1mInstalling imagemagick dependency: [32maom[39m[0m
[34m==>[0m [1mPouring aom--3.2.0_1.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/aom/3.2.0_1: 23 files, 13.3MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mlibde265[39m[0m
[34m==>[0m [1mPouring libde265--1.0.8.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/libde265/1.0.8: 22 files, 2.3MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mlibffi[39m[0m
[34m==>[0m [1mPouring libffi--3.4.2.catali

[34m==>[0m [1mPouring shared-mime-info--2.1.catalina.bottle.tar.gz[0m
[34m==>[0m [1m/usr/local/Cellar/shared-mime-info/2.1/bin/update-mime-database /usr/local/s[0m
🍺  /usr/local/Cellar/shared-mime-info/2.1: 86 files, 4.5MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mx265[39m[0m
[34m==>[0m [1mPouring x265--3.5.catalina.bottle.1.tar.gz[0m
🍺  /usr/local/Cellar/x265/3.5: 11 files, 35.8MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mlibheif[39m[0m
[34m==>[0m [1mPouring libheif--1.12.0.catalina.bottle.tar.gz[0m
[34m==>[0m [1m/usr/local/opt/shared-mime-info/bin/update-mime-database /usr/local/share/mi[0m
🍺  /usr/local/Cellar/libheif/1.12.0: 25 files, 2.8MB
[32m==>[0m [1mInstalling imagemagick dependency: [32mliblqr[39m[0m
[34m==>[0m [1mPouring liblqr--0.4.2_1.catalina.bottle.1.tar.gz[0m
🍺  /usr/local/Cellar/liblqr/0.4.2_1: 24 files, 134.5KB
[32m==>[0m [1mInstalling imagemagick dependency: [32mlibomp[39m[0m
[34m==>[0m [1mPouri

######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/lzo/blobs/sha256:c8f55ba0de8527[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/pixman/manifests/0.40.0[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/pixman/blobs/sha256:1862e6826a4[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/cairo/manifests/1.16.0_5[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloadin

[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libidn2/manifests/2.3.2[0m
Already downloaded: /Users/albasala/Library/Caches/Homebrew/downloads/b8f2405de653b6eec7b67d66be89a8aa5babeb4a79fefd07d1998040d99b02cb--libidn2-2.3.2.bottle_manifest.json
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libidn2/blobs/sha256:71c5f183ae[0m
Already downloaded: /Users/albasala/Library/Caches/Homebrew/downloads/70319ccf886ed70c19cdf34a3e4727cdba161e8c96b60f0de2817996a69d03c6--libidn2--2.3.2.catalina.bottle.tar.gz
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libtasn1/manifests/4.17.0[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/libtasn1/blobs/sha256:0b0b6a4b1[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDo

Already downloaded: /Users/albasala/Library/Caches/Homebrew/downloads/a9c864a1cb51235e8890239dcb57c468e765b33a6f69ac1f026f1313f19c7eda--gnutls-3.6.16_1.bottle_manifest.json
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/gnutls/blobs/sha256:464f68e7e6f[0m
Already downloaded: /Users/albasala/Library/Caches/Homebrew/downloads/89db6804a3ae65f3eef43f62a4a3f7d2101bbabe29c96f8c5cd51272a0004b5e--gnutls--3.6.16_1.catalina.bottle.tar.gz
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/jansson/manifests/2.14[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/jansson/blobs/sha256:ddf25d8386[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/emacs/manifests/27.2[0m
#####################

[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/cask/manifests/0.8.7[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/cask/blobs/sha256:bd85befe31659[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/gstreamer/manifests/1.18.4[0m
######################################################################## 100.0%
[34m==>[0m [1mDownloading https://ghcr.io/v2/homebrew/core/gstreamer/blobs/sha256:9fcc5eb5[0m
[34m==>[0m [1mDownloading from https://pkg-containers.githubusercontent.com/ghcr1/blobs/sh[0m
######################################################################## 100.0%
[32m==>[0m [1mUpgrading [32mwget[39m
  1.20.3_2 -> 1.21.2 
[0m
[32m==>[0m [1mInstalling dependencie

[34m==>[0m [1mPouring libnghttp2--1.46.0.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/libnghttp2/1.46.0: 13 files, 674.1KB
[32m==>[0m [1mInstalling [32munbound[39m[0m
[34m==>[0m [1mPouring unbound--1.13.2_1.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/unbound/1.13.2_1: 57 files, 5.7MB
Removing: /usr/local/Cellar/unbound/1.12.0... (57 files, 5.4MB)
Removing: /usr/local/Cellar/unbound/1.9.3_1... (56 files, 4.8MB)
[32m==>[0m [1mUpgrading [32mgnutls[39m
  3.6.15 -> 3.6.16_1 
[0m
[32m==>[0m [1mInstalling dependencies for gnutls: [32mlibtasn1[39m and [32mnettle[39m[0m
[32m==>[0m [1mInstalling gnutls dependency: [32mlibtasn1[39m[0m
[34m==>[0m [1mPouring libtasn1--4.17.0.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/libtasn1/4.17.0: 61 files, 639.6KB
[32m==>[0m [1mInstalling gnutls dependency: [32mnettle[39m[0m
[34m==>[0m [1mPouring nettle--3.7.3.catalina.bottle.tar.gz[0m
🍺  /usr/local/Cellar/nettle/3.7.3: 89 files, 2.7MB
[32m==>[0m [1mIn