Skip to content

Commit

Permalink
implement Class1AffinityPredictor.merge() method
Browse files Browse the repository at this point in the history
  • Loading branch information
timodonnell committed Nov 28, 2017
1 parent 2f02695 commit c699066
Showing 1 changed file with 58 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,67 @@ def __init__(
columns=["model_name", "allele", "config_json", "model"])
self.manifest_df = manifest_df

if allele_to_percent_rank_transform is None:
if not allele_to_percent_rank_transform:
allele_to_percent_rank_transform = {}
self.allele_to_percent_rank_transform = allele_to_percent_rank_transform

@property
def neural_networks(self):
"""
List of the neural networks in the ensemble.
Returns
-------
list of Class1NeuralNetwork
"""
result = []
for models in self.allele_to_allele_specific_models.values():
result.extend(models)
result.extend(self.class1_pan_allele_models)
return result

@classmethod
def merge(cls, predictors):
"""
Merge the ensembles of two or more Class1AffinityPredictor instances.
Note: the resulting merged predictor will NOT have calibrated percentile
ranks. Call calibrate_percentile_ranks() on it if these are needed.
Parameters
----------
predictors : sequence of Class1AffinityPredictor
Returns
-------
Class1AffinityPredictor
"""
assert len(predictors) > 0
if len(predictors) == 1:
return predictors[0]

allele_to_allele_specific_models = collections.defaultdict(list)
class1_pan_allele_models = []
allele_to_pseudosequence = predictors[0].allele_to_pseudosequence

for predictor in predictors:
for (allele, networks) in (
predictor.allele_to_allele_specific_models.items()):
allele_to_allele_specific_models[allele].extend(networks)
class1_pan_allele_models.extend(
predictor.class1_pan_allele_models)

return Class1AffinityPredictor(
allele_to_allele_specific_models=allele_to_allele_specific_models,
class1_pan_allele_models=class1_pan_allele_models,
allele_to_pseudosequence=allele_to_pseudosequence
)

@property
def num_networks(self):
return self.manifest_df.shape[0]


@property
def supported_alleles(self):
Expand Down

0 comments on commit c699066

Please sign in to comment.