In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import random

from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
VERBOSE = False

In [None]:
# initial value for experiment
omega_d = 50
omega = 15
C = 15
mu = 6
alpha = 0.7
d_0 = 5
R_0 = 1

In [None]:
# steady state of mouse experiment
R = R_0
d_st = d_0

# calculate g steady state
g_st = (C - d_0 + mu * np.log(R))/alpha

g_st

In [None]:
def ode_system(R, d_val, g_val, omega_d, C, mu, alpha, omega, d0):
    d_dot = omega_d * (C + mu * np.log(R) - alpha * g_val - d_val)
    g_dot = omega * (d_val / d0 - 1)
    return [d_dot, g_dot]

In [None]:
# using Euler method
def explicit_euler(old_d_val, old_g_val, R_input, h=0.01):

    new_d_val = old_d_val + h * ode_system(R_input, old_d_val, old_g_val, omega_d, C, mu, alpha, omega, d_0)[0]
    new_g_val = old_g_val + h * ode_system(R_input, old_d_val, old_g_val, omega_d, C, mu, alpha, omega, d_0)[1]

    return new_d_val, new_g_val


In [None]:
# Updated parameters for increased sensitivity
m0 = 1
alpha = 0.1
n = 1
h = 1
a0 = 1/3
tau = 3

In [None]:
def calculate_a_l_m(l, m):
    K_m = np.exp(alpha * (m - m0))
    a = 1 / (1 + (l / K_m)**n)
    return a

def calculate_phi(l,m):
    l+=1
    a_lm = calculate_a_l_m(l,m)
    value = 1/tau * (a_lm/a0)**h
    return  1-max(0, min(0.9, value))

In [None]:
class LineEnvironment:
    def __init__(self, length, reward_pos_size_prob):
        self.length = length
        self.reward_pos, self.reward_size, self.reward_prob = reward_pos_size_prob
        self.visits = [0] * length # Initialize a list to record visits to each position

    def step(self, action, state):
        # Update state based on action
        state = max(0, min(state + action, self.length - 1))
        self.visits[state] += 1  # Record the visit to the current state

        reward = 0
        for i in range(len(self.reward_pos)):
            if self.reward_pos[i] == state and np.random.random() <= self.reward_prob[i]:
                reward = self.reward_size[i]
                break

        return state, reward

def update_dopamine(d_list, g_list, reward_input):
    dop, gaba = explicit_euler(d_list[-1], g_list[-1], reward_input+1)
    return dop, gaba

In [None]:
class TDLearningAgent:
    def __init__(self, num_states, num_actions, learning_rate, gamma):
        self.q_table = np.zeros((num_states, num_actions)) # agent expect no reward
        self.learning_rate = learning_rate
        self.gamma = gamma

    def choose_action(self, prev_action, state, epsilon, last_epi=False):
        if last_epi:
            return 1

        if np.random.rand() < epsilon:
            return -1*prev_action  # change direction with prob epsilon
        else:
            return prev_action # else keep direction

    def learn(self, state, action, reward):

        if action == -1:
            q_action = 0
        else:
            q_action = 1

        predict = self.q_table[state, q_action]
        next_state = state + action
        target = reward + self.gamma * np.max(self.q_table[next_state])
        updated_value = self.q_table[state, q_action] + self.learning_rate * (target - predict)
        self.q_table[state, q_action] = max(0, updated_value)

In [None]:
def plot_visits(visits, reward_pos):
    visits[0] = 0
    positions = list(range(len(visits)))

    plt.figure(figsize=(10, 6))
    plt.bar(positions, visits, color='gray', label='Visits')

    reward_visits = [visits[pos] for pos in reward_pos]
    plt.bar(reward_pos, reward_visits, color='red', label='Reward Positions')

    plt.xlabel('Position', fontsize=14)
    plt.ylabel('Number of Visits', fontsize=14)

    plt.xticks(positions, [str(pos) for pos in positions])

    plt.legend()
    plt.grid(True)

    if VERBOSE:
        directory = f'drive/MyDrive/Dopamine/graphs/'
        name = f'cumu_freq_sto_rew_dop_reori_td'
        plt.savefig(directory + name +'.png', format='png',bbox_inches='tight',dpi=500)

    plt.show()

In [None]:
# Initialization
rew_list = [[5, 9, 13], [1,2,4], [1, 0.5, 0.25]]
env = LineEnvironment(length=20, reward_pos_size_prob=rew_list)
agent = TDLearningAgent(num_states=20, num_actions=2, learning_rate=0.1, gamma=0.3)
epsilon = 0

d_list_whole = []
g_list_whole = []
q_values_over_time = []

# Training Loop
for episode in tqdm(range(10000)):
    state = 0
    done = False
    total_reward = 0
    d_list = [d_0]
    g_list = [g_st]
    pred_error_list=[]
    action = 1 #init action

    reward = 0
    dop_spike=0

    while not done:
        agent.learn(state, action, reward)

        epsilon = calculate_phi(dop_spike, d_0)

        action = agent.choose_action(action, state, epsilon)
        next_state, reward = env.step(action, state)

        if action == -1:
                q_action = 0
        else:
                q_action = 1

        # Update Dopamine based on prediction error
        prediction_error = abs(reward + agent.gamma * np.max(agent.q_table[next_state + action]) - agent.q_table[next_state, q_action])
        dop, gaba = update_dopamine(d_list, g_list, prediction_error)
        dop_spike = dop - d_list[0]
        d_list.append(dop)
        g_list.append(gaba)

        state = next_state

        if state == env.length - 2:  # Agent reaches the end of the line
            done = True
        if state == 0: # Agent go back to the start of the line, go right again
            action = 1

    if episode % 5 == 0:
        d_list_whole.append(d_list)
        g_list_whole.append(g_list)
        q_values_over_time.append(agent.q_table[:].copy())


In [None]:
# Run the simulation
# Initialization
env_simu = LineEnvironment(length=20, reward_pos_size_prob=rew_list)

d_list_simu = []
g_list_simu = []
q_values_over_simu = []

hit_time = 0
state = 0
finish = False
action = 1 #init action
reward = 0
dop_spike=0

# Simulation Loop
while not finish:

    agent.learn(state, action, reward)

    epsilon = calculate_phi(dop_spike, d_0)

    action = agent.choose_action(action, state, epsilon)

    if state == env_simu.length - 2:
        action = -1  # Agent reaches the end of the line
        hit_time += 1

    if state == 0: # Agent go back to the start of the line, go right again
        action = 1
        hit_time += 1

    next_state, reward = env_simu.step(action, state)

    if action == -1:
            q_action = 0
    else:
            q_action = 1

    # Update Dopamine based on prediction error
    prediction_error = abs(reward + agent.gamma * np.max(agent.q_table[next_state + action]) - agent.q_table[next_state, q_action])
    dop, gaba = update_dopamine(d_list, g_list, prediction_error)
    dop_spike = dop - d_list[0]
    d_list.append(dop)
    g_list.append(gaba)

    state = next_state

    if hit_time > 10000:
        finish = True

plot_visits(env_simu.visits, env_simu.reward_pos)


In [None]:
# last episode, at the end, run a deterministic simulation where the agent simply goes in one direction
state = 0
done = False
d_list = [d_0]
g_list = [g_st]
action = 1 #init action
reward=0

state_list = [0]

rew_list_final = [[5, 9, 13], [1,0,0], [1, 0, 0]]
env_final = LineEnvironment(length=20, reward_pos_size_prob=rew_list_final)

while not done:
    agent.learn(state, action, reward)

    action = agent.choose_action(action, state, epsilon=0, last_epi=True)
    next_state, reward = env_final.step(action, state)

    if action == -1:
            q_action = 0
    else:
            q_action = 1

    # Update Dopamine based on prediction error
    prediction_error = max(0, reward + agent.gamma * np.max(agent.q_table[next_state + action]) - agent.q_table[next_state, q_action])
    dop, gaba = update_dopamine(d_list, g_list, prediction_error)
    dop_spike = dop - d_list[0]
    d_list.append(dop)
    g_list.append(gaba)

    state = next_state

    if state == env_final.length - 2:  # Agent reaches the end of the line
            done = True
    if state == 0: # Agent go back to the start of the line, go right again
            action = 1

d_list_whole.append(d_list)
g_list_whole.append(g_list)
q_values_over_time.append(agent.q_table[:].copy())

In [None]:
q_values_over_time = np.array(q_values_over_time)

In [None]:
plt.figure(figsize=(10, 6))

for i in range(0, len(q_values_over_time), 400):
    plt.plot(q_values_over_time[i, :, 1], label=f'Q table at epoch {i*5}')

# Draw vertical lines at reward positions
for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.axvline(x=rew_pos, color='red', linestyle='--', linewidth=0.8, label='Reward Position' if rew_pos == rew_list[0][0] else None)

# Draw horizontal lines from x=0 to each reward position with the corresponding reward size
for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.hlines(y=rew_size*rew_prob, xmin=0, xmax=rew_pos, color='blue', linestyle='--', linewidth=0.8, label=f'Expectation of Reward' if rew_pos == rew_list[0][0] else None)

plt.xlabel('Position', fontsize=14)
plt.ylabel('Q-values', fontsize=14)

plt.legend()

plt.tight_layout(rect=[0, 0, 0.75, 1])

if VERBOSE:
    directory = f'drive/MyDrive/Dopamine/graphs/'
    name = f'Q_Evolution1_sto_rew_dop_reori_td'
    plt.savefig(directory + name +'.png', format='png',bbox_inches='tight',dpi=500)

plt.show()


In [None]:
plt.figure(figsize=(10, 6))

for i in range(0, len(q_values_over_time), 400):
    plt.plot(q_values_over_time[i, :, 0], label=f'Q table at epoch {i*5}')

for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.axvline(x=rew_pos, color='red', linestyle='--', linewidth=0.8, label='Reward Position' if rew_pos == rew_list[0][0] else None)

for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.hlines(y=rew_size*rew_prob, xmin=0, xmax=rew_pos, color='blue', linestyle='--', linewidth=0.8, label=f'Expectation of Reward' if rew_pos == rew_list[0][0] else None)

plt.xlabel('Position', fontsize=14)
plt.ylabel('Q-values', fontsize=14)

plt.legend(loc='upper left')

plt.tight_layout(rect=[0, 0, 0.75, 1])

if VERBOSE:
    directory = f'drive/MyDrive/Dopamine/graphs/'
    name = f'Q_Evolution-1_sto_rew_dop_reori_td'
    plt.savefig(directory + name +'.png', format='png',bbox_inches='tight',dpi=500)

plt.show()


In [None]:
plt.figure(figsize=(10, 6))

plt.plot(q_values_over_time[-2, :, 1], label=f'Q table for action 1')
plt.plot(q_values_over_time[-2, :, 0], label=f'Q table for action -1')

for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.axvline(x=rew_pos, color='red', linestyle='--', linewidth=0.8, label='Reward Position' if rew_pos == rew_list[0][0] else None)

for rew_pos, rew_size, rew_prob in zip(rew_list[0], rew_list[1], rew_list[2]):
    plt.hlines(y=rew_size*rew_prob, xmin=0, xmax=rew_pos, color='blue', linestyle='--', linewidth=0.8, label=f'Expectation of Reward' if rew_pos == rew_list[0][0] else None)

plt.xlabel('Position', fontsize=14)
plt.ylabel('Q-values', fontsize=14)

plt.legend()

plt.tight_layout(rect=[0, 0, 0.75, 1])

if VERBOSE:
    directory = f'drive/MyDrive/Dopamine/graphs/'
    name = f'last_Q_1-1_sto_rew_dop_reori_td'
    plt.savefig(directory + name +'.png', format='png',bbox_inches='tight',dpi=500)

plt.show()


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharey=True, sharex=True)

all_q_values = []
for rew_pos in rew_list[0]:
    all_q_values.extend(q_values_over_time[:, rew_pos, 1])

min_q_value = min(all_q_values)
max_q_value = max(all_q_values)

bin_width = 0.05

bins = np.arange(min_q_value, max_q_value + bin_width, bin_width)

for index, rew_pos in enumerate(rew_list[0][:3]):
    rew_size = rew_list[1][index]
    rew_prob = rew_list[2][index]
    q_values_at_pos1 = q_values_over_time[:, rew_pos, 1]
    q_values_at_pos0 = q_values_over_time[:, rew_pos, 0]

    axes[index].hist(q_values_at_pos0, bins=bins, color='red', alpha=0.7, label='Action -1')
    axes[index].hist(q_values_at_pos1, bins=bins, color='blue', alpha=0.7, label='Action 1')
    axes[index].set_xlabel('Q-value intervals', fontsize=14)
    axes[index].set_xlim(min_q_value, max_q_value)
    axes[index].set_title(f'Reward at position {rew_pos} with size {rew_list[1][index]} and probability {rew_list[2][index]}')
    axes[index].axvline(x=rew_size*rew_prob, color='blue', linestyle='--', linewidth=1, label=f'Expectation = {rew_size*rew_prob}')
    axes[index].legend()

fig.text(0.02, 0.5, 'Frequency', va='center', rotation='vertical', fontsize=12)


plt.tight_layout(rect=[0.03, 0, 1, 0.95])

if VERBOSE:
    directory = 'drive/MyDrive/Dopamine/graphs/'
    plt.savefig(directory + 'q_freq_sto_rew_dop_reori_td.png', format='png', bbox_inches='tight', dpi=500)

plt.show()


In [None]:
last_episode_dopamine = d_list_whole[-1]
last_episode_gaba = g_list_whole[-1]

fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True)

axes[0].plot(last_episode_dopamine, color='blue', label='Dopamine Level at last epoch')
axes[0].set_ylabel('Dopamine Level', fontsize=14)
axes[0].set_ylim([4.95, 5.15])
for rew_pos in rew_list[0]:
    axes[0].axvline(x=rew_pos, color='red', linestyle='--', linewidth=0.8, label='Reward Position' if rew_pos == rew_list[0][0] else None)
axes[0].axhline(y=d_st, color='blue', linestyle='--', linewidth=0.8, label='Dopamine Steady state')

formatter = ScalarFormatter(useOffset=False)
formatter.set_scientific(False)
axes[0].yaxis.set_major_formatter(formatter)

axes[0].legend()

axes[1].plot(last_episode_gaba, color='green', label='GABA Level at last epoch')
axes[1].set_ylabel('GABA Level', fontsize=14)
axes[1].set_ylim([14.28, 14.30])
axes[1].set_xlabel('Position', fontsize=14)
for rew_pos in rew_list[0]:
    axes[1].axvline(x=rew_pos, color='red', linestyle='--', linewidth=0.8, label='Reward Position' if rew_pos == rew_list[0][0] else None)
axes[1].axhline(y=g_st, color='blue', linestyle='--', linewidth=0.8, label='GABA Steady state')

axes[1].legend()

axes[1].yaxis.set_major_formatter(formatter)

plt.tight_layout(pad=1.0)
fig.subplots_adjust(right=0.75)

if VERBOSE:
    directory = f'drive/MyDrive/Dopamine/graphs/'
    name = f'last_dop_gaba_sto_rew_dop_reori_td'
    plt.savefig(directory + name +'.png', format='png',bbox_inches='tight',dpi=500)

plt.show()
