You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
importnumpyasnpfromsklearn.metricsimportpairwise_distancesfromembetter.utilsimportsimilaritydefcalc_distances(inputs, anchors, pipeline, anchor_pipeline=None, metric="cosine", aggregate=np.max, n_jobs=None):
""" Shortcut to compare a sequence of inputs to a set of anchors. The available metrics are: `cityblock`,`cosine`,`euclidean`,`haversine`,`l1`,`l2`,`manhattan` and `nan_euclidean`. You can read a verbose description of the metrics [here](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise.distance_metrics.html#sklearn.metrics.pairwise.distance_metrics). Arguments: - inputs: sequence of inputs to calculate scores for - anchors: set/list of anchors to compare against - pipeline: the pipeline to use to calculate the embeddings - anchor_pipeline: the pipeline to apply to the anchors, meant to be used if the anchors should use a different pipeline - metric: the distance metric to use - aggregate: you'll want to aggregate the distances to the different anchors down to a single metric, numpy functions that offer axis=1, like `np.max` and `np.mean`, can be used - n_jobs: set to -1 to use all cores for calculation """X_input=pipeline.transform(inputs)
ifanchor_pipeline:
X_anchors=anchor_pipeline.transform(anchors)
else:
X_anchors=pipeline.transform(anchors)
X_dist=pairwise_distances(X_input, X_anchors, metric=metric, n_jobs=n_jobs)
returnaggregate(X_dist, axis=1)
The text was updated successfully, but these errors were encountered:
Something like this:
The text was updated successfully, but these errors were encountered: