[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/comp-neural-circuits/plasticity-workshop/blob/dev/rate_based.ipynb)

# Rate-based Plasticity Rules

## Hebbian Plasticity

**Goals**
+ Covariance-based learning rule is equivalent to detecting the first principal component of the activity


### Initialization

In [21]:
!pip install numpy scipy matplotlib ipywidgets scikit-learn --quiet
import numpy as np
import scipy.linalg as lin
from numpy.random import default_rng
rng = default_rng()
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
plt.style.use("https://github.com/comp-neural-circuits/plasticity-workshop/raw/dev/plots_style.txt")

### Utility Functions

In [18]:
#@title Utility functions
def ornstein_uhlenbeck(mean,cov,dt,Ttot,dts=1E-2):
  """
  Generates a multi-dimensional Ornstein-Uhlenbeck process.

  Parameters :
  mean (numpy vector) : desired mean
  cov  (matrix)   : covariance matrix (symmetric, positive definite)
  dt   (real)     : timestep output
  Tot  (real)     : total time
  dts = 1E-3 (real) : simulation timestep

  Returns :
  times (numpy vector)
  rates (numpy matrix)  :  rates[i,j] is the rate of unit i at time times[j]
  """
  times = np.linspace(0.0,Ttot-dt,num=int(Ttot/dt))
  n = len(mean)
  nTs = int(Ttot/dts)
  rates_all = np.empty((n,nTs))
  rates_all[:,0] = 0
  L = lin.cholesky(cov)
  nskip = int(dt/dts)
  assert round(dts*nskip,5) == dt , "dt must be multiple of  " + str(dts)
  for t in range(1,nTs):
    dr = dts*(mean-rates_all[:,t-1])
    dpsi = np.sqrt(2*dts)*(L.T @ rng.standard_normal(n))
    rates_all[:,t] = rates_all[:,t-1] + dr + dpsi
  # subsample 
  rates = rates_all[:,::nskip]
  return times,rates
  
def twodimensional_UL(mean1,var1,mean2,var2,corr,dt,Ttot,dts=1E-2):
  """
  Generates samples from a 2D Ornstein-Uhlenbeck process.

  Parameters :
  mean1 (real) : mean on first dimension
  var1  (real) : variance on first dimension (at dt=1. intervals)
  mean2 (real) : - 
  var2  (real) : - 
  corr  (real) : correlation coefficient 
  dt   (real)     : timestep output
  Tot  (real)     : total time
  dts = 1E-3 (real) : simulation timestep

  Returns :
  times  (numpy vector)
  rates1 (numpy vector)
  rates2 (numpy vector)
  """
  assert -1<=corr<=1, "correlation must be in (-1,1) interval"
  var12 = corr*var1*var2
  (times, rates) = ornstein_uhlenbeck(
      np.array([mean1,mean2]),
      np.array([[var1,var12],[var12,var2]]),
      dt,Ttot,dts)
  return times, rates[0,:],rates[1,:]


In [None]:
times,rates1,rates2 = twodimensional_UL(0.,1.,0.,2.,0.6,1.0,500.0)
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(times,rates1)
ax1.plot(times,rates2)
ax2.scatter(rates1,rates2)

In [None]:
@interact(x=True, y=1.0)
def g(x, y):
    return (x, y)

def f(x):
    return x
interact(f,x=10)

In [None]:
def mySTDP_plot(A_plus, A_minus, tau_plus, tau_minus, Delta_t, dW):
  '''
  A_plus : maximum amount of potentiation (LTP)
  A_minus: maximum amount of depression (LTD)
  tau_plus: LTP time constant
  tau_minus: LTD time constant 
  Delta_t : array with the time differences between post- and pre-synaptic spikes
  dW : synaptic change 
  '''
  plt.figure()
  plt.plot([-5 * tau_minus, 5 * tau_plus], [0, 0], 'k', linestyle=':')
  plt.plot([0, 0], [-A_minus, A_plus], 'k', linestyle=':')

  plt.plot(Delta_t[Delta_t <= 0], dW[Delta_t <= 0], 'r')
  plt.plot(Delta_t[Delta_t > 0], dW[Delta_t > 0], 'b')

  plt.xlabel(r'$\Delta t=$ t$_{\mathrm{post}}$ - t$_{\mathrm{pre}}$ (ms)')
  plt.ylabel(r'$\Delta $W', fontsize=14)
  plt.title('Pairwise STDP rule', fontsize=12, fontweight='bold')
  plt.show()

In [None]:
cov_mat = np.array([[ 1.0 , 0.5],[0.5,1.0]])
mean = np.array([2.,2.0])
times,rates = ornstein_uhlenbeck(mean,cov_mat,1.0,50.0)



In [None]:
plt.plot(times,rates[0,:],times,rates[1,:]);

In [None]:
np.var(rates[1,:])

In [None]:
times[-1]

In [None]:
print(np.mean(rates[0,:]), np.var(rates[0,:]))

In [None]:
def Delta_W(A_plus, A_minus, tau_plus, tau_minus, Delta_t):
  """
  Calculates the instantaneous change in weights dW due to the STDP pairwise rule

  A_plus : maximum amount of potentiation (LTP)
  A_minus: maximum amount of depression (LTD)
  tau_plus: LTP time constant
  tau_minus: LTD time constant 
  Delta_t : array with the time differences between post- and pre-synaptic spikes
  """

  # Initialize the STDP change
  dW = np.zeros(len(Delta_t))
  # Calculate dW for LTP
  dW[Delta_t > 0] = A_plus * np.exp(-Delta_t[Delta_t > 0] / tau_plus)
  # Calculate dW for LTD
  dW[Delta_t <= 0] = -A_minus * np.exp(delta_t[Delta_t <= 0] / tau_minus)

  return dW

In [None]:
rates

In [None]:
rates

In [None]:
vals = rng.standard_normal(10)
vals

In [None]:
# define the STDP rule parameters
A_plus = 1
A_minus = 1
tau_plus = 20  #[ms]
tau_minus = 10 #[ms]

delta_t = np.linspace(-5 * tau_minus, 5 * tau_plus, 50)

dW = Delta_W(A_plus, A_minus, tau_plus, tau_minus, Delta_t)

mySTDP_plot(A_plus, A_minus, tau_plus, tau_minus, Delta_t, dW)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=074c8271-80e9-4d9f-94a8-13db082db696' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>