Skip to content

Commit

Permalink
Adding training code
Browse files Browse the repository at this point in the history
Testing creation of training set from raw Kaggle data
  • Loading branch information
garethjns committed Jan 8, 2017
1 parent 0d8a1a9 commit 626ae56
Show file tree
Hide file tree
Showing 5 changed files with 6,493 additions and 11 deletions.
3 changes: 1 addition & 2 deletions copyTestLeakToTrain.m
Expand Up @@ -3,8 +3,7 @@ function copyTestLeakToTrain(paths)
% Move old, leaked test data to new training set

% Load list of safe files
safe = readtable([paths.mainDataDir, ...
'train_and_test_data_labels_safe.csv']);
safe = readtable('train_and_test_data_labels_safe.csv');

% Convert to string
safe1 = string(safe{:,1});
Expand Down
98 changes: 91 additions & 7 deletions featuresObject.m
Expand Up @@ -29,6 +29,10 @@
methods
function obj = featuresObject(params, use)

if isfield(params, 'tt')
obj.tt = params.tt;
end

if isfield(params, 'divS')
obj.params = params;
% Get list of divS to use
Expand All @@ -55,6 +59,7 @@
obj = setFileLists(obj, {});
obj.use = use;


% Run compile now?
% obj = compileFeatures(obj);
end
Expand Down Expand Up @@ -182,9 +187,15 @@

% Set feature type
obj = setType(obj);

% Set keepIdx
obj = setkeepIdx(obj);

% Set safeIdx (if training)
switch obj.tt
case {'train', 'Train'}
obj = setSafeIdx(obj);
case {'test', 'Test'}
% Set keepIdx
obj = setkeepIdx(obj);
end

% Add subject numbers to sets if hybrid
if obj.hybrid
Expand Down Expand Up @@ -266,6 +277,79 @@
obj.keepIdx = keepIdx1;
end

% Load and apply safe idx from second release of data
function obj = setSafeIdx(obj)
if obj.applyNewSafeIdx

% Load csv
% Contains all training data and affected test data
safe = readtable('train_and_test_data_labels_safe.csv');

% Drop refernces to test data
flStr = string(obj.fileLists{1}.File);
safeStr = string(safe.image);
trIdx = safeStr.contains(flStr);
safe = safe(trIdx, :);
% Remove class column for safety
safe.class = [];
% Rename key variable
safe.Properties.VariableNames{1} = 'File';

% Readd the new files in the training set (with modified
% names)
% Files in fileList
str = string(obj.fileLists{1}.File);
% But not in safeStr
missingIdx = ~str.contains(safeStr);
% Add thses to safe, along with safe tag
nAdd = sum(missingIdx);
add = table(cellstr(str(missingIdx)), ones(nAdd,1));
add.Properties.VariableNames{1} = 'File';
add.Properties.VariableNames{2} = 'safe';

% NB: Union drops dupes, but there shouldn't be any
safe = union(safe, add);

% Join
% To fileList
nt = join(obj.fileLists{1}, safe);

% Order as fileList?
plot(nt.SubSegID); hold on; plot(obj.fileLists{1}.SubSegID)
if ~all(nt.SubSegID == obj.fileLists{1}.SubSegID)
keyboard
else
obj.fileLists{1} = nt;
end

% Join
% To biggest subSegList
% (.File is called .Files here)
safe.Properties.VariableNames{1} = 'Files';
nSSL = join(obj.SSL{1}, safe);

% Order as fileList?
plot(nSSL.SubSegID); hold on; plot(obj.SSL{1}.SubSegID)
if ~all(nSSL.SubSegID == obj.SSL{1}.SubSegID)
keyboard
else
obj.SSL{1} = nSSL;
end

% Set newKeepIdx
keepIdx2 = nSSL.safe;

else
% Don't apply new safeIdx
keepIdx2 = true(height(obj.SSL{1},1));
end

% Get oklist from training data and combine
keepIdx1 = obj.newKeepIdx(obj.allTrain);

obj.keepIdx = keepIdx1 & keepIdx2;
end


function divS = findDivS(obj)
% Find available features files
Expand Down Expand Up @@ -379,14 +463,14 @@

nSubs = numel(subs);
% Generate file list
tn=0;
tn = 0;
for s = 1:nSubs
% Create sub table for this subject
fileList = table(vars{:,1});
fileList.Properties.VariableNames = vars(:,2);
fileList.Properties.VariableDescriptions = vars(:,3);

sDir = [paths.new, str, '_', subs{s}, '\'];
sDir = [paths, str, '_', subs{s}, '\'];

% files = dir([sDir, subs{s}, '*']);
files = dir([sDir, '*.mat']);
Expand All @@ -399,7 +483,7 @@
'/', num2str(nFilesSub), ')']);

switch str
case {'train', 'Test'}
case {'train', 'T'}
Y = str2double(fn(end-4));
IDIdx = strfind(files(n).name, '_');
ID = files(n).name(IDIdx(1)+1:IDIdx(2)-1);
Expand Down Expand Up @@ -504,7 +588,7 @@

nSubs = numel(subs);
% Generate file list
tn=0;
tn = 0;
for s = 1:nSubs
% Create sub table for this subject
fileList = table(vars{:,1});
Expand Down
4 changes: 2 additions & 2 deletions predict.m
Expand Up @@ -15,8 +15,8 @@
% Set paths and prepare parameters

% params.paths = 'S:\EEG Data Mini\';
params.paths = 'S:\EEG Data\New\';
params.ModelPath = 'trainedModelsCompact.mat';
params.paths.dataDir = 'S:\EEG Data\New\';
params.paths.ModelPath = 'trainedModelsCompact.mat';

rng(1000) % Probably does nothing here
startTime = tic;
Expand Down
73 changes: 73 additions & 0 deletions train.m
@@ -0,0 +1,73 @@
%% Set path to test data
% Set paths and prepare parameters

% params.paths = 'S:\EEG Data Mini\';
params.paths.dataDir = 'S:\EEG Data\';
params.paths.or = [params.paths.dataDir, 'Original\'];
% Path to new training and test sets
params.paths.new = [params.paths.dataDir, 'New\'];
params.paths.ModelPath = 'trainedModelsCompact.mat';

rng(1000) % Probably does nothing here
startTime = tic;

params.master = 61; % Version
params.nSubs = 3;

% Other params
% Edit in function
params = setParams(params);

warning('off', 'MATLAB:table:RowsAddedExistingVars')


%% Prepare raw data
% Create new training directory from original Kaggle data as per list of
% safe files.
% Creates singles.mat needed for this set

copyTestLeakToTrain(params.paths)


%% Process training set

params.tt = 'train';

% Features to use
% Need to save this in serizureModel during training
clear use
use.hillsBandsLog2D = 0;
use.hillsBandsLogAv = 1;
use.maxHills2D = 1;
use.maxHillsAv = 1;
use.summ32D = 1;
use.summ3Av = 1;
use.bandsLin2D = 1;
use.bandsLinAv = 1;
use.maxBands2D = 1;
use.maxBandsAv = 1;
use.mCorrsT = 1;
use.mCorrsF = 1;

% Create features object for test
disp('Creating basic features')

% Epoch window sizes to use
params.divS = [240, 160, 80];

% Create object
featuresTrain = featuresObject(params, use);

% Compile available features
featuresTrain = featuresTrain.compileFeatures();


%% Run training

% Set model parameters

% Run train function

% Run compare models

% Save models to disk (params.paths.ModelPath)

0 comments on commit 626ae56

Please sign in to comment.