# Burst Detection - R library wrapped using rpy2
To run properly, this notebook requires a working install of R with special packages **pracma, sjemea, and e1071**.
To test it out, you can use my custom python install located on the network: */allen/aibs/mat/Peter/renv/*, which can be started using **source /allen/aibs/mat/Peter/renv/bin/activate**

This script has been modified from it's original version by Marcus Blackburn.

**Readme for the R-code used in this notebook:**

**Burst-analysis**

This is the repository to accompany our article:

    A comparison of computational methods for detecting bursts in
    neuronal spike trains and their application to human stem
    cell-derived neuronal network

    Ellese Cotterill, Paul Charlesworth, Christopher W. Thomas, Ole
    Paulsen, Stephen J. Eglen
	
    J. Neurophysiol. (2016).


[Journal web page](http://jn.physiology.org/content/116/2/306)

You are free to use any of the data or resources in this repository.
We do request however that if you use this material, you cite the
above paper in any work that you publish.

Code to implement the eight burst detection methods in the paper are located in [Burst_detection_methods](Burst_detection_methods).

Data for mouse RGC analysis from Demas et al., 2003 are available from http://www.gigasciencejournal.com/content/3/1/3

MEA recordings from hiPSC-derived neuronal networks are located in [hiPSC_recordings](hiPSC_recordings)

In [1]:
#rpy2 generates a bunch of annoying warnings
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np

In [2]:
import rpy2.ipython
%load_ext rpy2.ipython

In [3]:
#make some test spike trains (or import real data)
spikes = np.random.poisson(0.1*np.ones(10000))
spike_times =  np.where(spikes>0)[0]
spike_times = spike_times/200.0


In [4]:
%R require(pracma)
%R require(sjemea)
%R require(e1071)

array([1], dtype=int32)

The two most useful methods are LogISI (adaptive burst-detection with a single tuning parameter) and MI (5 human-interpretable parameters)

In [5]:
%%R
#Function to run logISI method - works
logisi.pasq.method<-function(spike.train, cutoff=0.1){
  cutoff<-ifelse(is.null(cutoff), 0.1, cutoff)
  if (length(spike.train)>3) {
      isi.low <- logisi.break.calc(spike.train, cutoff) #Calculates threshold as isi.low
    if (is.null(isi.low) || isi.low>=1 ){
      logisi.par <- list(min.ibi=0,   min.durn=0, min.spikes=3,
      isi.low=cutoff) #If no value for isi.low found, or isi.low above 1 second, find bursts using threshold equal to cutoff (default 100ms)
      result<-logisi.find.burst(spike.train, logisi.par)
    } else if (isi.low<0) {
      result<-NA
    } else if (isi.low>cutoff & isi.low <1) {
      logisi.par <- list(min.ibi=isi.low,   min.durn=0, min.spikes=3,
                         isi.low=cutoff) #If isi.low >cutoff, find bursts using threshold equal to cutoff (default 100ms)
      bursts<-logisi.find.burst(spike.train, logisi.par)
      if (!is.na(bursts)[1]){
        logisi.par2 <- list(min.ibi=0,   min.durn=0, min.spikes=3,
        isi.low=isi.low) #If bursts have been found, add burst related spikes using threshold of isi.low
        brs<-logisi.find.burst(spike.train, logisi.par2)
        result<-add.brs(bursts, brs, spike.train)
      } else {
        result<-bursts
      }
    } else {
      logisi.par <- list(min.ibi=0,   min.durn=0, min.spikes=3,
      isi.low=isi.low) #If isi.low<cutoff, find bursts using a threshold equal to isi.low
      result<-logisi.find.burst(spike.train, logisi.par)
    }
    
  } else {
    result<-NA
  }
  result
}

#Finds peaks in logISI histogram
get.peaks<-function(h, Pd=2, Th=0, Np=NULL){
  m<-0
  L<-length(h$density)
  j<-0
  Np<-ifelse(is.null(Np), L, Np)
  pks<-NULL
  locs<-NULL
  void.th<-0.7
  while((j<L)&&(m<Np)){
    j<-j+1
    endL<-max(1,j-Pd)
    if (m>0 && j<min(c(locs[m]+Pd, L-1))){
      j<-min(c(locs[m]+Pd, L-1))
      endL<-j-Pd
    }
    endR<-min(L, j+Pd)
    temp<-h$density[endL:endR]
    aa<-which(j==endL:endR)
    temp[aa]<--Inf
    if (Pd>1){
      idx1<-max(1, aa-2)
      idx2<-min(aa+2, length(temp))
      idx3<-max(1, aa-1)
      idx4<-min(aa+1, length(temp))
      if (sum((h$density[j]>(temp[c(1:idx1, idx2:length(temp))]+Th))==FALSE)==0 && sum((h$density[j]>(temp[idx3:idx4]))==FALSE)==0 && j!=1 && j!=L){
        m<-m+1
        pks[m]<-h$density[j]
        locs[m]<-j
      } } else if (sum((h$density[j]>(temp+Th))==FALSE)==0 ) {
        m<-m+1
        pks[m]<-h$density[j]
        locs[m]<-j
      }
    
  }
  ret<-data.frame(pks=pks, locs=locs)
}

#Function to find cutoff threshold.
find.thresh<-function(h, ISITh=100){
  void.th<-0.7
  gp<-get.peaks(h)
  num.peaks<-length(gp$pks)
  pkx<-h$breaks[gp$locs]
  intra.indx<-which(pkx<ISITh)
  if(length(intra.indx)>=1){
    max.intra<-max(gp$pks[intra.indx])
    max.idx<-which.max(gp$pks[intra.indx])
  } else {
    return(-1000)
  }
  x1<-pkx[max.idx]
  y1<-max.intra
  locs1<-gp$locs[max.idx]
  num.peaks.after.burst<-num.peaks-max.idx
  if (num.peaks.after.burst==0){
    return(NULL)
  } else {
    gp2<-gp[(max.idx+1):num.peaks,]
    ymin<-sapply(gp2$locs, function(x) min(h$density[locs1:x]))
    xmin<-sapply(gp2$locs, function(x) which.min(h$density[locs1:x]))+locs1-1
    voidParameter<-1-(ymin/sqrt(y1*gp2$pks))
  }
  indxvoid<-suppressWarnings(min(which(voidParameter>=void.th)))
  if(is.infinite(indxvoid)) {
    flags<-c(1,0)
    return(NULL)
  } else {
    ISImax<-h$breaks[xmin[indxvoid]]
    return(ISImax)
  }
}

#Calculates cutoff for burst detection
logisi.break.calc<-function(st, cutoff){
  isi<-diff(st)*1000
  max.isi<-ceiling(log10(max(isi)))
  isi<-isi[isi>=1]
  br<-logspace(0, max.isi, 10*max.isi)
  h<-hist(isi, breaks=br, plot=FALSE)
  h$density<-h$counts/sum(h$counts)
  h$density<-lowess(h$density, f=0.05)$y
  thr<-find.thresh(h, cutoff*1000)
  if(!is.null(thr)){
    thr<-thr/1000
  }
  thr
}




###Function to add burst related spikes to edges of bursts
add.brs<-function(bursts, brs, spike.train){
is.between<-function(x,a,b){ betw<-0
                             if(x>=a &x<=b)
                             {
                               betw<-1
                             }
                             betw
}
                             
burst.adj<-data.frame(beg=rep(0, dim(bursts)[1]), end=rep(0, dim(bursts)[1]) )
for (i in 1:dim(bursts)[1]) {
  for (j in 1:dim(brs)[1]) {
    if(is.between(bursts[i,1], brs[j,1], brs[j,2]) | is.between(bursts[i,2], brs[j,1], brs[j,2]))
    {
      burst.adj$beg[i]<-min(bursts[i,1], brs[j,1])
      burst.adj$end[i]<-max(bursts[i,2], brs[j,2])
      break
    } else {
      burst.adj$beg[i]<-bursts[i,1]
      burst.adj$end[i]<-bursts[i,2]
    }
    if(brs[j,2]>bursts[i,2]) {
      break
    }
  }
}

diff.begs<-diff(burst.adj[,"beg"])
rep.bursts.begs<-which(diff.begs==0)
if (any(rep.bursts.begs)) {
  burst.adj<-burst.adj[-rep.bursts.begs,]
}
diff.ends<-diff(burst.adj[,"end"])
rep.bursts.end<-which(diff.ends==0)+1
if (any(rep.bursts.end)) {
  burst.adj<-burst.adj[-rep.bursts.end,]
}
start.times<-spike.train[burst.adj$beg]
end.times<- spike.train[burst.adj$end]
durn<-end.times-start.times
len<-burst.adj$end-burst.adj$beg+1
mean.isis<-durn/(len-1)
N.burst<-dim(burst.adj)[1]
IBI<-c(NA, start.times[-1]-end.times[-N.burst])
result<-cbind(beg=burst.adj$beg, end=burst.adj$end, IBI=IBI, len=len, durn=durn, mean.isis=mean.isis, SI=rep(1, N.burst))
result

}


##Function for finding bursts, taken from sjemea
logisi.find.burst<- function(spikes, par, debug=FALSE) {
  
  ## For one spike train, find the burst using log isi method.
  ## e.g.
  ## find.bursts(s$spikes[[5]])
  ## init.
  ## params currently in LOGISI.PAR
  ##
  
  no.bursts = NA;                       #value to return if no bursts found.
  
  
  ##beg.isi =    par$beg.isi
  ##end.isi =    par$end.isi
  min.ibi =      par$min.ibi
  min.durn =     par$min.durn
  min.spikes =   par$min.spikes
  isi.low =      par$isi.low
  
  nspikes = length(spikes)
  
  ## Create a temp array for the storage of the bursts.  Assume that
  ## it will not be longer than Nspikes/2 since we need at least two
  ## spikes to be in a burst.
  
  max.bursts <- floor(nspikes/2)
  bursts <- matrix(NA, nrow=max.bursts, ncol=3)
  colnames(bursts) = c("beg", "end", "IBI")
  burst <- 0                            #current burst number
  
  ## Phase 1 -- burst detection. Each interspike interval of the data
  ## is compared with the threshold THRE. If the interval is greater
  ## than the threshold value, it can not be part of a burst; if the
  ## interval is smaller or equal to the threhold, the interval may be
  ## part of a burst.
  
  
  
  ## LAST.END is the time of the last spike in the previous burst.
  ## This is used to calculate the IBI.
  ## For the first burst, this is no previous IBI
  last.end = NA;                        #for first burst, there is no IBI.
  
  eps<-10^(-10)
  n = 2
  in.burst = FALSE
  
  while ( n < nspikes) {
    
    next.isi = spikes[n] - spikes[n-1]
    if (in.burst) {
      if (next.isi - isi.low>eps) {
        ## end of burst
        end = n-1; in.burst = FALSE
        
        
        ibi =  spikes[beg] - last.end; last.end = spikes[end]
        res = c(beg, end, ibi)
        burst = burst + 1
        if (burst > max.bursts) {
          print("too many bursts!!!")
          browser()
        }
        bursts[burst,] <- res
      }
    } else {
      ## not yet in burst.
      if (next.isi - isi.low <=eps) {
        ## Found the start of a new burst.
        beg = n-1; in.burst = TRUE
      }
    }
    n = n+1
  }
  
  ## At the end of the burst, check if we were in a burst when the
  ## train finished.
  if (in.burst) {
    end = nspikes
    ibi =  spikes[beg] - last.end
    res = c(beg, end, ibi)
    burst = burst + 1
    if (burst > max.bursts) {
      print("too many bursts!!!")
      browser()
    }
    bursts[burst,] <- res
  }
  
  ## Check if any bursts were found.
  if (burst > 0 ) {
    ## truncate to right length, as bursts will typically be very long.
    bursts = bursts[1:burst,,drop=FALSE]
  } else {
    ## no bursts were found, so return an empty structure.
    return(no.bursts)
  }
  
  if (debug) {
    print("End of phase1\n")
    print(bursts)
  }
  
  
  ## Phase 2 -- merging of bursts.  Here we see if any pair of bursts
  ## have an IBI less than MIN.IBI; if so, we then merge the bursts.
  ## We specifically need to check when say three bursts are merged
  ## into one.
  
  
  ibis = bursts[,"IBI"]
  merge.bursts = which(ibis < min.ibi)
  
  if (any(merge.bursts)) {
    ## Merge bursts efficiently.  Work backwards through the list, and
    ## then delete the merged lines afterwards.  This works when we
    ## have say 3+ consecutive bursts that merge into one.
    
    for (burst in rev(merge.bursts)) {
      bursts[burst-1, "end"] = bursts[burst, "end"]
      bursts[burst, "end"] = NA         #not needed, but helpful.
    }
    bursts = bursts[-merge.bursts,,drop=FALSE] #delete the unwanted info.
  }
  
  if (debug) {
    print("End of phase 2\n")
    print(bursts)
  }
  
  
  ## Phase 3 -- remove small bursts: less than min duration (MIN.DURN), or
  ## having too few spikes (less than MIN.SPIKES).
  ## In this phase we have the possibility of deleting all spikes.
  
  ## LEN = number of spikes in a burst.
  ## DURN = duration of burst.
  len = bursts[,"end"] - bursts[,"beg"] + 1
  durn = spikes[bursts[,"end"]] - spikes[bursts[,"beg"]]
  bursts = cbind(bursts, len, durn)
  
  rejects = which ( (durn < min.durn) | ( len < min.spikes) )
  
  if (any(rejects)) {
    bursts = bursts[-rejects,,drop=FALSE]
  }
  
  if (nrow(bursts) == 0) {
    ## All the bursts were removed during phase 3.
    bursts = no.bursts
  } else {
    ## Compute mean ISIS
    len = bursts[,"end"] - bursts[,"beg"] + 1
    durn = spikes[bursts[,"end"]] - spikes[bursts[,"beg"]]
    mean.isis = durn/(len-1)
    
    ## Recompute IBI (only needed if phase 3 deleted some cells).
    if (nrow(bursts)>1) {
      ibi2 = c(NA, calc.ibi(spikes, bursts))
    } else {
      ibi2 = NA
    }
    bursts[,"IBI"] = ibi2
    
    SI = rep(1, length(mean.isis ))
    bursts = cbind(bursts, mean.isis, SI)
  }
  
  ## End -- return burst structure.
  bursts
  
}

In [6]:
%%R
##Poisson Surprise (PS) method - WORKS

PS.method<-function(spike.train, si.thresh=5) {
    si.thresh<-ifelse(is.null(si.thresh), 5, si.thresh)
    burst <- si.find.bursts.thresh(spike.train)
    if (is.null(dim(burst))){
        result<-NA
    } else {
        burst.rem<-which(burst[,"SI"]<si.thresh)
        if (length(burst.rem)) {
            burst<-burst[-burst.rem,]
        }
        if (length(dim(burst))<1) {
            burst<-data.frame(beg=burst[1], len=burst[2], SI=burst[3], durn=burst[4], mean.isis=burst[5])
        } else if (dim(burst)[1]==0){
            return(NA)
        }
        beg<-burst[,"beg"]
        len<-burst[,"len"]
        N.burst<-length(beg)
        end<-beg+len-1
        IBI<-c(NA, spike.train[beg[-1]]-spike.train[end[-N.burst]])
        result<-cbind(beg=beg, end=end, IBI=IBI, len=len, durn=burst[,"durn"], mean.isis=burst[,"mean.isis"], SI=burst[,"SI"])
        rownames(result)<-NULL
    }
    result
}




si.find.bursts.thresh<- function (spikes, debug = FALSE)
{
    nspikes = length(spikes)
    mean.isi = mean(diff(spikes))
    threshold = mean.isi/2
    n = 1
    max.bursts <- floor(nspikes/3)
    bursts <- matrix(NA, nrow = max.bursts, ncol = burst.info.len)
    burst <- 0
    while (n < nspikes - 2) {
        if (debug)
        print(n)
        if (((spikes[n + 1] - spikes[n]) < threshold) && ((spikes[n +
        2] - spikes[n + 1]) < threshold)) {
            res <- si.find.burst.thresh2(n, spikes, nspikes, mean.isi,
            burst.isi.max, debug)
            if (is.na(res[1])) {
                n <- n + 1
            }
            else {
                burst <- burst + 1
                if (burst > max.bursts) {
                    print("too many bursts")
                    browser()
                }
                bursts[burst, ] <- res
                n <- res[1] + res[2]
                names(n) <- NULL
            }
        }
        else {
            n = n + 1
        }
    }
    if (burst > 0) {
        res <- bursts[1:burst, , drop = FALSE]
        colnames(res) <- burst.info
    }
    else {
        res <- NA
    }
    res
}



si.find.burst.thresh2<-function(n, spikes, nspikes, mean.isi, threshold=NULL,
debug=FALSE) {
    ## Find a burst starting at spike N.
    ## Include a better phase 1.
    
    
    ## Determine ISI threshold.
    if (is.null(threshold))
    isi.thresh = 2 * mean.isi
    else
    isi.thresh = threshold
    
    if (debug)
    cat(sprintf("** find.burst %d\n", n))
    
    i=3  ## First three spikes are in burst.
    s = surprise(n, i, spikes, nspikes, mean.isi)
    
    ## Phase 1 - add spikes to the train.
    phase1 = TRUE
    ##browser()
    
    ## in Phase1, check that we still have spikes to add to the train.
    while( phase1 ) {
        
        ##printf("phase 1 s %f\n", s);
        
        i.cur = i;
        
        ## CHECK controls how many spikes we can look ahead until SI is maximised.
        ## This is normally 10, but will be less at the end of the train.
        check = min(10, nspikes-(i+n-1))
        
        looking = TRUE; okay = FALSE;
        while (looking) {
            
            if (check==0) {
                ## no more spikes left to check.
                looking=FALSE;
                break;
            }
            check=check-1; i=i+1
            s.new = surprise(n, i, spikes, nspikes, mean.isi)
            if (debug)
            printf("s.new %f s %f n %d i %d check %d\n", s.new, s, n, i, check)
            
            if (s.new > s) {
                okay=TRUE; looking=FALSE;
            } else {
                ## See if we should keep adding spikes?
                if ( (spikes[i] - spikes[i-1]) > isi.thresh ) {
                    looking = FALSE;
                }
                
            }
        }
        ## No longer checking, see if we found an improvement.
        if (okay) {
            if (s > s.new) {
                ## This should not happen.
                printf("before s %f s.new %f\n", s, s.new)
                browser()
            }
            s = s.new
        } else {
            ## Could not add more spikes onto the end of the train.
            phase1 = FALSE
            i = i.cur
        }
    }
    
    
    ## start deleting spikes from the start of the burst.
    phase2 = TRUE
    while(phase2) {
        if (i==3) {
            ## minimum length of a burst must be 3.
            phase2=FALSE
        } else {
            s.new = surprise(n+1, i-1, spikes, nspikes, mean.isi)
            if (debug)
            cat(sprintf("phase 2: n %d i %d s.new %.4f\n", n, i, s.new))
            if (s.new > s) {
                if (debug)
                print("in phase 2 acceptance\n")
                n = n+1; i = i-1
                s = s.new
            } else {
                ## removing front spike did not improve SI.
                phase2 = FALSE
            }
        }
    }
    
    
    ## End of burst detection; accumulate result.
    
    
    ## compute the ISIs, and then the mean ISI.
    
    ## Fencepost issue: I is the number of spikes in the burst, so if
    ## the first spike is N, the last spike is at N+I-1, not N+I.
    isis = diff(spikes[n+(0:(i-1))])
    mean.isis = mean(isis)
    
    durn = spikes[n+i-1] - spikes[n]
    res <- c(n=n, i=i, s=s, durn=durn, mean.isis=mean.isis)
    
    ##browser()
    res
    
}

In [7]:
%%R 
#Applies mi.find.bursts from sjemea to single spike train
MI.method<- function(spike.train){
  burst<-mi.find.bursts(spike.train)
  if (dim(burst)[1]<1) {
    burst<-NA
  }
  burst
}

In [8]:
%%R
# WORKS!!!
##Function to calculate and plot bursts, based on CMA method (Kapucu et al., 2012).
##Input is spike train. If brs.incl is set to true, method will include burst related
##spikes. min.val is the minumum number of spikes on a spikes train for the method to be run.
##If plot = TRUE, spike train with bursts labelled is plotted.
CMA.method<-function(spike.train, brs.incl=TRUE, min.val=3, plot=FALSE) {
  #Do not perform burst detection if less than min.val spikes in the spike train
  if (length(spike.train)<min.val) {
    result<-NA
    return(result)
  }
  isi<-diff(spike.train)
  isi.range<-max(isi)-min(isi)
  eps<-isi.range/1000
  if (isi.range<0.001){
    breaks1<-seq(0, max(isi)+isi.range/10, isi.range/10)
    hist.isi<-hist(isi, breaks =breaks1, plot=FALSE) #if ISI range very small (<0.001), use smaller bins
  } else {
    hist.isi<-hist(isi, breaks =seq(0, max(isi)+eps, eps), plot=FALSE) #Create histogram with approx 1000 bins
  }
  CMA<-cumsum(hist.isi$counts)/seq(1,length(hist.isi$counts),1)
  CMAm<-max(CMA)
  m<-min(which(CMA==CMAm))
  alpha.values<-data.frame(max=c(1, 4, 9, 1000), a1=c(1, 0.7, 0.5, 0.3), a2=c(0.5, 0.5, 0.3,0.1)) #Alpha value scale
  skew<-skewness(CMA)
  if (is.na(skew)){
    result<-NA
    return(result)
  }
  diff.skew<-alpha.values[,"max"]-skew
  alpha.indx<-which(diff.skew==min(diff.skew[diff.skew>0]))
  alpha1<-alpha.values[alpha.indx, "a1"]
  alpha2<-alpha.values[alpha.indx, "a2"]
  cutoff<-which.min(abs(CMA[m:length(CMA)] - (alpha1*CMAm)))+(m-1) #Cutoff set at bin closest in value to alpha1*CMAm
  xt<-hist.isi$mids[cutoff] #maxISI
  cutoff2<-which.min(abs(CMA[m:length(CMA)] - (alpha2*CMAm)))+(m-1) #Burst related spikes cutoff set at bin closest in value to alpha2*CMAm
  xt2<-hist.isi$mids[cutoff2] #maxISI for burst related spikes
  bursts<-find.bursts(spike.train,xt) #Find burst cores
  #If brs.incl=TRUE, then extend bursts to include burst related spikes
  if (brs.incl && !is.null(dim(bursts)[1])){
    brs<-find.bursts(spike.train, xt2)
    burst.adj<-NULL
    for (i in 1:dim(bursts)[1]){
    burst.between<-apply(brs[,1:2], 1, function(x) between.bursts(bursts[i,1:2], x)) #Find burst related spikes which surround bursts
    which.between<-which(unlist(sapply(burst.between, function(x) sum(!is.na(x))))>0)
    if (which.between) {
      burst.adj<-rbind(burst.adj, burst.between[[which.between]])
    } else {
      burst.adj<-rbind(burst.adj, bursts[i, 1:2])
    }
  }
  
  burst.adj<-unique(burst.adj) #Remove any repeated bursts
  N<-dim(burst.adj)[1]
  beg<-burst.adj[,1]
  end<-burst.adj[,2]
  ibi<-c(NA, spike.train[beg[-1]]-spike.train[end[-N]])
  len<-end-beg+1
  durn<-spike.train[end]-spike.train[beg]
  bursts<-cbind(beg=beg, end=end, IBI=ibi, len=len, durn=durn, mean.isis=durn/len, SI=1)
  }

  if (is.null(dim(bursts)[1])){
    bursts<-NA
  }
  bursts
}
  
    
    between.bursts<-function(burst1, burst2) {
      if (burst2[1]<=burst1[1] & burst2[2]>=burst1[2]) {
        burst<-c(min(burst1[1], burst2[1]), max(burst2[1], burst2[2]))
      } else {
        burst<-NA
      }
      burst
    }
    
    #Add back in IBIs

find.bursts<-function(spike.train, xt){
isi<-diff(spike.train)
indxs<-which(isi<xt)
burst.breaks<- c(0, which(diff(indxs) >1), length(indxs))
isi.list<-sapply(seq(length(burst.breaks) - 1), function(i) indxs[(burst.breaks[i] + 1):burst.breaks[i+1]])
burst.indx<-which(sapply(isi.list, length)>1)
if(length(burst.indx)){
beg<-sapply(isi.list[burst.indx], function(x) min(x))
end<-sapply(isi.list[burst.indx], function(x) max(x))+1
N<-length(beg)
ibi<-c(NA, spike.train[beg[-1]]-spike.train[end[-N]])
len<-end-beg+1
durn<-spike.train[end]-spike.train[beg]
bursts<-cbind(beg=beg, end=end, IBI=ibi, len=len, durn=durn, mean.isis=durn/(len-1), SI=1)
} else {
  bursts<-NA
}
bursts
}

In [9]:
%%R 
# WORKS!!!
##
# An R implementation of Gourevitch & Eggermont (2007) Rank Surprise Method for 
# identifying bursts in spike trains.
#
# Arguments:
# spike.train = vector of spike timings
# RS.thresh   = significance threshold for accepting a burst
#
# Returns:
# A list of 2 column matrices of the start and end times of identified bursts for
# each value of RS.thresh provided. 
# A 2x2 matrix of -1 is returned whenever no bursts are found.
##

RS.method <- function(spike.train, RS.thresh){
  
  ISI <- diff(spike.train)
  N <- length(ISI)
  Results<-list()
  if (N>1) {
  # Burst size at which Gaussian approximation is used
  q.lim <- 30
  # Minimum number of spikes acceptable as a burst
  l.min <- 3
  # Maximum ISI identifying spikes as possible bursts for subsequent analysis
  limit <- quantile(ISI,0.75)
  
  # Convert ISI values to ranks
  order1 <- sort(ISI,index.return=T)$ix
  order2 <- sort(-ISI,index.return=T)$ix
  rk <- rep(0,N)
  rk2 <- rep(0,N)
  rk[order1] <- (1:N)
  rk2[order2] <- (1:N)
  R=(N+1-rk2+rk)/2  # ensures equal values given mean rank
  
  #Identify start and end points of spikes sequences below limit
  ISI.limit <- diff(ISI<limit)
  begin.int <- which(ISI.limit==1)+1
  end.int <- which(ISI.limit==-1)
  #Include first ISI if under limit
  if(ISI[1] < limit){
    begin.int <- c(1,begin.int)
  }
  #Include last ISI if spikes are below limit at end of spike train
  if(length(end.int) < length(begin.int)){
    end.int <- c(end.int,N)
  }
  #Number of spikes in each putative burst
  length.int <- end.int-begin.int+1
  
  # Create stores for final burst information
  burst.RS <-  numeric()
  burst.length <- numeric()
  burst.start <- numeric()
  
  # Create solutions to -1^k 
  alternate <- rep(c(1,-1),200)
  
  # Create solutions to log factorials
  log.fac <- cumsum(log(1:q.lim))
  
  for (index in 1:length(begin.int)){  # Repeat for all clusters of short ISIs
    n.j <- begin.int[index]
    p.j <- length.int[index];
    subseq.RS <- numeric()
    if (p.j >= (l.min-1)){		 # Proceed only if there are enough spikes
      for (i in 0:(p.j-(l.min-1))){  # Repeat for all possible first spikes
        q <- l.min-2
        while (q < p.j-i){       # Repeat for increasing burst lengths
          q <- q+1
          rr <- seq(n.j+i, n.j+i+q-1)
          u <- sum(R[rr])
          u <- floor(u)
          # Calculate RS probability exactly, if q is small
          # or approximately if q is large
          if (q < q.lim){
            k <- seq(0,(u-q)/N,1)
            length.k <- length(k)
            mat1 <- matrix(rep(k,q), q, length.k, byrow=T)*N
            mat2 <- matrix(rep(0:(q-1), length.k), q, length.k)
            p <- exp((colSums(log(u - mat1 - mat2)) - 
                        log.fac[c(1,k[-1])] - log.fac[q-k]) - 
                       q*log(N))%*%alternate[1:length.k]                 
          }else{
            p <- pnorm((u-q*(N+1)/2)/sqrt(q*(N^2-1)/12));
          }
          RS <- -log(p)
          subseq.RS <- rbind(subseq.RS, c(RS,i,q))
        }
      }
      # Extract the highest rank surprise bursts that are non-overlapping 
      subseq.RS <- matrix(subseq.RS,ncol=3)
      if (length(subseq.RS) > 0){  
        subseq.RS <- subseq.RS[order(subseq.RS[ ,1], decreasing=T), ]
        while (length(subseq.RS) > 0){
          subseq.RS <- matrix(subseq.RS, ncol=3)
          current.burst <- subseq.RS[1, ]
          burst.RS <- rbind(burst.RS,current.burst[1])
          burst.start <- rbind(burst.start, n.j+current.burst[2])
          burst.length <- rbind(burst.length, current.burst[3]+1)
          subseq.RS <- subseq.RS[-1, ]
          if (length(subseq.RS) > 0){ 
            subseq.RS <- matrix(subseq.RS, ncol=3)  
            keep <- which(subseq.RS[ ,2] + subseq.RS[ ,3] - 1 < 
                            current.burst[2] | subseq.RS[,2] >  
                            current.burst[2] + current.burst[3] -1)
            subseq.RS=subseq.RS[keep, ]
          }
        }
      }
    }
  }
  
 
  # Convert length into end position and positions into times
  for (x in 1:length(RS.thresh)){
  above.thresh<-which(burst.RS>=RS.thresh[x])
  N.burst<-length(above.thresh)
  if (N.burst<1) {
    result<-NA
  } else {
    bursts<-cbind(burst.start[above.thresh], burst.length[above.thresh])
   bursts.ord<-cbind(bursts[ order(bursts[,1]),1], bursts[ order(bursts[,1]),2])
  beg<-bursts.ord[,1]
  len<-bursts.ord[,2]
  end<-beg+len-1
  start.times<-spike.train[beg]
  end.times<-spike.train[end]
  IBI<-c(NA, start.times[-1]-end.times[-N.burst])
  durn<-end.times-start.times
  mean.isis<-durn/(len-1)
  result<-cbind(beg=beg, end=end, IBI=IBI, len=len, durn=durn, mean.isis=mean.isis, SI=rep(RS.thresh[x], N.burst))
  }
  Results[[x]]<-result
  }
  } else {
    Results<-rep(list(NA), length(RS.thresh))
  }
  
  
  return(Results)
  
}


In [10]:
%%R
#ISI Rank threshold method WORKS
hennig.method<-function(st, cutoff.prob=0.05) {
  bursts<-NULL
  result<-NULL
  allisi<-diff(st)
  if (length(allisi)<1) {
    result<- NA
  } else {
  isi.rank<-rank(allisi) #1 is smallest ISI
  st.length<-ceiling(max(st))
  spike.counts<-NULL
  for (i in 0:st.length-1) {
    spike.counts[i+1]<-sum((st>=i)*(st<(i+1))) #calculate spike count of 1s intervals
  }
  sc.hist<-hist(spike.counts, nclass=200, plot=FALSE)
  p.dist<-1-cumsum(sc.hist$counts/sum(sc.hist$counts)) 
  cutoff.indx<-sum(p.dist>cutoff.prob)
  theta.c<-max(c(2, ceiling(sc.hist$mids[cutoff.indx]))) #set theta_C to value where probability of spike counts is equal to cutoff.index (default 0.05)
  theta.c.end<-theta.c*0.5 #cutoff to end a burst
  isi.rel.rank<-isi.rank/max(isi.rank) #calculate relative rank of each isi
  
  t<-st
  j<-1
  burst.on<-0
  bc<-1
  dt<-1
  burst.time<-NULL
  burst.end<-NULL
  burst.dur<-NULL
  burst.size<-NULL
  burst.beg<-NULL
  while (j<length(allisi)-theta.c) {
    if (burst.on==0 && isi.rel.rank[j]<0.5) { #burst begins when rank of isi<0.5
      if (t[j+theta.c]<t[j]+dt){ 
        burst.on<-1
        burst.time[bc]<-t[j]
        burst.beg[bc]<-j
        brc<-j
      }
    } else if (burst.on==1) { 
      if (t[j+theta.c.end]>t[j]+dt) {
        burst.end[bc]<-t[j]
        burst.dur[bc]<-t[j]-burst.time[bc]
        burst.size[bc]<-j-brc
        bc<-bc+1
        burst.on<-0
      }
    }
    j<-j+1
  }
  if (burst.on==1) {
    tmp<-t[j]-burst.time[bc]
    burst.end[bc]<-burst.time[bc]+tmp
    burst.dur[bc]<-t[j]-burst.time[bc]
    burst.size[bc]<-j-brc
    bc<-bc+1
  }
  N.burst<-length(burst.time)
  if (N.burst<1) {
    result<-NA
  } else {
  end<-burst.beg+burst.size
  IBI<-c(NA, burst.time[-1]-burst.end[-N.burst])
  len<-burst.size+1
  mean.isis<-burst.dur/(len-1)
  result<-cbind(beg=burst.beg, end=end, IBI=IBI, len=len, durn=burst.dur, mean.isis=mean.isis, SI=rep(1, N.burst))
  }
  }
  result
}

In [11]:
import rpy2.robjects as ro
from rpy2.robjects.packages import importr
from rpy2.robjects import pandas2ri
from rpy2.robjects.conversion import localconverter

from io import StringIO
import sys

#function converting output matrix to data frame

def Matrix2DF(mat):   
    r_columns = [c for c in mat.colnames]
    columns = [c.replace('.', '_')  for c in r_columns]
    df = {}
    for i,c in enumerate(columns):
        column =  mat.rx(True, r_columns[i])
        df[c] = [x for x in column]
    return pd.DataFrame(df)

## Helper functions that simplify calling the burst detection methods

def Henning(sp_times, cutoff=0.05):
    print('Rank Thresholding Method')
    %R -i sp_times
    %R -i cutoff
    %R bursts = hennig.method(sp_times, cutoff)
    burst_frame = %Rget bursts
    return Matrix2DF(burst_frame)

def RankSurprise(sp_times, threshold=5): #works sometimes, but is generally problematic!
    print('Rank Surprise Method')
    old_stdout = sys.stdout
    result = StringIO()
    sys.stdout = result
    %R -i sp_times
    %R -i threshold
    %R bursts = RS.method(sp_times, threshold)
    burst_frame = %Rget bursts
    print(burst_frame)
    sys.stdout = old_stdout   
    return pd.read_csv(StringIO(result.getvalue()), header=1, delim_whitespace = True)
    #Strange formatting issue on the output - had to parse PRINT output

    
def CMA(sp_times, includeBorders=True, minimum=3, plot=False):
    #print('Cummulative Moving Average')
    %R -i sp_times
    %R -i includeBorders
    %R -i minimum
    %R -i plot
    %R bursts = CMA.method(sp_times, includeBorders, minimum, plot)
    burst_frame = %Rget bursts
    return Matrix2DF(burst_frame)


def LogISI(sp_times, cutoff=0.1):
    #print('LogISI method')
    %R -i sp_times
    %R -i cutoff
    %R bursts = logisi.pasq.method(sp_times, cutoff)
    burst_frame = %Rget bursts
    return Matrix2DF(burst_frame)

def PoissonSurprise(sp_times, threshold=5):
    #print('Poisson Surprise Method')
    %R -i sp_times
    %R -i threshold
    %R bursts = PS.method(sp_times, threshold)
    burst_frame = %Rget bursts
    return Matrix2DF(burst_frame)

# Testing functions implemented above on a synthetic spiketrain.

In [12]:
#Henning(spike_times, 0.05)

In [13]:
#RankSurprise(spike_times)

In [14]:
#CMA(spike_times, True, 3, False)

In [15]:
#LogISI(spike_times, 0.1)

In [16]:
#PoissonSurprise(spike_times, 5)

In [17]:
# TODO: Parameterize this so that you can pick which burst detection method and what paramters to feed it

In [18]:
# This cell and the cells immediately below it are my (Marcus') code. Other than commenting out
# some unneeded lines, these cells are the only modifications I made to this script.
# In order to get these cells to run, you have to run all the cells above them first.
# -Marcus

# Imports
import src.SessionNavigator as SessionNavigator
import src.SessionProcessor as SessionProcessor
import numpy as np
import pickle as pkl

# Insert the file path to your data (e.g. "C:\Users\MickeyMouse\MouseBrainData\manifest.json")
data_root = "C:/Users/Demogorgon/Documents/College/Marcus/Boston University PhD/Ocker Lab"
manifest_path = f"{data_root}/AllenSDK_Data/manifest.json"
save_path = f"{data_root}/correlations_and_bursts/data"
SAVE = True

# Create a SessionNavigator, which will open and navigate the data in manifest_path
navigator = SessionNavigator.SessionNavigator(manifest_path)

# The SessionNavigator has simple search functions to find sessions with specific criteria.
# Sessions can be sorted by visual areas observed (acronyms), mouse genotype, and 
# session type (either "functional_connectivity" or "brain_observatory_1.1")
acronyms = ['VISp', 'VISl', 'VISal', 'VISrl', 'VISam', 'VISpm', 'LGd']
session_type = "functional_connectivity"
genotype = "wt/wt"
session_ids = navigator.find_sessions(acronyms, genotype=genotype, session_type=session_type)
#sessions = [navigator.load_session(session_id) for session_id in session_ids]

# Load the session and open a SessionProcessor
current_session = session_ids[0]

session = navigator.load_session(current_session)
processor = SessionProcessor.SessionProcessor(session)

# We'll look at just one stimulus type here
stim = 'drifting_gratings_contrast'
stim_table = session.get_stimulus_table(stim)
stim_presentation_ids = stim_table.index.values
#stim_presentation_ids = stim_presentation_ids[:10]
spike_trains = session.presentationwise_spike_times(stim_presentation_ids, processor.all_units)
num_units = len(processor.all_units)
#one_unit_trains = spike_trains[spike_trains['unit_id']==951032492]

In [23]:
# This cell is the "work horse." All the burst calculations are made here.
# The bursts are then sorted by stim presentation ID in the next cell

import warnings
warnings.filterwarnings('ignore')
import time
import math
import copy
import pandas as pd

# Cutoff parameter for LogISI
CUT_OFF = 0.1

# The start and end times for this particular stimulus epoch
# (we just want the spikes associated with the stimulus in "stim")
start_times = list(stim_table["start_time"])
start_time = start_times[0]
stop_times = list(stim_table["stop_time"])
stop_time = stop_times[-1]

# Get the start and end times for each stim presentation
presentation_start_times = list(stim_table['start_time'])
presentation_end_times = list(stim_table['stop_time'])
num_presentations = len(stim_presentation_ids)

# The entire spike trains of every unit
spike_times = session.spike_times

burst_calculation_start_time = time.time()
unit_update_count = 0
num_failed_units = 0
unitwise_burst_times = {}
unitwise_single_times = {}
for unit_id, unit_spikes in spike_times.items():
    
    # Trim the spike train so that only spikes associated with the stimulus of interest
    # are included, assuming one stimulus epoch for that stim type
    range_spikes = unit_spikes[(unit_spikes >= start_time)*(unit_spikes <= stop_time)] 
    try:
        #################################################
        # First
        # Try calculating bursts. For some difficult-to-trace reason, some cells cause
        # LogISI and the other burst detection functions to crash. Debugging RPY2 is 
        # really challenging, so this try-except block should be considdered a kludge
        bursts = LogISI(range_spikes, CUT_OFF)

        
        #################################################
        # Second, adjust indicies and add absolute beginning and end times
        
        # LogISI returns the indices of the first and last spike of every detected burst.
        # We need those indices, but we also need the time at which those spikes occured
        bursts.rename(columns={'beg':'beg_idx', 'end':'end_idx'}, inplace=True)
        bursts['beg_idx'] -= 1
        bursts['end_idx'] -= 1
        num_bursts = len(bursts)
        abs_burst_beg_time = np.zeros(num_bursts)
        abs_burst_end_time = np.zeros(num_bursts)
        
        for k in range(num_bursts):
            # bursts.loc[k, 'beg_idx'] -> the index of the time of the 
            # corresponding spike in range_spikes
            # e.g. range_spikes[idx] = 1643.43
            # the idx^th spike occurred 1643 seconds after the beginning of
            # the session
            abs_burst_beg_time[k] = range_spikes[int(bursts.loc[k, 'beg_idx'])]
            abs_burst_end_time[k] = range_spikes[int(bursts.loc[k, 'end_idx'])]
        
        bursts['absolute_beg_time'] = abs_burst_beg_time
        bursts['absolute_end_time'] = abs_burst_end_time
        
        
        #################################################
        # Third, separate bursts from non-bursts (refered to as single spikes or singles)
        begginning_indices = bursts["beg_idx"]
        ending_indices = bursts["end_idx"]
        burst_indices = []
        for idx in range(len(begginning_indices)):
            burst_index_range = range(int(begginning_indices[idx]), int(ending_indices[idx]+1))
            burst_indices += list(burst_index_range)
        
        # Collect every spike that wasn't counted as a member of a burst
        absolute_single_times = [spike_time for idx, spike_time in enumerate(range_spikes) if idx not in burst_indices]
        num_singles = len(absolute_single_times)        
        singles = pd.DataFrame({"absolute_spike_time": absolute_single_times})
        
        
        #################################################
        # Fourth, associate every burst and every single with a stimulus presentation id
        
        # burst_associated_presentation = np.zeros(num_bursts)
        # for presentation_idx in range(num_presentations):
        #     presentation_start_time = presentation_start_times[presentation_idx]
        #     presentation_end_time = presentation_end_times[presentation_idx]
        #     stim_presentation_id = stim_presentation_ids[presentation_idx]
        #     burst_range = bursts.loc[(bursts["absolute_beg_time"] >= presentation_start_time)*(bursts["absolute_end_time"] < presentation_end_time)]
            
        
        burst_associated_presentation = np.zeros(num_bursts)
        resume_idx = 0
        for time_idx in range(num_bursts):
            burst_start_time = abs_burst_beg_time[time_idx]
            for presentation_idx in range(resume_idx, num_presentations):
                presentation_start_time = presentation_start_times[presentation_idx]
                #presentation_end_time = presentation_end_times[presentation_idx]
                stim_presentation_id = stim_presentation_ids[presentation_idx]
                if burst_start_time >= presentation_start_time:# and burst_start_time < presentation_end_time:
                    burst_associated_presentation[time_idx] = stim_presentation_id
                else:
                    resume_idx = presentation_idx - 1
                    break
        
        bursts["stimulus_presentation_id"] = burst_associated_presentation
        
        single_associated_presentation = np.zeros(num_singles)
        resume_idx = 0
        for time_idx in range(num_singles):
            single_time = absolute_single_times[time_idx]
            for presentation_idx in range(resume_idx, num_presentations):
                presentation_start_time = presentation_start_times[presentation_idx]
                #presentation_end_time = presentation_end_times[presentation_idx]
                stim_presentation_id = stim_presentation_ids[presentation_idx]
                if single_time >= presentation_start_time:# and single_time < presentation_end_time:
                    single_associated_presentation[time_idx] = stim_presentation_id
                else:
                    resume_idx = presentation_idx - 1
                    break
        
        singles["stimulus_presentation_id"] = single_associated_presentation
        
        
        #################################################
        # Fifth, relativize the burst train and the single train
        # "If the burst start time is less than the stim end time, subtract the start time
        # else go to the next start and stop times"

        rel_burst_beg_time = np.zeros(num_bursts)
        rel_burst_end_time = np.zeros(num_bursts)
        for idx, row in bursts.iterrows():
            stim_presentation_id = int(row["stimulus_presentation_id"])
            stim_info = stim_table.loc[(stim_table.index==stim_presentation_id)]
            stim_beg_time = stim_info["start_time"]
            #stim_end_time = stim_info["stop_time"]
            rel_burst_beg_time[idx] = row["absolute_beg_time"] - stim_beg_time
            rel_burst_end_time[idx] = row["absolute_end_time"] - stim_beg_time
            
        bursts["relative_beg_time"] = rel_burst_beg_time
        bursts["relative_end_time"] = rel_burst_end_time
        
        rel_single_time = np.zeros(num_singles)
        for idx, row in singles.iterrows():
            stim_presentation_id = int(row["stimulus_presentation_id"])
            stim_info = stim_table.loc[(stim_table.index==stim_presentation_id)]
            stim_beg_time = stim_info["start_time"]
            rel_single_time[idx] = row["absolute_spike_time"] - stim_beg_time
        
        singles["relative_spike_time"] = rel_single_time
            
        
        #################################################
        # Sixth, store burst and singles     
        unitwise_burst_times[unit_id] = bursts
        unitwise_single_times[unit_id] = singles
        
    except AttributeError as e:
        # If an exception was caught, log it, store None, and keep going
        num_failed_units += 1
        unitwise_burst_times[unit_id] = None
        unitwise_single_times[unit_id] = None
    
    # Print an update every 100 cells (sometimes these can take a while, up to an hour on my machine)
    unit_update_count += 1
    if unit_update_count % 100 == 0:
        elapsed_time = time.time() - burst_calculation_start_time
        minutes = math.floor(elapsed_time/60)
        seconds = math.floor(elapsed_time%60)
        print(f"Time elapsed: {minutes}m {seconds}s. {unit_update_count}/{num_units} burst trains separated from single trains.")

elapsed_time = time.time() - burst_calculation_start_time
minutes = math.floor(elapsed_time/60)
seconds = math.floor(elapsed_time%60)
print(f"{unit_update_count-num_failed_units}/{num_units} burst trains separated from single trains. Total time elapsed: {minutes}m {seconds}s.")
print(f"{num_failed_units}/{num_units} separations could not be completed.")
        
if SAVE:
    print(f"Saving whole unit burst trains to:\n{save_path}/")#{stim}_whole_burst_trains__sesssion_{session_ids[0]}")
    print("...")
    with open(f"{save_path}/{stim}__whole_burst_trains__session_{current_session}.pkl", 'wb') as f:
        pkl.dump(unitwise_burst_times, f)
    print("Done")
    print(f"Saving whole unit single trains to:\n{save_path}/")
    print("...")
    with open(f"{save_path}/{stim}__whole_single_trains__session_{current_session}.pkl", 'wb') as f:
        pkl.dump(unitwise_single_times, f)
    print("Done")
else:
    print("Data not saved")

Time elapsed: 1m 8s. 100/784 burst trains separated from single trains.
Time elapsed: 2m 29s. 200/784 burst trains separated from single trains.
Time elapsed: 3m 48s. 300/784 burst trains separated from single trains.
Time elapsed: 5m 12s. 400/784 burst trains separated from single trains.
Time elapsed: 6m 40s. 500/784 burst trains separated from single trains.
Time elapsed: 8m 11s. 600/784 burst trains separated from single trains.
Time elapsed: 9m 35s. 700/784 burst trains separated from single trains.
708/784 burst trains separated from single trains. Total time elapsed: 11m 3s.
76/784 separations could not be completed.
Saving whole unit burst trains to:
C:/Users/Demogorgon/Documents/College/Marcus/Boston University PhD/Ocker Lab/correlations_and_bursts/data/
...
Done
Saving whole unit single trains to:
C:/Users/Demogorgon/Documents/College/Marcus/Boston University PhD/Ocker Lab/correlations_and_bursts/data/
...
Done


In [22]:
negative_count = 0
for unit_id, current_unit_bursts in unitwise_burst_times.items():
    if current_unit_bursts is not None and (current_unit_bursts < 0).any().any():
        negative_count += 1
        #print(current_unit_bursts.shape)
negative_count

0

In [20]:
# This sorts the whole spike trains by stimulus presentation.

# Get relevant times
num_presentations = len(stim_presentation_ids)
presentation_start_times = list(stim_table['start_time'])
presentation_end_times = list(stim_table['stop_time'])

# Make variables for reporting time
unit_update_count = 0
burst_sorting_start_time = time.time()

# Start sorting
fully_organized_bursts = {}
for unit_id in spike_times.keys():
    bursts = unitwise_burst_times[unit_id]
    
    if bursts is not None:
        current_unit_presentationwise_bursts = {}
        k = 0
        for stim_id in stim_presentation_ids:
            # Get the start and stop times
            current_beg_time = presentation_start_times[k]
            current_end_time = presentation_end_times[k]
            
            # Take out everything that happened before the beginning of this stim
            absolute_bursts = bursts.loc[(bursts["beg_time"]>=current_beg_time)]
            
            # Take out everything that happened after the end of this stim
            absolute_bursts = absolute_bursts.loc[(absolute_bursts["end_time"]<=current_end_time)]
            
            # beg_time and end_time are measured absolutely from the start of the session.
            # We also want to know the relative time of the bursts after the particular presentation
            relative_starts = absolute_bursts["beg_time"] - current_beg_time
            relative_ends = absolute_bursts["end_time"] - current_end_time
            
            # Name/rename everything appropriately
            absolute_bursts.rename(columns={'beg_time':'abs_beg_time', 'end_time':'abs_end_time'}, inplace=True)
            absolute_bursts["relative_beg_time"] = relative_starts
            absolute_bursts["relative_end_time"] = relative_ends
            
            # Store and increment
            current_unit_presentationwise_bursts[stim_id] = absolute_bursts
            k += 1
    else:
        current_unit_presentationwise_bursts = None
    
    fully_organized_bursts[unit_id] = current_unit_presentationwise_bursts
    
    # Print an update every 100 cells (sometimes these can take a while, up to an hour on my machine)
    unit_update_count += 1
    if unit_update_count % 100 == 0:
        elapsed_time = time.time() - burst_sorting_start_time
        minutes = math.floor(elapsed_time/60)
        seconds = math.floor(elapsed_time%60)
        print(f"Time elapsed: {minutes}m {seconds}s. {unit_update_count}/{num_units} burst trains sorted.")

if SAVE:
    print(f"Saving whole unit burst trains to:\n{save_path}/")#{stim}_whole_burst_trains__sesssion_{session_ids[0]}")
    print("...")
    with open(f"{save_path}/{stim}__sorted_burst_trains__session_{session_ids[0]}.pkl", 'wb') as f:
        pkl.dump(fully_organized_bursts, f)
    print("Done")
else:
    print("Data not saved")

Time elapsed: 0m 56s. 100/784 burst trains sorted.
Time elapsed: 1m 56s. 200/784 burst trains sorted.
Time elapsed: 2m 58s. 300/784 burst trains sorted.
Time elapsed: 3m 58s. 400/784 burst trains sorted.
Time elapsed: 5m 0s. 500/784 burst trains sorted.
Time elapsed: 6m 2s. 600/784 burst trains sorted.
Time elapsed: 7m 5s. 700/784 burst trains sorted.
Saving whole unit burst trains to:
C:/Users/Demogorgon/Documents/College/Marcus/Boston University PhD/Ocker Lab/correlations_and_bursts/data/
...
Done


In [59]:
for unit_id, current_unit_bursts in bursts_by_unit.items():
    for stim_id in stim_presentation_ids:
        if current_unit_bursts[stim_id] is not None:
            print(unit_id)
            break

KeyError: 3798

In [18]:
with open(f"{save_path}/{stim}_bursts__session_{session_ids[0]}.pkl", 'rb') as f:
    test_load = pkl.load(f)
#test_load