Skip to content

Commit

Permalink
Change prediction batch_size from 32 to 4096 for better performance
Browse files Browse the repository at this point in the history
  • Loading branch information
timodonnell committed Nov 28, 2017
1 parent cbb649d commit 38adfb4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
11 changes: 8 additions & 3 deletions mhcflurry/class1_affinity_prediction/class1_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def fit(
break
self.fit_seconds = time.time() - start

def predict(self, peptides, allele_pseudosequences=None):
def predict(self, peptides, allele_pseudosequences=None, batch_size=4096):
"""
Predict affinities
Expand All @@ -490,6 +490,9 @@ def predict(self, peptides, allele_pseudosequences=None):
allele_pseudosequences : EncodableSequences or list of string, optional
Only required when this model is a pan-allele model
batch_size : int
batch_size passed to Keras
Returns
-------
numpy.array of nM affinity predictions
Expand All @@ -501,8 +504,10 @@ def predict(self, peptides, allele_pseudosequences=None):
pseudosequences_input = self.pseudosequence_to_network_input(
allele_pseudosequences)
x_dict['pseudosequence'] = pseudosequences_input
(predictions,) = numpy.array(
self.network(borrow=True).predict(x_dict), dtype="float64").T

network = self.network(borrow=True)
raw_predictions = network.predict(x_dict, batch_size=batch_size)
predictions = numpy.array(raw_predictions, dtype = "float64")[:,0]
return to_ic50(predictions)

def compile(self):
Expand Down
1 change: 0 additions & 1 deletion test/test_class1_affinity_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def test_class1_affinity_predictor_a0205_memorize_training_data():
dense_layer_l1_regularization=0.0,
dropout_probability=0.0)

# First test a Class1NeuralNetwork, then a Class1AffinityPredictor.
allele = "HLA-A*02:05"

df = pandas.read_csv(
Expand Down
4 changes: 2 additions & 2 deletions test/test_train_allele_specific_models_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def test_run():
args = [
"--data", get_path("data_curated", "curated_training_data.csv.bz2"),
"--hyperparameters", hyperparameters_filename,
"--min-measurements-per-allele", "9000",
"--allele", "HLA-A*02:01", "HLA-A*01:01", "HLA-A*03:01",
"--out-models-dir", models_dir,
"--percent-rank-calibration-num-peptides-per-length", "1000",
"--percent-rank-calibration-num-peptides-per-length", "10000",
]
print("Running with args: %s" % args)
train_allele_specific_models_command.run(args)
Expand Down

0 comments on commit 38adfb4

Please sign in to comment.