Skip to content
This repository has been archived by the owner on Mar 7, 2022. It is now read-only.

Commit

Permalink
Change json for pickle in binarization
Browse files Browse the repository at this point in the history
  • Loading branch information
lukassnoek committed Dec 30, 2016
1 parent b2f396a commit d14f1f6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
13 changes: 6 additions & 7 deletions skbold/core/mvp_between.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import re
import json
import pickle
import warnings
import os.path as op
import pandas as pd
Expand Down Expand Up @@ -418,9 +418,9 @@ def add_y(self, file_path, col_name, sep='\t', index_col=0,
self._undersample_majority()

def apply_binarization_params(self, param_file, ensure_balanced=False):

with open(param_file, 'r') as fin:
params = json.load(fin, encoding='utf-8')
""" Applies binarization-parameters to y. """
with open(param_file, 'rb') as fin:
params = pickle.load(fin)

if params['type'] == 'zscore':
y_norm = (self.y - params['mean']) / params['std']
Expand Down Expand Up @@ -471,9 +471,8 @@ def binarize_y(self, params, save_path=None, ensure_balanced=False):
self._undersample_majority()

if save_path is not None:
with open(op.join(save_path, 'binarize_params.json'), 'w',
encoding="utf-8") as w:
json.dump(labb.binarize_params, w, indent=4)
with open(op.join(save_path, 'binarize_params.pkl'), 'wb') as w:
pickle.dump(labb.binarize_params, w)

def split(self, file_path, col_name, target, sep='\t', index_col=0):
""" Splits an MvpBetween object based on some external index.
Expand Down
6 changes: 3 additions & 3 deletions skbold/core/tests/test_mvp_between.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def test_mvp_between_binarize_y(mvp1c, params):
mvp1c.add_y(fpath, col_name='var_continuous', index_col=0)
mvp1c.binarize_y(params, ensure_balanced=True, save_path=testdata_path)
assert((np.unique(mvp1c.y) == [0, 1]).all())
assert(op.isfile(op.join(testdata_path, 'binarize_params.json')))
assert(op.isfile(op.join(testdata_path, 'binarize_params.pkl')))


def test_mvp_between_apply_binarization_params(mvp1c):
fpath = op.join(testdata_path, 'sample_behav.tsv')
mvp1c.add_y(fpath, col_name='var_continuous', index_col=0)
mvp1c.apply_binarization_params(op.join(testdata_path,
'binarize_params.json'))
os.remove(op.join(testdata_path, 'binarize_params.json'))
'binarize_params.pkl'))
os.remove(op.join(testdata_path, 'binarize_params.pkl'))


@pytest.mark.parametrize("cols", ['confound_categorical',
Expand Down
6 changes: 3 additions & 3 deletions skbold/data/test_data/sample_behav.tsv
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
var_categorical var_continuous var_multinomial confound_categorical confound_continuous group
sub001 1 23.123 2 0 384.123 train
sub002 1 24.129 2 0 819.32 train
sub003 1 20.001 2 1 381.33
sub003 1 20.001 2 1 381.33 train
sub004 0 19.583 2 1 921.939 train
sub005 9.209 2 1 201.193 train
sub005 0 9.209 2 1 201.193 train
sub006 0 58.284 1 0 403.201 train
sub007 0 132.123 1 1 683.863 test
sub008 999 492.593 1 1 290.193 test
sub009 1 394.288 1 0 993.201 test
sub010 1 918.13 3 0 381.382 test
sub011 1 382.111 1 998.281 test
sub011 1 382.111 3 1 998.281 test
sub012 0 399.11 3 0 888.133 test
sub013 1 393.99 3 0 133.281 test

0 comments on commit d14f1f6

Please sign in to comment.