Skip to content

Commit

Permalink
Add force CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Feb 14, 2019
1 parent df79e53 commit 1f0064e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions examples/talos_3d_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from talos.utils.gpu_utils import force_cpu

# Force CPU use on a GPU system
force_cpu()

import talos as ta
from keras.activations import relu, elu, softmax, hard_sigmoid, tanh
from keras.layers import Flatten, ConvLSTM2D, Dense, Conv3D, MaxPooling3D, Dropout
Expand Down Expand Up @@ -35,7 +40,9 @@
'stride_1': [1, 2],
'layer_drop': [0.0, 0.8, 4],
'layers': [2,3,4],
'pool': [0]
'pool': [0],
'rebin': [50],
'time': [10],

}
'''
Expand All @@ -56,7 +63,7 @@ def input_model(x_train, y_train, x_val, y_val, params):

model.add(Conv3D(params['neuron_1'], kernel_size=params['kernel_1'], strides=params['stride_1'],
padding='same',
input_shape=(100, 100, 10, 1),
input_shape=(params['rebin'],params['rebin'],params['time'], 1),
activation=params['activation']))
if params['pool']:
model.add(MaxPooling3D())
Expand Down Expand Up @@ -125,17 +132,17 @@ def input_model(x_train, y_train, x_val, y_val, params):
gamma_dir = [directory + "gammaFeature/no_clean/"]
proton_dir = [directory + "protonFeature/no_clean/"]

x, y = get_chunk_of_data(directory=gamma_dir, proton_directory=proton_dir, indicies=(30, 129, 10), rebin=100,
x, y = get_chunk_of_data(directory=gamma_dir, proton_directory=proton_dir, indicies=(30, 129, params['time'][0]), rebin=params['rebin'][0],
chunk_size=args['size'], as_channels=True)
x = x.reshape(-1, 100,100,10,1)
x = x.reshape(-1, params['rebin'][0],params['rebin'][0],params['time'][0],1)

print("Got data")
print("X Shape", x.shape)
print("Y Shape", y.shape)
history = ta.Scan(x, y,
params=params,
dataset_name='3d_separation_test',
experiment_no='1',
experiment_no='2',
model=input_model,
search_method='random',
grid_downsample=args['grid'])

0 comments on commit 1f0064e

Please sign in to comment.