Skip to content

Commit

Permalink
smal restructurization of the project
Browse files Browse the repository at this point in the history
+ new docstrings and documentation updated
  • Loading branch information
dokato committed Aug 18, 2015
1 parent 1658e81 commit c58f136
Show file tree
Hide file tree
Showing 31 changed files with 639 additions and 157 deletions.
42 changes: 1 addition & 41 deletions connectivipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,6 @@
from conn import conn_estim_dc
from mvarmodel import Mvar
from mvar.fitting import mvar_gen, mvar_gen_inst
from plot import plot_conn

__version__ = '0.34'

# plain plotting from values
def plot_conn(values, name='', fs=1, ylim=None, xlim=None, show=True):
'''
Plot connectivity estimation results. Allows to plot your results
without using *Data* class.
Args:
*values* : numpy.array
connectivity estimation values in shape (fq, k, k) where fq -
frequency, k - number of channels
*name* = '' : str
title of the plot
*fs* = 1 : 'int
sampling frequency
*ylim* = None : list
range of y-axis values shown, e.g. [0,1]
*None* means that default values of given estimator are taken
into account
*xlim* = None : list [from (int), to (int)]
range of y-axis values shown, if None it is from 0 to Nyquist frequency
*show* = True : boolean
show the plot or not
'''
fq, k, k = values.shape
fig, axes = plt.subplots(k, k)
freqs = np.linspace(0, fs//2, fq)
if not xlim:
xlim = [0, np.max(freqs)]
if not ylim:
ylim = [np.min(values), np.max(values)]
for i in xrange(k):
for j in xrange(k):
axes[i, j].fill_between(freqs, values[:, i, j], 0)
axes[i, j].set_xlim(xlim)
axes[i, j].set_ylim(ylim)
plt.suptitle(name,y=0.98)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
if show:
plt.show()
65 changes: 54 additions & 11 deletions connectivipy/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def spectrum(acoef, vcoef, fs=1, resolution=100):
S_z[e] = np.dot(np.dot(H_z[e],vcoef), H_z[e].T.conj())
return A_z, H_z, S_z

def spectrum_inst(acoef, vcoef, fs=1, resolution=None):
def spectrum_inst(acoef, vcoef, fs=1, resolution=100):
"""
Generating data point from matrix *A* with MVAR coefficients taking
into account zero-lag effects.
Expand Down Expand Up @@ -99,10 +99,7 @@ def spectrum_inst(acoef, vcoef, fs=1, resolution=None):
Int. J. Bioelectromagn. 11, 74–79 (2009).
"""
p, k, k = acoef.shape
if resolution == None:
freqs=np.linspace(0,fs/2,512)
else:
freqs=np.linspace(0,fs/2,resolution)
freqs=np.linspace(0,fs/2,resolution)
A_z=np.zeros((len(freqs),k,k),complex)
B_z=np.zeros((len(freqs),k,k),complex)

Expand Down Expand Up @@ -133,7 +130,8 @@ class Connect(object):
__metaclass__ = ABCMeta

def __init__(self):
self.values_range = [None, None]
self.values_range = [None, None] # normalization bands
self.two_sided = False # only positive, or also negative values

@abstractmethod
def calculate(self):
Expand All @@ -154,11 +152,13 @@ def short_time(self, data, nfft=None, no=None, **params):
*no* = None : int
overlap length (if None it's N/10)
*params* :
additional parameters
additional parameters specific for chosen estimator
Returns:
*A_z* : numpy.array
z-transformed A(f) complex matrix in shape (*resolution*, k, k)
*stvalues* : numpy.array
short time values (time points, frequency, k, k), where k
is number of channels
"""
assert nfft>no, "overlap must be smaller than window"
if len(data.shape)>2:
k, N, trls = data.shape
else:
Expand Down Expand Up @@ -187,6 +187,7 @@ def short_time(self, data, nfft=None, no=None, **params):

def short_time_significance(self, data, Nrep=10, alpha=0.05,\
nfft=None, no=None, **params):
assert nfft>no, "overlap must be smaller than window"
if len(data.shape)>2:
k, N, trls = data.shape
else:
Expand Down Expand Up @@ -217,10 +218,16 @@ def significance(self, data, Nrep=10, alpha=0.05, **params):
return signific

def levels(self, signi, alpha, k):
ficance = np.zeros((k,k))
if self.two_sided:
ficance = np.zeros((2, k, k))
else:
ficance = np.zeros((k, k))
for i in range(k):
for j in range(k):
ficance[i][j] = np.max(st.scoreatpercentile(signi[:,:,i,j], alpha*100, axis=1))
if self.two_sided:
ficance[i][j] = np.max(st.scoreatpercentile(signi[:,:,i,j], alpha*100, axis=1))
else:
ficance[i][j] = np.max(st.scoreatpercentile(signi[:,:,i,j], alpha*100, axis=1))
return ficance

def __calc_multitrial(self, data, **params):
Expand Down Expand Up @@ -281,6 +288,33 @@ def __init__(self):

def short_time(self, data, nfft=None, no=None, mvarmethod='yw',\
order=None, resol=None, fs=1):
"""
It overloads :class:`ConnectAR` method :func:`Connect.short_time`.
Short-tme version of estimator, where data is windowed into parts
of length *nfft* and overlap *no*. *params* catch additional
parameters specific for estimator.
Args:
*data* : numpy.array
data matrix
*nfft* = None : int
window length (if None it's N/5)
*no* = None : int
overlap length (if None it's N/10)
*mvarmethod* = 'yw' :
MVAR parameters estimation method
*order* = None:
MVAR model order; it None, it is set automatically basing
on default criterion.
*resol* = None:
frequency resolution; if None, it is 100.
*fs* = 1 :
sampling frequency
Returns:
*stvalues* : numpy.array
short time values (time points, frequency, k, k), where k
is number of channels
"""
assert nfft>no, "overlap must be smaller than window"
if len(data.shape)>2:
k, N, trls = data.shape
else:
Expand Down Expand Up @@ -311,6 +345,7 @@ def short_time(self, data, nfft=None, no=None, mvarmethod='yw',\
def short_time_significance(self, data, Nrep=100, alpha=0.05, method='yw',\
order=None, fs=1, resolution=None,\
nfft=None, no=None, **params):
assert nfft>no, "overlap must be smaller than window"
if len(data.shape)>2:
k, N, trls = data.shape
else:
Expand Down Expand Up @@ -725,6 +760,7 @@ def calculate(self, data, cnfft=None, cno=None, window=np.hanning, im=False):
.. [1] M. B. Priestley Spectral Analysis and Time Series.
Academic Press Inc. (London) LTD., 1981
"""
assert cnfft>cno, "overlap must be smaller than window"
k, N = data.shape
if not cnfft:
cnfft = int(N/5)
Expand Down Expand Up @@ -759,6 +795,9 @@ class PSI(Connect):
PSI - class inherits from :class:`Connect` and overloads
:func:`Connect.calculate` method.
"""
def __init__(self):
self.two_sided = True

def calculate(self, data, band_width=4, psinfft=None, psino=0, window=np.hanning):
"""
Phase Slope Index calculation using FFT mehtod.
Expand All @@ -781,6 +820,7 @@ def calculate(self, data, band_width=4, psinfft=None, psino=0, window=np.hanning
.. [1] Nolte G. et all, Comparison of Granger Causality and
Phase Slope Index. 267–276 (2009).
"""
assert psinfft>psino, "overlap must be smaller than window"
k, N = data.shape
if not psinfft:
psinfft = int(N/4)
Expand All @@ -798,6 +838,9 @@ class GCI(Connect):
GCI - class inherits from :class:`Connect` and overloads
:func:`Connect.calculate` method.
"""
def __init__(self):
self.two_sided = True

def calculate(self, data, method='yw', order=None):
"""
Granger Causality Index calculation from MVAR model.
Expand Down
2 changes: 1 addition & 1 deletion connectivipy/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def resample(self, fs_new):
self.__data = ss.resample(self.__data, new_nr_samples, axis=1)
self.__fs = fs_new

def fit_mvar(self, p = None, method = 'yw'):
def fit_mvar(self, p=None, method='yw'):
'''
Fitting MVAR coefficients.
Expand Down
5 changes: 0 additions & 5 deletions connectivipy/mvar/orders.py

This file was deleted.

46 changes: 46 additions & 0 deletions connectivipy/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
#! /usr/bin/env python

import numpy as np
import matplotlib.pyplot as plt

# plain plotting from values
def plot_conn(values, name='', fs=1, ylim=None, xlim=None, show=True):
'''
Plot connectivity estimation results. Allows to plot your results
without using *Data* class.
Args:
*values* : numpy.array
connectivity estimation values in shape (fq, k, k) where fq -
frequency, k - number of channels
*name* = '' : str
title of the plot
*fs* = 1 : 'int
sampling frequency
*ylim* = None : list
range of y-axis values shown, e.g. [0,1]
*None* means that default values of given estimator are taken
into account
*xlim* = None : list [from (int), to (int)]
range of y-axis values shown, if None it is from 0 to Nyquist frequency
*show* = True : boolean
show the plot or not
'''
fq, k, k = values.shape
fig, axes = plt.subplots(k, k)
freqs = np.linspace(0, fs//2, fq)
if not xlim:
xlim = [0, np.max(freqs)]
if not ylim:
ylim = [np.min(values), np.max(values)]
for i in xrange(k):
for j in xrange(k):
axes[i, j].fill_between(freqs, values[:, i, j], 0)
axes[i, j].set_xlim(xlim)
axes[i, j].set_ylim(ylim)
plt.suptitle(name,y=0.98)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
if show:
plt.show()

0 comments on commit c58f136

Please sign in to comment.