Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into placeholder
  • Loading branch information
changsiyao committed Dec 12, 2015
2 parents 40788a2 + af5c9ba commit 3ec9533
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 40 deletions.
27 changes: 24 additions & 3 deletions code/scripts/logistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@
import matplotlib.pyplot as plt
from sklearn.cross_validation import cross_val_score
from sklearn.linear_model import LogisticRegression
import sys

# Path to function
pathtofunction = '../utils'
# Append path to sys
sys.path.append(pathtofunction)

from logistic_function import create_confusion, getMin_thrs, plot_roc

pathtofolder = '../../data/'

nsub = 16
beh_lambda = np.array([])
beh_score = np.array([])
val_score = np.array([])
Min_thrs = np.array([])
AUC_smr = np.array([])
fig = plt.figure(figsize=(20,20))
for i in np.arange(1, nsub+1):
run1 = np.loadtxt(pathtofolder + 'ds005/sub0'+ str(i).zfill(2)+
'/behav/task001_run001/behavdata.txt', skiprows = 1)
Expand All @@ -18,7 +29,7 @@
'/behav/task001_run003/behavdata.txt', skiprows = 1)
behav = np.concatenate((run1, run2, run3), axis=0)
behav = behav[np.logical_or.reduce([behav[:,5] == x for x in [0,1]])]
X = zip(np.ones(len(behav)), behav[:, 1],behav[:, 2])
X = zip(np.ones(len(behav)), behav[:, 1], behav[:, 2])
y = behav[:, 5]
logreg = LogisticRegression(C=1e5)
# C=1e5 specifies a regularization strength
Expand All @@ -34,8 +45,18 @@
scores = cross_val_score(LogisticRegression(), X, y,
scoring='accuracy', cv=10)
val_score = np.append(val_score, scores.mean())

# calculate the AUC and plot ROC curve for each subject
logreg_proba = logreg.predict_proba(X)
confusion = create_confusion(logreg_proba, y)
addsub = fig.add_subplot(4, 4, i)
addsub, AUC = plot_roc(confusion, addsub, i)
Min_thrs = np.append(Min_thrs, getMin_thrs(confusion))
AUC_smr = np.append(AUC_smr, AUC)

np.savetxt(pathtofolder + 'ds005/models/lambda.txt', beh_lambda)
np.savetxt(pathtofolder + 'ds005/models/reg_score.txt', beh_score)
np.savetxt(pathtofolder + 'ds005/models/cross_val_score.txt', val_score)
np.savetxt(pathtofolder + 'ds005/models/cross_val_score.txt', val_score)
np.savetxt(pathtofolder + 'ds005/models/Min_thrs.txt', Min_thrs.reshape(16,3))
np.savetxt(pathtofolder + 'ds005/models/AUC_smr.txt', AUC_smr)
fig.savefig(pathtofolder + 'ds005/models/roc_curve')

101 changes: 101 additions & 0 deletions code/utils/logistic_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import numpy as np
import matplotlib.pyplot as plt

def create_confusion(logreg_proba, y, thrs_inc=0.01):
"""
Creates the confusion matrix based on various levels of discriminate
probability thresholds
Parameters
----------
actual: Actual responses, 1-d array with values 0 or 1
fitted: Fitted probabilities, 1-d array with values between 0 and 1
thrs_inc: increment of threshold probability (default 0.05)
Returns
-------
Confusion Matrix : Array of dim (X, 5) where X is the number of different
thresholds
Column 1: Threshold value between 0, 1
Columns 2-5 show counts for:
Column 2: True postive
Column 3: True negative
Column 4: False postive
Column 5: False negative
"""
thrs_array = np.linspace(0, 1, 1/thrs_inc +1)
confusion = np.ones((len(thrs_array), 5))
confusion[:,0] = thrs_array
for i in range(int(1/thrs_inc +1)):
t = thrs_array[i]
# Classifier / label agree and disagreements for current threshold.
TP_t = np.logical_and( logreg_proba[:,1] > t, y==1 ).sum()
TN_t = np.logical_and( logreg_proba[:,1] <=t, y==0 ).sum()
FP_t = np.logical_and( logreg_proba[:,1] > t, y==0 ).sum()
FN_t = np.logical_and( logreg_proba[:,1] <=t, y==1 ).sum()
confusion[i, 1:5] = [TP_t, TN_t, FP_t, FN_t]
return confusion


def getMin_thrs(confusion):
"""
Returns the threshold with the smallest number of wrong predictions
Parameters:
-----------
Confustion matrix: 2-d array with 5 columns
Returns:
--------
thrs: min threshold that gives minimum wrong predictions: columns 3 +
column 4
false_pos: number of incorrect trues
false_neg: number of incorrect falses
"""
thrs_min = np.argmin(confusion[:,3]+ confusion[:,4])
col_out = confusion[thrs_min, :]
thrs = col_out[0]
false_pos = col_out[3]
false_neg = col_out[4]
return thrs, false_pos, false_neg


def plot_roc(confusion, fig, sub_i):
"""
function to plot the ROC (receiver operating characteristic) curve and
calculate the corresponding AUC (Area Under Curve).
Parameters:
-----------
Confustion matrix: 2-d array with 5 columns
Returns:
--------
fig: The ROC curve
AUC: Correspong AUC value
"""
ROC = np.zeros((confusion.shape[0],2))
for i in range(confusion.shape[0]):
# Compute false positive rate for current threshold.
FPR_t = confusion[i, 3] / float(confusion[i, 3] + confusion[i, 2])
ROC[i,0] = FPR_t

# Compute true positive rate for current threshold.
TPR_t = confusion[i, 1] / float(confusion[i, 1] + confusion[i, 4])
ROC[i,1] = TPR_t

# Plot the ROC curve.
plt.plot(ROC[:,0], ROC[:,1], lw=2)
plt.xlim(-0.1,1.1)
plt.ylim(-0.1,1.1)
plt.xlabel('$FPR(t)$')
plt.ylabel('$TPR(t)$')
plt.grid()

AUC = 0.
for i in range(confusion.shape[0]-1):
AUC += (ROC[i+1,0]-ROC[i,0]) * (ROC[i+1,1]+ROC[i,1])
AUC *= -0.5

plt.title('subject '+ str(sub_i)+', AUC = %.4f'%AUC)
return fig, AUC
36 changes: 9 additions & 27 deletions data/Makefile
Original file line number Diff line number Diff line change
@@ -1,35 +1,17 @@
data:
wget http://openfmri.s3.amazonaws.com/tarballs/ds005_raw.tgz
<<<<<<< HEAD

wget http://nipy.bic.berkeley.edu/rcsds/ds005_mnifunc.tar
wget http://nipy.bic.berkeley.edu/rcsds/mni_icbm152_nlin_asym_09c_2mm.tar.gz

validate:
python data.py

unzip:
tar -xvzf ds005_raw.tgz
for i in {1..9}
do
for j in {1..3}
do
gunzip ds005/sub00${i}/BOLD/task001_run00${j}/bold.nii.gz
done
done

for i in {10..16}
do
for j in {1..3}
do
gunzip ds005/sub0${i}/BOLD/task001_run00${j}/bold.nii.gz
done
done

=======
tar -xvzf ds005_raw.tgz
#wget http://nipy.bic.berkeley.edu/rcsds/ds005_mnifunc.tar
#tar -xvf ds005_mnifunc.tar
tar -xvf ds005_mnifunc.tar
tar -xvzf mni_icbm152_nlin_asym_09c_2mm.tar.gz
rm ds005_raw.tgz
#rm ds005_mnifunc.tar

validate:
python data.py
>>>>>>> a96098ccbb47c304f972e54f9161165806dc04f1
rm ds005_mnifunc.tar
rm mni_icbm152_nlin_asym_09c_2mm.tar.gz
mv mni_icbm152_nlin_asym_09c_2mm templates
mv templates/mni_icbm152_t1_tal_nlin_asym_09c_2mm.nii templates/mni_standard.nii
18 changes: 18 additions & 0 deletions data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
The ds005 dataset, filtered ds005 dataset, and mni templates are stored here. Th
e makefile is written such that:

- 'make data' will pull in the appropriate data
- 'make unzip' will unzip, remove, and rename certain files
- 'make validate' will run data.py to check the hashes of each downloaded file w
ith a master hashlist included, ensuring all downloaded data is correct

THE COMMANDS SHOULD BE DONE IN THIS ORDER to be successfully validated. The ds00
5 folder contains subfolders for each subject, the most relevant of which are:

- BOLD: raw data of fMRI scans for each of the subjects three runs, as well as d
isplacement/variance data
- behav: file for each run that contains the onsets, potential gains, potential
losses, and response of each trial
- model: filtered, processed data of fMRI scans for each of the subjects three r
uns, and the onsets files
-
11 changes: 3 additions & 8 deletions data/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import print_function, division
import os
import hashlib
import json

Expand Down Expand Up @@ -48,8 +47,7 @@ def check_hashes(d):
ex: make_hash_list("ds005", "temp") makes hashlist for all of ds005
including subdirectories
"""

"""
"""
file_paths = []
for path, subdirs, files in os.walk(directory):
for name in files:
Expand All @@ -58,12 +56,9 @@ def check_hashes(d):
with open(title, 'w') as outfile:
json.dump(dictionary, outfile)
return dictionary
"""
"""

if __name__ == "__main__":
with open('hashList.txt', 'r') as hl:
with open('total_hash.txt', 'r') as hl:
d = json.load(hl)
check_hashes(d)
#with open('new_hashList.txt', 'r') as hl2:
# data = json.load(hl2)
#check_hashes(data)
1 change: 0 additions & 1 deletion data/hashList.txt

This file was deleted.

2 changes: 1 addition & 1 deletion data/new_hashList.txt → data/total_hash.txt

Large diffs are not rendered by default.

0 comments on commit 3ec9533

Please sign in to comment.