# Sleep Stage Transition Analysis

This Jupyter notebook will run simulations of the thalamocortical network model described in the paper "Cellular and neurochemical basis of sleep stages in the thalamocortical network" by Krishnan et al. The notebook guides you through the process of setting up the environment, running the simulation, loading the output data, and performing various analyses to understand the transitions between different sleep stages.

## Model Source Code

Cybershuttle will help you run the simulation code on remote HPC resources. But if you prefer, you can install it locally. The simulation code is implemented in C++ and uses OpenMP for parallelism. You will need a C++ compiler and OpenMP support.

### Cloning the Repository

Clone the repository to your local machine using the following command:
```bash
git clone https://github.com/bazhlab-ucsd/sleep-stage-transition.git
```

### Building the Project

Navigate to the project directory and compile the code:
```bash
cd sleep-stage-transition
make network
make run
```
### Running Simulations

The simulation parameters can be modified in the `params.txt` file. Adjust this file to set different levels of neuromodulators and other parameters. Run the simulation with:
```bash
make run
```



## Running the simulation code using Cybershuttle (Recomended)

Cybershuttle's Cybertune library simplifies the workflow of running parameter sweep simulations by automating input file transfers, remote simulations, and output retrievals. We will first run an end to end single simulation.

In [None]:
from cybershuttle_tune.sdk import ExecutionContext
from cybershuttle_tune.sdk import ApplicationContext
from cybershuttle_tune.sdk import TuneConfig
from cybershuttle_tune.sdk import DiscreteParam
from cybershuttle_tune.sdk import run_grid_search
from cybershuttle_tune.sdk import get_sweep_status
from cybershuttle_tune.sdk import fetch_outputs
from cybershuttle_tune.sdk import authorize
from cybershuttle_tune.cli.auth import get_access_token
import os
import json
from pathlib import Path

In [None]:
# Review Input Parameters
! cat inputs/params.txt

### Execute a single simulation on Remote HPC Resource

In [None]:
# Configure input parameters 
params = [
    DiscreteParam('ha_awake', [0.6]),
    DiscreteParam('ach_cx_awake', [0.6]),
    DiscreteParam('ach_th_awake', [0.6]),
]

# Run the simulation on the ACCESS-CI, San Diego Supercomputer Centers Expanse Cluster
execution_context = ExecutionContext(resource = "Expanse", 
                                     project = "Default Project", 
                                     group_resource_profile = "Default", 
                                     cpu = 1, 
                                     memory = 1000, 
                                     queue = "shared")

input_file_mapping = {"Network Config File":"network.cfg", "Param File": "params.txt"}
app_context = ApplicationContext(app_name = "Sleep-Stage-Transition", input_dir = "inputs", input_file_mapping = input_file_mapping)

tune_config = TuneConfig(
    app_context = app_context,
    params = params, 
    execution_context = execution_context)

In [None]:
response = run_grid_search(tune_config = tune_config)

In [None]:
states, indexes = get_sweep_status(response['job_name'], response['working_dir'])
states

In [None]:
output_paths = fetch_outputs(response['job_name'], response['working_dir'])

In [None]:
output_paths

## Loading and Processing Output Data

After running the simulation, the output files will be saved in the `out` directory. These files contain the membrane voltage and other relevant data of different neuron types.

# For simulations ran through Cybershuttle, the Cybershuttle local agent can transparently fetch these output files

In [None]:
import numpy as np
import matplotlib.pylab as plt

### Functions

#### Simple Plot Function
This function visualizes the simulation output of the thalamocortical network model. It creates image plots of the membrane potentials of cortical, thalamic, and reticular neurons, as well as time-series plots of specific neurons' membrane potentials.

In [None]:
def gen_cx_tc_re_imageplot(spath):
    cx = np.loadtxt(spath + 'time_cx')
    tc = np.loadtxt(spath + 'time_tc')
    re = np.loadtxt(spath + 'time_re')

    plt.figure(figsize=(10,5))
    plt.subplot(3,2,1)
    plt.imshow(cx[:,1:-1].T, aspect='auto',vmin=-80,vmax=-50)

    plt.subplot(3,2,3)
    plt.imshow(tc[:,1:-1].T, aspect='auto',vmin=-80,vmax=-50)

    plt.subplot(3,2,5)
    plt.imshow(re[:,1:-1].T, aspect='auto',vmin=-80,vmax=-50)

    plt.subplot(3,2,2)
    plt.plot(cx[:,200])

    plt.subplot(3,2,4)
    plt.plot(tc[:,50])

    plt.subplot(3,2,6)
    plt.plot(re[:,50])
    return cx,tc,re

### Compute Mean as Local Field Potential (LFP)
This function provides a way to compute and visualize the LFP from the membrane potentials of cortical neurons, giving insights into the overall electrical activity in the simulated brain region.

In [None]:
def compute_mean_as_lfp(time_cx):
  # Compute mean from index 1 to -1, assuming 0 is cell number and last is predefined
  return np.mean(time_cx[:,1:-1])

#### Compute LFP
The function fftc performs a Fast Fourier Transform (FFT) on the LFP data to analyze the frequency components of the signal. This function provides a method to analyze the frequency components of the LFP signal, which can reveal insights into the underlying neuronal dynamics and oscillatory activity.

In [None]:
import scipy

from scipy import signal

def fftc(data, fs, pad, h_freq ):
    # Usage : fft_corrected(data, fs, win, pad )
    #         data  --- np array
    #         fs  --- sampling rate
    #         pad  --- padding
    # need to load numpy as np

    # if win.shape[0]>0:
    #     data=np.multiply(data,win)

    pad = data.shape[0]
    fft_out=np.fft.fft(data,n=pad)

    # frequency=np.multiply(np.arange(0,pad/2),fs/pad) 

    frequency=np.linspace(0.0, 1.0/(2.0/fs), pad/2)
    # frequency=np.fft.fftfreq(data.shape[-1])

    power=np.absolute(fft_out)
    phase=np.angle(fft_out);

    h_freq_bin = np.where(frequency<h_freq)    
    power=power[h_freq_bin]/np.shape(data)[0];
    phase=phase[h_freq_bin];
    frequency=frequency[h_freq_bin];

    return power, phase, frequency

### Morlet Spectrogram

This function performs time-frequency analysis on a signal using the Morlet wavelet. The implementation is based on the textbook 'The Illustrated Wavelet Transform Handbook' (Paul Addison), pp. 33ff., and the paper 'Comparison of the Hilbert transform and wavelet methods...' (Le Van Quyen, 2001).

In [None]:
def morlet_wav(x, srate, sigma, flo, fhi, deltaf):
    N_orig = len(x)
    #zero-pad x so that the number of entries is a power of 2, so that the fft will be computationally efficient
    N=int( 2**(  np.ceil(  np.log(N_orig) / np.log(2)  )  )  )
    x=np.concatenate([x,np.zeros(N-len(x))])
    Xk=np.fft.fft(x)

    #figure out number of total frequency values at which you will be sampling
    #for the time-frequency analysis, and allocate space in 'Transform' (first
    #row of 'Transform' contains the power as a function of time for the lowest frequency
    freqvals=np.arange(flo,fhi+deltaf,deltaf)
    num_freqvals=len(freqvals)
    Transform=np.zeros((num_freqvals,N), dtype=complex)

    freq_samples=srate*np.arange(-N/2,N/2)/N #construct array of frequency values at which you sample the Fourier Transform of the wavelet function (Addison Eq. 2.38); don't need '-1' (as in Matlab code) bc. of how arange works; also, can assume N is divisible by 2 because of above

    for i_f, freq in enumerate(freqvals):
        #construct fourier transform of the Morlet wavelet in such a form that we
        #can use Eq. 2.35 (p. 33, Addison) along with iFFT to determine Transform
        #for specific frequency band. Note that my normalization is not the
        #same as in Addison's textbook.
        W = np.sqrt(2*np.pi)*sigma*np.exp(-2*np.pi**2*sigma**2*(freq_samples-freq)**2)
        Transform[i_f:i_f+1, :] = np.fft.ifft(Xk * np.fft.ifftshift(W))

    #throw away the part of Transform that corresponded to zero-padded portion of 'x'
    Transform=Transform[:,1:N_orig+1]
    #compute phases and modulus 
    Phases = np.arctan2(np.imag(Transform), np.real(Transform))
    Modulus = np.abs(Transform)

    return Modulus, Phases, Transform

### Plot results from one simulation

In [None]:
print(output_paths[0])
for outputfile in output_paths:
    out_tar_dir = outputfile + "/" + response['job_name'] + '_' + outputfile.split('/')[-2]
    out_tar_file_path = out_tar_dir + '/' + 'output.tar.gz'
    print(out_tar_dir)
    print(out_tar_file_path)

import tarfile
out_tar_file = tarfile.open(out_tar_file_path) 
out_tar_file.extractall(out_tar_dir) 
out_tar_file.close()
output_dir = out_tar_dir + '/' + 'output/'

cx,tc,re = gen_cx_tc_re_imageplot(output_dir)

## FFT spectrogram

seltime=np.arange(10000,15000)

plt.subplot(2, 1, 1)
avg=signal.detrend(np.mean(cx[seltime,1:-1],axis=1))
plt.plot(avg)

plt.subplot(2, 1, 2)
power, phase, frequency = fftc(avg, 1000, 4, 200 )
plt.plot(frequency,(power))
plt.xlim(0, 20)
plt.ylim(0, 2000)

## Wavelet Spectrogram

Perform detailed time-frequency analysis of a signal, revealing the dynamics of its frequency components over time.

flo = 1
fhi = 100
deltaf = 0.1
freqvals=np.arange(flo,fhi+deltaf,deltaf)

dt=0.025 #ms
sigma = 1.0 #width of gaussian window (in seconds) for frequency-time analysis
cut_start=1000; #number of milliseconds to cut out of beginning
cut_end=1000; #number of milliseconds to cut out of end
dsample=100; #downsample by factor 'dsample'

temp= avg#np.loadtxt('lfp_nhost=10.txt')
data=temp[0:len(temp):dsample] #downsample data
time=dsample*dt*np.arange(0,len(data))
srate = 1000/(dsample*dt) #Hz

Modulus, Phases, Transform = morlet_wav(data,srate,sigma,flo,fhi,deltaf)

plt.pcolormesh(time[round(cut_start/(dsample*dt)):len(time)-round(cut_end/(dsample*dt))], freqvals, Modulus[:,round(cut_start/(dsample*dt)):len(time)-round(cut_end/(dsample*dt))], rasterized='True', cmap='jet')
plt.xlabel('Time (ms)')
plt.ylabel('Frequency (Hz)')
plt.colorbar()
#plt.clim((0,250))
plt.xlim([10000, 360000])

# Scale up Simulations

## Paramaterizing histamine (HA) and acetylcholine (ACh)

We will run simulations on level of histamine (HA) during the awake state in the thalamocortical network model, level of acetylcholine (ACh) in the cortical neurons (cx) and thalamic neurons (th) during the awake state.

In [None]:
# Launch 10 simulations with paramaterization of histamine from 0.3 to 1.2

params = [
    DiscreteParam('ha_awake', [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]),
    DiscreteParam('ach_cx_awake', [0.6]),
    DiscreteParam('ach_th_awake', [0.6]),
]

# Run the simulation on the ACCESS-CI, San Diego Supercomputer Centers Expanse Cluster
execution_context = ExecutionContext(resource = "Expanse", 
                                     project = "Default Project", 
                                     group_resource_profile = "Default", 
                                     cpu = 1, 
                                     memory = 1000, 
                                     queue = "shared")

input_file_mapping = {"Network Config File":"network.cfg", "Param File": "params.txt"}
app_context = ApplicationContext(app_name = "Sleep-Stage-Transition", input_dir = "inputs", input_file_mapping = input_file_mapping)

tune_config = TuneConfig(
    app_context = app_context,
    params = params, 
    execution_context = execution_context)

In [None]:
# Launch the paramater sweep runs on remote HPC Resourcs
response = run_grid_search(tune_config = tune_config)

In [None]:
# Query Cybeshuttle Server to fetch the job status. Wait until all statuses are complete
states, indexes = get_sweep_status(response['job_name'], response['working_dir'])
states

In [None]:
remote_output_paths = get_remote_data_dirs(response['job_name'], response['working_dir'])

In [None]:
remote_output_paths

### Remote Cell Execution

#### Load remote cell execution plugin

In [None]:
import sys
sys.path.append('/')
import airavata_magics
airavata_magics.load_ipython_extension(get_ipython())

#### Initialize the remote execution agent. Provide the computation requirements and target cluster name

In [None]:
%init_remote cluster=expanse cpu=2 memory=2024 queue=shared walltime=60

#### Wait for agent to come online

In [None]:
%status_remote

In [None]:
%status_remote

#### Run basic computations

In [None]:
%%run_remote
# Your code here
a = 10
print(a)

#### Run more complex analytics

In [None]:
%%run_remote
import matplotlib.pyplot as plt

# Sample data
x = [1, 2, 3, 4, 5]
y = [2, 3, 5, 7, 11]

# Create the plot
plt.plot(x, y)

# Add a title and labels
plt.title('Simple Line Plot')
plt.xlabel('X Axis')
plt.ylabel('Y Axis')

# Show the plot
plt.show()

#### Run command line 

In [None]:
%%run_remote
!ls /home/

#### Terminate the agent once the computations are completed

In [None]:
%terminate_remote