In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import pandas as pd
import matplotlib as mpl
import os
from matplotlib.offsetbox import TextArea, DrawingArea, OffsetImage, AnnotationBbox
import matplotlib.image as mpimg
from scipy.stats import binned_statistic
from my_utils import do_label
import tqdm

# model equations

In [None]:
VERBOSE=False

In [None]:
d0_=5
C_=15
mu_=6
alpha_=0.7

def load_eqs(omega_d,omega):
    
    def eqs(y,t,R):
        d,g = y
        dydt = [omega_d*(C_+mu_*np.log(R(t)) - alpha_*g - d), 
                omega*(d/d0_-1), 
                ]
        return dydt
    return eqs



def get_y0(r0=1,tmax = 100):
    t = np.arange(1, tmax,0.01)
    sol = odeint(eqs, [1,1], t,args=(lambda t: r0,))
    return sol[-1]

def run(r,t,r0=1):
    return odeint(eqs, get_y0(r0), t,args=(r,),rtol=1e-12,atol=1e-12)

def generate_stim_reward(tcue = 1,treward = None,tmax = 4.5,
                         reward_baseline = 1,reward_size=5,give_reward=1):
    t = np.arange(0, tmax,0.01)
    if treward==None:
        # if treward is none, cue and reward happen at the same time
        r = lambda t: reward_baseline if (t<tcue) else give_reward*reward_size + (1-give_reward)*0.1*reward_baseline
        rdashed = r
    else:
        r = lambda t: reward_baseline if (t<tcue) else 0.5*reward_size if (t<treward) else give_reward*reward_size + (1-give_reward)*0.1*reward_baseline
        rdashed = lambda t: r(t) if t<treward else 0.1*reward_baseline 
    return run(r,t,reward_baseline),r,t,rdashed


# phasic responses (primates)

In [None]:
# load with primate paramaters
eqs = load_eqs(omega_d = 100,omega = 30)

In [None]:
def preprocess_tobler_for_plotting(f):
    z=pd.read_csv("tobler_reward_2005/"+f,header=None,names=("x","y"))
    z=z.sort_values("y")
    z.y = z.y-z.y.iloc[0]
    z = z.iloc[1:]
    z=z.sort_values("x")
    z.x = z.x-z.x.iloc[0]
    FACT=5
    z.x=(z.x*FACT).round(1)
    z=z.groupby("x").max().reset_index()
    z.x=z.x/FACT
    return z


out_pop = []
colors=("gray","brown","black")
fig,axs=plt.subplots(ncols=2,figsize=(15,5))
slist=("pop_cue_","pop_reward_")
mxs=[20,15]
for i in range(2):
    clrs=[]
    for amount,clr in zip(["0.05","0.15","0.5"],colors):
        clrs.append(amount + " ml")
        z=preprocess_tobler_for_plotting(slist[i]+amount.replace(".","")+".csv")
        axs[i].plot(z.x,z.y,c=clr,lw=3)
        axs[i].set_ylim([0,mxs[i]])

axs[0].set_title("reward")
axs[1].set_title("reward preceded by cue")
axs[0].set_yticks([0,5,10,15,20])
axs[1].set_yticks([0,5,10,15])
for ax in axs:
    ax.set_xticks([0,0.5,1,1.5,2])
    do_label(ax,fon=1,ylabel='dopaminergic activity (spike/s)')
    ax.legend(clrs)

fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_responsesDATA.pdf")


In [None]:
DBASE  = 1

fig,axs=plt.subplots(ncols=2,figsize=(15,5),)
figb,axsb=plt.subplots(ncols=2,figsize=(15,3))


def plot_r_rdsahsed(ax,r,rdashed,tt,c):
    ax.plot(tt,[r(tt) for tt in t],c=c,lw=3)
    ax.plot(tt,[rdashed(tt) for tt in t],c=c,lw=3,ls='dashed')
    ax.set_ylim(0,8)
    #ax.set_yticks([0,1.5,3])

for stim_mag,c,i in zip((0.05,0.15,0.5),colors,range(3)):
    sol,r,t,rdashed=generate_stim_reward(reward_baseline=DBASE,reward_size=2+stim_mag*10,tmax=2)
    SHIFT=i/80
    TMAX=2
    axs[0].plot(t[t<TMAX],sol[t<TMAX,0],c=c,lw=3)
    axs[0].set_ylim((0,20))
    plot_r_rdsahsed(axsb[0],r,rdashed,t-SHIFT,c)

    TMIN=2
    TMAX=4
    sol,r,t,rdashed=generate_stim_reward(reward_baseline=DBASE,reward_size=2+stim_mag*10,treward=2.9,tmax=TMAX)
    #generate_stim_reward(stim_mag=stim_mag,baseline=DBASE,rind=2*DBASE,tmax=TMAX,treward=2.9)

    ind = (t>=TMIN) & (t<TMAX)
    SHIFT=i/30
    axs[1].plot(t[ind]-TMIN,sol[ind,0]-i/7,c=c,lw=3)
    axs[1].set_ylim((0,15))
    plot_r_rdsahsed(axsb[1],r,rdashed,t-TMIN-SHIFT,c)
    

axs[0].set_title("reward")
axs[1].set_title("reward preceded by cue")
axsb[0].set_title("reward")
axsb[1].set_title("reward preceded by cue")

axs[0].set_yticks([0,5,10,15,20])
axs[1].set_yticks([0,5,10,15])
for ax in axs:
    do_label(ax,fon=1,ylabel='dopaminergic activity (spike/s)')
    ax.set_xticks([0,0.5,1,1.5,2])
    ax.legend(clrs)

axsb[0].set_xticks([0,0.5,1,1.5,2])
axsb[1].set_xticks([-2,-1,0,1])

for ax in axsb:
    do_label(ax,fon=1,ylabel='expected reward')
    ax.legend([clrs[0],clrs[0],clrs[1],clrs[1],clrs[2],clrs[2]])
    



fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_responses.pdf")
figb.tight_layout()
if VERBOSE:
    figb.savefig("figs/python/da_responsesA.pdf")

In [None]:
amounts = np.array([0,0.025,0.075,0.15,0.25])
fig,axs=plt.subplots(ncols=len(amounts),figsize=(5*len(amounts),5))
maxs=[]
for i in range(len(amounts)):
    z=preprocess_tobler_for_plotting("fig1_cue_"+str(amounts[i]).replace(".","")+".csv")
    axs[i].plot(z.x,1.5*z.y,c='k',lw=2,linestyle='dashed')
    axs[i].set_ylim(0,20)
    axs[i].set_xticks([0,1,2])
    sol,r,t,rdashed=generate_stim_reward(reward_baseline=1,reward_size=1+amounts[i]*20,tcue=1.05,tmax=2)
    axs[i].plot(t,sol[:,0],c=c,lw=3)
    axs[i].set_title("%.3f ml" % amounts[i])
    do_label(axs[i],fon=1,ylabel='dopaminergic activity (spike/s)')
    
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_magnitudes.pdf")


# phasic responses (mice)

In [None]:
# load with mice paramaters
eqs = load_eqs(omega_d= 50,omega = 15)

In [None]:
config = {u'Δdopamine' : ["cohen_2012/cohen_dopSmall.csv","cohen_2012/cohen_dopBig.csv"],
          u'ΔGABA' : ["cohen_2012/cohenGABASmall.csv","cohen_2012/cohenGABABig.csv"]}
figs = []
for key in config.keys():
    signals = [pd.read_csv(sig,header=None,names=("x","y")) for sig in config[key]]
    for i in range(2):
        signals[i].y=signals[i].y-signals[i].y.min()
    fig,ax = plt.subplots(ncols=1,nrows=1,figsize=(3.5,2))
    signals[0].sort_values("x").plot(ax=ax,x='x',y='y',c='gray',lw=3,legend=False)
    signals[1].sort_values("x").plot(ax=ax,x='x',y='y',c='k',lw=3,legend=False)
    ax.set_ylim((0,15))
    fig.tight_layout()
    do_label(ax,fon=1,ylabel=key)
    ax.set_xticks((-1,0,1,2))
    fig.tight_layout()
    if VERBOSE:
        fig.savefig("figs/python/da_data{key}.pdf".format(key=key))

In [None]:
solA,r,t,rdashed=generate_stim_reward(reward_baseline=1,reward_size=8,tmax=4)
solB,r,t,rdashed=generate_stim_reward(reward_baseline=1,reward_size=2,tmax=4)
t=t-1
figA,axs = plt.subplots(ncols=1,nrows=2,figsize=(7,5),sharex=True)
axs[0].plot(t,solA[:,0]-solA[0,0],c='k',lw=3)
axs[0].plot(t,solB[:,0]-solB[0,0],c='gray',lw=3)
axs[1].plot(t,solA[:,1]-solA[0,1],c='k',lw=3)
axs[1].plot(t,solB[:,1]-solB[0,1],c='gray',lw=3)

for ax in axs:
    ax.set_ylim(0,17)
    ax.set_xticks([-1,0,1,2])
    ax.set_xlim([-1.75,2])
do_label(axs[0],fon=1,ylabel=u'Δdopamine',xlabel="")
do_label(axs[1],fon=1,ylabel=u'ΔGABA')
figA.tight_layout()
if VERBOSE:
    figA.savefig("figs/python/da_gaba.pdf")

# dopamine ramps 

In [None]:
from scipy.interpolate import interp1d
from scipy.signal import savgol_filter
import glob
def plot_dop_ramps(idx,xlim,normy=0,normx=0,adjx=1,adjy=1):
    fig,ax = plt.subplots(figsize=(5,5))
    fnames = glob.glob("unified_framework/unified_{idx}_*".format(idx=idx))
    for fname in fnames:
        f = pd.read_csv(fname,header=None,names=("x","y"))
        f.y = adjy*(f.y-normy)
        f.x = adjx*(f.x-normx)
        f= f.sort_values("x").drop_duplicates("x")
        f2 = interp1d(f.x, f.y, kind='linear')
        x=np.arange(f.x.min(),f.x.max(),0.2)
        y = savgol_filter(x=f2(x),window_length=11,polyorder=3)
        col = fname.split("_")[-1].split('.')[0]
        if col=="orange":
            col="#dd571c"
        elif col=='gold':
            col='#ffb302'
        ax.plot(x,y,c=col,lw=5)
        ax.set_xlim(xlim[0],xlim[1])
        ax.set_xticks(np.linspace(xlim[0],xlim[1],5))
        ylim = ax.get_ylim()
        ax.set_yticks(np.arange(int(ylim[0]),int(ylim[1])+1,0.5))
        do_label(ax,fon=1,fontsize=24,ylabel=u'ΔF/F (GCaMP) (z)')
        fig.tight_layout()
        if VERBOSE:
            fig.savefig("figs/python/da_uchida_{idx}_data.pdf".format(idx=idx))
plot_dop_ramps(1,(-2,2))
plot_dop_ramps(2,(0,8))
plot_dop_ramps(3,(0,16))
plot_dop_ramps(4,(0,8))


In [None]:
TMAX=8
LENGTH=97
LAMBDA=0.04
gradient = lambda t:np.exp(LAMBDA*(t**1.5))

In [None]:
def run_variable_speed(v):
    tmax=int(TMAX/v)
    t = np.linspace(0, tmax,100*tmax)
    r=lambda t: gradient(t*v)
    return t,run(r,t)

fig,ax = plt.subplots(figsize=(5,5))
for v,c in ([0.5,'red'],[1,'black'],[2,'#ffb302']):
    t,sol=run_variable_speed(v)
    
    ax.plot(t,sol[:,0]-d0_,color=c,lw=4)

    ax.set_ylim(0,1.5)

    
ax.set_xticks([0,4,8,12,16])
do_label(ax)
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_uchida_3.pdf")

In [None]:
def run_teleport(tteleport_s=3,tteleport_e=5,normalize=True):
    NORM_FACT = TMAX/LENGTH
    tteleport_s,tteleport_e=tteleport_s*NORM_FACT,tteleport_e*NORM_FACT
    delta=(tteleport_e-tteleport_s)
    t = np.linspace(0, TMAX-delta,100*TMAX)
    DELAY = 0.1
    r = lambda t: gradient(t) if t<tteleport_s else (gradient(t+delta) if t>tteleport_s+DELAY else gradient(tteleport_s))
    if normalize:
        tret = t-TMAX+delta+2
    else:
        tret=t
    return tret,run(r,t)
    
def run_pause(pause_t):
    NORM_FACT = TMAX/LENGTH
    pause_t=pause_t*NORM_FACT
    t = np.linspace(0, TMAX,100*TMAX)
    r = lambda t: gradient(t) if t<pause_t else gradient(pause_t)
    return t-TMAX+2,run(r,t)

In [None]:
fig,ax = plt.subplots(figsize=(5,5))
t,sol=run_teleport(0,0)
ax.plot(t,sol[:,0]-d0_,c='k',lw=4)
t,sol=run_teleport(40,70)
ax.plot(t,sol[:,0]-d0_,c='red',lw=4)
t,sol=run_teleport(65,70)
ax.plot(t,sol[:,0]-d0_,c='#dd571c',lw=4)
t,sol=run_pause(70)
ax.plot(t,sol[:,0]-d0_,c='#ffb302',lw=4)
ax.set_xlim(-2,2)
ax.set_ylim(0,2)
ax.set_yticks([0,0.5,1,1.5])
ax.set_xticks([-2,-1,0,1,2])

do_label(ax)
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_uchida_1.pdf")

In [None]:
fig,ax = plt.subplots(figsize=(5,5))
t,sol=run_teleport(0,0,normalize=False)
ax.plot(t,sol[:,0]-d0_,c='k',lw=4)
t,sol=run_teleport(5,35,normalize=False)
ax.plot(t,sol[:,0]-d0_,c='r',lw=4)
t,sol=run_teleport(25,55,normalize=False)
ax.plot(t,sol[:,0]-d0_,c='#dd571c',lw=4)
t,sol=run_teleport(45,75,normalize=False)
ax.plot(t,sol[:,0]-d0_,c='#ffb302',lw=4)
ax.set_xlim(0,8)
ax.set_ylim(0,2.5)

do_label(ax)
ax.set_xticks([0,2,4,6,8])
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_uchida_2.pdf")

In [None]:
def fexp(EXP):
    return lambda t: gradient(TMAX*(t**EXP)/ (TMAX**EXP))
fig,ax = plt.subplots(figsize=(5,5))
t = np.linspace(0, TMAX,100*TMAX)
exps = [0.5,1,2]
cols = ['red','k','#ffb302']
for i in range(3):
    ax.plot(t,run(fexp(exps[i]),t)[:,0]-d0_,c=cols[i],lw=4)
    
ax.set_ylim(0,1)
do_label(ax)
ax.set_xticks([0,2,4,6,8])
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/da_uchida_4.pdf")

# matching law

In [None]:
class GaussianTwoMode(object):
    def __init__(self,h=1,a=0.3,b=0.2):
        self.a=a
        self.b=b
        self.h=h
    def field(self,X):
        return self.h*np.exp(-np.power((X-self.a)/(self.b),2))+np.exp(-np.power((X+self.a)/(self.b),2))


In [None]:
def get_FCDCircuit(rho,N,alpha,omega,d0,tau,H):
    
    class FCDCircuit():
        def __init__(self,v,dt):
            self.vals = {"d" : 1,"g" : 1}
            self.v=v
            self.dt = dt
        def run(self,u):
            m0=1
            dt = self.dt
            d,g = self.vals["d"],self.vals["g"]
            self.vals["d"] = max(rho + N*np.log(u)-alpha*g,0.01)
            self.vals["g"] += (dt)*(omega*(d/d0-1))
            return (1/tau)*np.power(d0/d,H)
    return FCDCircuit



class FastBacteria(object):
    def __init__(self,f,tau,rho,N,alpha,omega,d0,H,v,dt):
        self.circuit = get_FCDCircuit(rho=rho,omega=omega,N=N,tau=tau,alpha=alpha,d0=d0,H=H)(v=v,dt=dt)
        self.v = v
        self.dt = dt
        self.f = f
        self.coord = self.init_coord()
        self.f0 = f.field(self.coord)    
        self.d0=d0
        self.tau = tau
        [self.circuit.run(self.f0) for i in range(1000)]
        self.head = np.random.choice([1,-1])
        self.traj = []

    def init_coord(self):
        return 0
    def reor(self):
        return np.random.choice([-1,1])

    def bac_run(self,iterations=1000,bound=1,do_tqdm=True):
        rl=0
        if do_tqdm:
            tqdm_maybe = tqdm.tqdm_notebook
        else:
            tqdm_maybe = lambda x: x
        for i in tqdm_maybe(range(iterations)):
            rl += self.dt
            self.traj.append(self.coord)
            self.f0 = self.f.field(self.coord)
            rdr_prob = np.clip(self.dt*self.circuit.run(self.f0),0,1)
            if (np.random.rand() < rdr_prob):
                rl=0
                self.head = self.reor()
                
            next_loc = self.coord+self.v*self.dt*self.head
            self.coord=next_loc
            
def iteration(f,thresh=0.025,iterations=1000000):
    bac = FastBacteria(f=f,alpha=1,N=5,omega=15,tau=0.1,rho=15,d0=5,H=1,dt=0.0025,v=0.3)
    bac.bac_run(iterations=iterations,do_tqdm=False)
    bac.traj = bac.traj[20000:]
    return (np.abs(np.array(bac.traj)-f.a)<thresh).mean(),(np.abs(np.array(bac.traj)+f.a)<thresh).mean()
def matching_run(h,iterations=50):

    sol = np.array([iteration(GaussianTwoMode(h=h)) for i in tqdm.tqdm_notebook(range(iterations))])
    #ra,rb=np.median(sol,axis=0)
    ra,rb=np.mean(sol,axis=0)
    print("result: %.2f, expected: %.2f, base ratio: %.2f" % (ra/rb,h,h))
    return ra/rb

            



In [None]:
mu_=5
tau=0.1


def get_FCDCircuit(mu_,C_,alpha_,omega_,d0_,tau_):
    
    class FCDCircuit():
        def __init__(self,dt):
            self.vals = {"d" : 1,"g" : 1}
            self.dt = dt
        def run(self,u):
            MIN_DOP=0.01
            dt = self.dt
            d,g = self.vals["d"],self.vals["g"]
            self.vals["d"] = max(C_ + mu_*np.log(u)-alpha_*g,MIN_DOP)
            self.vals["g"] += (dt)*(omega_*(d/d0_-1))
            return (1/tau_)*np.power(d0_/d,1)
    return FCDCircuit


class FastBacteria(object):
    def __init__(self,f,circuit,v,dt):
        self.circuit = circuit(dt=dt)
        self.v = v
        self.dt = dt
        self.f = f
        self.coord = self.init_coord()
        self.f0 = f.field(self.coord)
        self.d0=d0_
        #initialize
        [self.circuit.run(self.f0) for i in range(1000)]
        self.head = np.random.choice([1,-1])
        self.traj = []

    def init_coord(self):
        return 0
    def reor(self):
        return np.random.choice([-1,1])

    def bac_run(self,iterations=1000,bound=1,do_tqdm=True):
        rl=0
        if do_tqdm:
            tqdm_maybe = tqdm.tqdm_notebook
        else:
            tqdm_maybe = lambda x: x
        for i in tqdm_maybe(range(iterations)):
            self.traj.append(self.coord)
            self.f0 = self.f.field(self.coord)
            rdr_prob = np.clip(self.dt*self.circuit.run(self.f0),0,1)
            if (np.random.rand() < rdr_prob):
                self.head = self.reor()
                
            next_loc = self.coord+self.v*self.dt*self.head
            self.coord=next_loc
            
def iteration(f,thresh=0.025,iterations=1000000,v=0.25,dt=0.0025):
    circuit = get_FCDCircuit(mu_=mu_,C_=C_,alpha_=alpha_,omega_=15,d0_=d0_,tau_=tau)
    bac = FastBacteria(f=f,circuit=circuit,dt=dt,v=v)
    bac.bac_run(iterations=iterations,do_tqdm=False)
    bac.traj = bac.traj[20000:]
    return (np.abs(np.array(bac.traj)-f.a)<thresh).mean(),(np.abs(np.array(bac.traj)+f.a)<thresh).mean()

def matching_run(h,iterations=50):

    sol = np.array([iteration(f=GaussianTwoMode(h=h)) for i in tqdm.tqdm_notebook(range(iterations))])
    ra,rb=np.mean(sol,axis=0)
    print("result: %.2f, expected: %.2f, base ratio: %.2f" % (ra/rb,h,h))
    return ra/rb

            



In [None]:
ratios=np.linspace(1,4,6)
sol = [matching_run(j,iterations=100) for j in ratios]

In [None]:
fig,ax = plt.subplots(figsize=(5.5,5.5))
ax.plot(ratios,ratios,c='k',lw=2,ls='dashed')
ax.set_ylim([1,5.5])
ax.set_yticks([1,2,3,4,5])
ax.scatter(ratios[::-1],sol,c='k',s=100,marker='o')
ax.legend((r"matching, $\beta=1$","simulation"))
do_label(ax,fon=1,fontsize=16,ylabel=r'response ratio $P(r_1) / P(r_2)$',xlabel=r'reward ratio $ r_1 / r_2$')
fig.tight_layout()
if VERBOSE:
    fig.savefig("figs/python/matching.pdf")