Skip to content

Commit

Permalink
added exception handling for string parameters. fixed bug in cv init …
Browse files Browse the repository at this point in the history
…parameter
  • Loading branch information
adelavega committed Jun 12, 2013
1 parent 6dd7e7d commit 44a9b2d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 17 deletions.
16 changes: 8 additions & 8 deletions examples/classify_regions_using_features.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@
"output_type": "pyout",
"prompt_number": 4,
"text": [
"0.75743738623534018"
"0.74216300940438873"
]
}
],
Expand All @@ -124,7 +124,7 @@
"\n",
"classify_regions(dataset, masks, method='SVM', threshold=0.001, remove_overlap=True, regularization='scale', output='summary', studies=None, features=None, class_weight=True, classifier=None, cross_val=None):\n",
"\n",
"Let's focus on the first few arguments: by default, a support vector machine is used as the classifier method, studies that have more than 0.1% of their peaks (i.e. any) in a mask are included, studies that activate both regions are removed, the feature vectors are regularized to having unit variance and a summary dictionary is returned.\n",
"Let's focus on the first few arguments: by default, a support vector machine is used as the classifier method, studies that activate more than 0.1% of voxels within a mask are included, studies that activate both regions are removed, the feature vectors are regularized to having unit variance and a summary dictionary is returned.\n",
"\n",
"Let's try changing the method. Importantly, we need a baseline to compare to. For that, we can use a *Dummy classifier*.\n",
"Dummy classifier use very simple strategies to classify (such as picking the most frequent class) and serve as a good baseline that is specific to your data. We can try this type of classifier by specifying `method = \"Dummy\"`"
Expand All @@ -144,7 +144,7 @@
"output_type": "pyout",
"prompt_number": 5,
"text": [
"0.76175525909541264"
"0.74020376175548597"
]
}
],
Expand All @@ -163,7 +163,7 @@
"cell_type": "code",
"collapsed": false,
"input": [
"classify.classify_regions(dataset, [\"../neurosynth/tests/data/medial_motor.nii.gz\", \"../neurosynth/tests/data/vmPFC.nii.gz\"], threshold=0.3)"
"classify.classify_regions(dataset, [\"../neurosynth/tests/data/medial_motor.nii.gz\", \"../neurosynth/tests/data/vmPFC.nii.gz\"], threshold=0.1)"
],
"language": "python",
"metadata": {},
Expand All @@ -172,7 +172,7 @@
"output_type": "pyout",
"prompt_number": 6,
"text": [
"{'n': {0: 84, 1: 47}, 'score': 0.6795278365045806}"
"{'n': {0: 875, 1: 262}, 'score': 0.69570051890289097}"
]
}
],
Expand All @@ -182,14 +182,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Our accuracy is now 68% and we only have around 150 total observations. Let's compare to a Dummy classifier to make sense of this"
"Our accuracy is now 69% and we only have around 1100 total observations. Let's compare to a Dummy classifier to make sense of this"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"classify.classify_regions(dataset, [\"../neurosynth/tests/data/medial_motor.nii.gz\", \"../neurosynth/tests/data/vmPFC.nii.gz\"], threshold=0.3, method=\"Dummy\")"
"classify.classify_regions(dataset, [\"../neurosynth/tests/data/medial_motor.nii.gz\", \"../neurosynth/tests/data/vmPFC.nii.gz\"], threshold=0.1, method=\"Dummy\")"
],
"language": "python",
"metadata": {},
Expand All @@ -198,7 +198,7 @@
"output_type": "pyout",
"prompt_number": 7,
"text": [
"{'n': {0: 84, 1: 47}, 'score': 0.53453136011275548}"
"{'n': {0: 875, 1: 262}, 'score': 0.57870953792933033}"
]
}
],
Expand Down
13 changes: 4 additions & 9 deletions neurosynth/analysis/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def classify_by_features(dataset, features, studies=None, method='SVM', scikit_c

def classify_regions(dataset, masks, method='SVM', threshold=0.001, remove_overlap=True,
regularization='scale', output='summary', studies=None, features=None,
class_weight='auto', classifier=None, cross_val='4-fold'):
class_weight='auto', classifier=None, cross_val='4-Fold'):

'''
Args:
Expand All @@ -22,10 +22,6 @@ def classify_regions(dataset, masks, method='SVM', threshold=0.001, remove_overl
import nibabel as nib
import os

# Get base file names for masks
# Sci-kit learn does not support numbered masks
# mask_names = [os.path.basename(file).split('.')[0] for file in masks

# Load masks using NiBabel
try:
loaded_masks = [nib.load(os.path.relpath(m)) for m in masks]
Expand Down Expand Up @@ -99,8 +95,7 @@ def __init__(self, clf_method='SVM', classifier=None, class_weight=None):
from sklearn.dummy import DummyClassifier
self.clf = DummyClassifier(strategy="stratified")
else:
# Error handling?
self.clf = None
raise Exception("Unrecognized classification method")


def fit(self, X, y):
Expand Down Expand Up @@ -129,7 +124,7 @@ def cross_val_fit(self, X, y, cross_val='4-Fold'):
if cross_val == '4-Fold':
self.cver = cross_validation.KFold(len(self.y),4,indices=False,shuffle=True)
else:
self.cver = None
raise Exception("Unrecognized cross validation method")
else:
self.cver = cross_val

Expand All @@ -154,5 +149,5 @@ def regularize(self, X, method='scale'):
from sklearn import preprocessing
return preprocessing.scale(X,with_mean=False)
else:
return X
raise Exception("Unrecognized regularization method")

0 comments on commit 44a9b2d

Please sign in to comment.