In [1]:
# This code is an implementation of the model proposed by Foster Dayan and Morris (Hippocampus, 2000). 
# This is the DMP algorithm in which every day the location of the platform changes and every trial the 
# start location of the rat changes. We can define the number of rats (= number of independant experiment 
# to perform statistics), the number of days and the number of trials per day. 

#MvR edits

In [2]:
using Polynomials

In [3]:
###################################################################################
###################################################################################
########################                                ###########################
########################       DEFINE CLASSES           ###########################

###################################################################################
###################################################################################

In [4]:
type Trial
    Trajectory
    Latency
    SearchPreference
    ActionMap
    Valuemap
    Error
end

In [5]:
type Day 
    trial::Any
    Day()=new(Trial[]);
    Platform::Any
end

In [6]:
type Experiment 
    day::Any
        Experiment()=new(Day[])

    PlaceCells::Any
end

In [7]:
type Rat
    experiment::Any
    Rat()=new(Experiment[])
    parameters
    featuresexperiment
end

In [8]:
###################################################################################
###################################################################################
########################                                ###########################
########################       DEFINE FUNCTIONS         ###########################
########################                                ###########################
########################                                ##########################
########################                                ##########################
###################################################################################
###################################################################################

In [9]:
# Define the place activity :

# this is a for loop and the other bit uses vectorial formulation 
# Define activity as a function of position 
###### !!!!!!! POSITIONS TO BE GIVEN IN THE SAME UNIT THAN SIGMA ###### !!!!!!!
function place_activity(x,y,xpc,ypc,σ) # x,y 2 scalars the position of the rat, xpc,ypc 2 vectors posiions of all place cells
    N=length(xpc); # N number of place cells 
    actplacecell=zeros(N,1); # define empty array of activity 
    
    for k=1:N # k is the k-th place cell
        actplacecell[k]=exp(-((x-xpc[k])^2+(y-ypc[k])^2)/(2σ^2));
    end
    return actplacecell
end 

place_activity (generic function with 1 method)

In [10]:
# computes place cell activity with vectorial product 

function  placecells(position,centres,width)
# PLACECELLS Calculates the activity of the place cells in the simulation.
#
# PLACECELLS(POSITION,CENTRES,WIDTH) calculates the activity of the place cells
#in the simulation. The returned vector F is of length N, where N is the number of place
#cells, and it contains the activity of each place cell given the simulated rat's current
#POSITION (a 2 element column vector). The activity of the place cells is modelled as a
#rate-of-fire (i.e. a scalar value) determined by a gaussian function. The CENTRES of the
#gaussian functions are an argument, and must be a 2 x N matrix containing each place
#cell's preferred location in 2D space. The WIDTH of the place cell fields must
#also be provided as a scalar value (all place cells are assumed to have the same
#width).
#
#The returned vector, F, must be a N element column vector.
    
    # calculate place cell activity
F = exp.(-sum((repmat(position,1,size(centres,2))-centres).^2,1)/(2*width^2))';
return F
end


placecells (generic function with 1 method)

In [11]:
# Calculate reward as a function of position 
function reward(x,y,xp,yp,r) # x,y position of the rat and xp,yp position of the platform, r radius of the platform
    if (x-xp)^2+(y-yp)^2<= r^2 # if the rat is in the platform
        R=1;
    else # else 
        R=0;
    end
    return R;
end

reward (generic function with 1 method)

In [12]:
# This function tells within wich index column is located x
function indice(Acum,x) # x number, Acum vector
    if x<Acum[1] # if the random number generated is before the first 
                return 1
    end 
    for i=1:length(Acum)
       if Acum[i-1]<x<=Acum[i]
                return i
       end
    end 
end

indice (generic function with 1 method)

In [13]:
###################################################################################
################## GENERAL PARAMETERS THAT DONT CHANGE WITHIN TRIALS ##################
###################################################################################

# ELEMENTS WITH UNITS 
R=1 #100; # Radius of the circle in cm
r=0.04; #5;# Radius of the platform  in cm
radiussearchpref=0.2;# 20; # radius of the area in which we calculate searchpreference  

speed=0.3; #30; # speed of the rat in cm.s-1 


dt=0.1; # timestep in s 
# Different possible directions 
angles=[-3*pi/4, -2*pi/4, -pi/4, 0, pi/4, 2*pi/4, 3*pi/4, pi];


# Trial characteristic :
T=120; # maximal duration of a trial in seconds
DeltaT=15; # Interval between trials in seconds  

# Place cells 
N=493; # number of place cells 
σ=0.30*R; # variability of place cell activity, in proportion of the radius 


# Place cell : other method than the use of the sunflower section, generate random uniform centers and radius, 
# so place cells centers are more random than with the use of sunflower in which they are all eqidistqnt for 
# one another
# initialize the centres of the place cells by random unifrom sampling across the pool :
arguments= rand(1,N)*2*pi;
radii= sqrt.(rand(1,N))*R;
centres= [cos.(arguments).*radii; sin.(arguments).*radii]; 
Xplacecell=centres[1,:];
Yplacecell=centres[2,:];


# Action cells : 
n=8; # number of action cells 


# Potential positions of the platform : 
Xplatform=[0.3,0,-0.3,0,0.5,-0.5,0.5,-0.5].*R;  
Yplatform=[0,0.3,0,-0.3,0.5,0.5,-0.5,-0.5].*R; 

# Potential Starting positions of the rat :
Xstart=[0.90,0,-0.90,0].*R; # East, North, West, South
Ystart=[0,0.90,0,-0.90].*R;

# Define number of rats, number of days and numbers of trials per day
numberofdays=1;
numberofrats=10;
numberoftrials=20;


times=collect(0:dt:T+dt);

In [14]:
# Parameter that regulate the choice between former angle and new angle 
momentum=1.1;

# Learning variables : 
γ=0.98; # Discount factor.  they dont precise the value  
actorLR=0.1; # actor learning rate
criticLR=0.01; # critic learning rate

In [21]:
#########################################################################
#############          LOOP       1   EXPERIMENT FOR 1 DAY 1 RAT   ######################
#########################################################################

@time begin # get the time it takes to run it 

rats=Rat();
# Save different parameters 
rats.parameters=[momentum,γ,actorLR,criticLR]; 
rats.featuresexperiment=[numberofrats, numberofdays, numberoftrials];


for indexrat=1:numberofrats
    
currentexperiment=Experiment(); # Creating the experiment 
currentexperiment.PlaceCells=hcat(Xplacecell,Yplacecell); # Store location of place cells 

# Initialisation variables :
criticweights=zeros(N,1); # weights between critic and place cells
actorweights=zeros(N,n); # weights between actor and place cells
    
        ##########  ##########  ##########  ##########   ########## 
    ##########  ##########  START EXPERIMENT  ##########  ##########  
        ##########  ##########  ##########  ##########   ########## 

# currentexperiment=Experiment(); # Creating the experiment 
#currentexperiment.PlaceCells=hcat(Xplacecell,Yplacecell); # Store location of place cells 

    for indexday=1:numberofdays
            
        # Everyday the location of the platform changes
        # Chose platform :
        indexplatform=rand(1:8); # generate random index between 1 and 8 (only 8 positions possible) 
        xp=Xplatform[indexplatform]; # take the ith platform position 
        yp=Yplatform[indexplatform]; 

        
        currentday=Day(); # creating a day 
        currentday.Platform=hcat(xp,yp); # storing the position of the platform for this day  
                    
            ##########  ##########  ##########  ##########  
        ##########  ##########  START DAY ##########  ##########  
            ##########  ##########  ##########  ##########  
       
        for indextrial=1:numberoftrials ##########  
            indextrial
            
            ## Chose starting position :
                      
            indexstart=rand(1:4); # generate random number between 1 and 4
            positionstart=[Xstart[indexstart] Ystart[indexstart]]; # put the rat in the chosen position 1 East 2 North 3 West 4 South
            
            position=positionstart;
            
            # Initialize reward 
            re=0;
            
            # Initialise index  
            k=1;
            # initialise time 
            t=times[k];
                
            # initialise the storage of the trajectory within the trial 
            historyX=Float64[]; # for x coordinate
            historyY=Float64[]; # for y coordinate 

            error=Float64[]; # initialise storage of the error within the trial
            searchpref=0; # initialise serachpreference which is the counting of the time spent within the circle around the platform of radius radiussearchpref
            arg=0; # define the angle        
            timeout=0; # define the timeout parameter that tells if we are running out of time (at 120 seconds rats are put on the platform in which case we dont reinforce the action)        
            prevdir=[0 0]; # initialise prevdir to smooth trajectory
                
                
            ##########  ##########  ##########  ##########   ########## 
            ##########  ##########  START TRIAL ##########  ##########  
            ##########  ##########  ##########  ##########   ########## 
            
                while t<=T && re==0 # if we are not on the platform and if the time is less than the limit time 
    
                        if t==T # if we are running out of time 
                            X=xp; # we put the rat on the platform
                            Y=yp;
                            position=[X Y];
                            timeout=1; # if we put the rat on the platform then we dont reinforce the actor but only the critic
                        end
                        
                        
                    # Store former position to be able to draw trajectory
                    push!(historyX,position[1]) # store position 
                    push!(historyY,position[2])
                    
                    
                         ###  Compute reward ### 
                    re=reward(position[1],position[2],xp,yp,r); 
                    
                         # compute new activity of pace cells :
                    actplacecell=placecells([position[1],position[2]],centres,σ);
                
                    ### Compute Critic ###
                    C=dot(criticweights,actplacecell); # current estimation of the future discounted reward 
                    
                    ####### Take decision and move to new position : #######
                    #  Compute action cell activity    
                    actactioncell=transpose(actorweights)*actplacecell; # careful actorweights contains place cells in rows and action cells in column 
                        
                        if maximum(actactioncell)>=100 # if its getting to big control it 
                            actactioncell=100.*actactioncell./maximum(actactioncell); 
                        end
                    
                    # Compute probability distribution : 
                    # 2 is like temperature.
                    Pactioncell=exp.(2.*actactioncell)./sum(exp.(2.*actactioncell)); 
 
                    
                    # Compute summed probability distribution:
                    SumPactioncell=[sum(Pactioncell[1:k]) for k=1:length(Pactioncell)] # compute the vector of progressive sum

                    # Compute summed probability distribution:
                    # SumPactioncell=cumul(Pactioncell); # other possibility but first one better

                    # Generate uniform number between 0 and 1 that we will compare with the elements of the cumulative probabilities:
                    x=rand(); 

                    # now chose action: 
                    indexaction=indice(SumPactioncell,x); # Chose which action between the 8 possibilities
                    argdecision=angles[indexaction]; # compute the coreesponding angle 
                    newdir=[cos(argdecision) sin(argdecision)]; # compute the direction it decided to go to
                    dir=(newdir./(1.0+momentum).+momentum.*prevdir./(1.0+momentum)); # smooth trajectory to avoid sharp angles
                    dir=dir./norm(dir); # normalize so we control the exact speed of the rat
                    prevdir=dir; # store former direction
                    
              
                    # Store former position 
                    formerposition=position; 
                    # Compute new position : 
                    position=position.+dt.*speed.*dir; 
                    
                    # redefine X and Y in case we hit the walls (this would need improving)
                    X=position[1];
                    Y=position[2];
                    Xf=formerposition[1];
                    Yf=formerposition[2];
                
                    # We code walls as reflectors :
                        if X^2+Y^2>R^2 # if we are out of the circle 
                            # find the position between former position and current position that is exactly on the circle :
                            # Create Polynomial with a parameter lambda that represent the absciss along the segment between formerposition and newposition
                            # search the value of lambda for which we are crossing the circle    
                            polynom=Poly([Xf^2+Yf^2-R^2,2*X*Xf+2*Y*Yf-2*Xf^2-2*Yf^2,Xf^2+Yf^2+X^2+Y^2-2*X*Xf-2*Y*Yf]); # using poly creates a polynomial, coefficient are in order of increasing exposant 
                            # find the root of this polynomial that is between 0 and 1 
                            λ=roots(polynom)[find(x -> 0<x <1,roots(polynom))];
                            λ=maximum(λ); # to convert from array of float 
                            Xlambda=λ*X+(1-λ)Xf; # position of the point that is on the circle between the former position to the new position
                            Ylambda=λ*Y+(1-λ)Yf; 
                            delta=norm([Xlambda-X,Ylambda-Y]); # distance of the point to Xlambda Ylambda
                                
                            
                            # Find the intersection between the line starting from the newposition in the direction of Xlambda and Ylambda and the circle of centre Xlambda Ylambda of radius delta
                            poly2=Poly([Y^2-2*Ylambda*Y+(Ylambda^2)+X^2-2*Xlambda*X+(Xlambda^2)-delta^2, -2*Ylambda*Y/R+2*Ylambda^2/R-2*Xlambda*X/R+2*Xlambda^2/R ,Ylambda^2/R^2+Xlambda^2/R^2]);
            
                            # Problem with root is the precision : sometimes the first root given is reaaally near the first point in which case we want the second root
                            deplacement=maximum(roots(poly2)[find(x -> 0<x ,roots(poly2))]); 
                            
                                
                            # Compute new position : we just move following the inverse vector of Xlambda,Ylambda of the distance we computed
                            Xnew=X-deplacement*Xlambda/R;
                            Ynew=Y-deplacement*Ylambda/R;
                            #X=-delta*cos(anglerotation)*cos(anglereflect)-delta*sin(anglerotation)*sin(anglereflect)+delta*sin(anglerotation)*cos(anglereflect)+delta*cos(anglerotation)*sin(anglereflect)+Xlambda;   
                            #Y=-delta*sin(anglerotation)*cos(anglereflect)+delta*sin(anglerotation)*sin(anglereflect)-delta*cos(anglerotation)*cos(anglereflect)+delta*cos(anglerotation)*sin(anglereflect)+Ylambda;   
                                if Xnew^2+Ynew^2>R^2 # if we are still out of the circle 
                                    println("we are still out")
                                    break
                                end

                            X=Xnew;
                            Y=Ynew;
                            position=[X Y];    
                        end
                    
                    # If we are now at the very edge of the maze, move us in a little bit :
                        if X^2+Y^2>=R^2
                            eps=0.02;
                            position = position.*(1-eps);
                        end
                
                    # compute new activity of pace cells :
                    actplacecell=placecells([position[1],position[2]],centres,σ);

                        if re==1 # if we are on the platform 
                            Cnext=0;
                        else # if we are not on the platform
                            Cnext=dot(criticweights,actplacecell);# new estimation of the future discounted reward 
                        end 
                    
                
                    #### Compute error  ####
                    err=re+γ*Cnext-C;
  
                    push!(error,err);
                
                
                    ######### Compute new weights : ########
                    
                    # Actor weights :
                        if timeout==0 # if the location we are in is thanks to our action 
                            G=zeros(8,1); # creating a matrix to select the row to update
                            G[indexaction]=1; # indicating which row is to update 
                            actorweights=actorweights+actorLR.*err.*actplacecell*transpose(G);# updating the weights       
                        end
                    
                    # Critic weights : 
                    criticweights=criticweights+criticLR.*err.*actplacecell;
    
                     ###### ####### ####### Updating search preference  ####### ####### #######
                       if (X-xp)^2+(Y-yp)^2<= radiussearchpref^2          
                           searchpref=searchpref+1*dt;
                       end
                     # going to next step
                    k=k+1; # counting steps
                    t=times[k]; # counting time
                
                ##################################################            
                end

                ########## ##########  END TRIAL ########## ##########             
            
            push!(historyX,position[1]) # Store the last position visited 
            push!(historyY,position[2])
                        
            ############### SAVING THE THINGS IN THE DIFFERENT CLASS ################
            ## in creating a new trial type one should write Trial(Trajectory, latency, searchpreference, actionmap) # action map atm is just z, then it will be improved adding a new attribute being value map 
            
            currenttrial=Trial(hcat(historyX,historyY),t,searchpref,actorweights,criticweights,error); # Creating the current trial with all its fields
            
            push!(currentday.trial,currenttrial) # Storing the current trial in the current day 
        
        ##################################################     
        end 
        ########## ##########  END DAY ########## ##########
               
        push!(currentexperiment.day,currentday) # Storing the current day in the current experiment 
                   
    ##################################################     
    end 
    ########## ##########  END EXPERIMENT ########## ##########

push!(rats.experiment,currentexperiment) # Storing the current experiment in the rat's class

##################################################     
end 
########## ##########  END RATS ########## ###
    
    
#end # end time 

In [16]:
# Save the data so we can open them again or use them to plot 
using JLD: save 

save("./experiment.jld", "rats", rats)