Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean metric function #32

Merged
merged 7 commits into from Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions fs_mol/models/abstract_torch_fsmol_model.py
Expand Up @@ -37,7 +37,7 @@
from fs_mol.utils.logging import PROGRESS_LOG_LEVEL
from fs_mol.utils.metric_logger import MetricLogger
from fs_mol.utils.metrics import (
avg_metrics_list,
avg_task_metrics_list,
compute_metrics,
BinaryEvalMetrics,
BinaryMetricType,
Expand Down Expand Up @@ -301,7 +301,7 @@ def validate_on_data_iterable(
if not quiet:
logger.info(f" Validation loss: {valid_loss:.5f}")
# If our data_iterable had more than one task, we'll have one result per task - average them:
mean_valid_metrics = avg_metrics_list(list(valid_metrics.values()))
mean_valid_metrics = avg_task_metrics_list(list(valid_metrics.values()))
if metric_to_use == "loss":
return -valid_loss # We are maximising things elsewhere, so flip the sign on the loss
else:
Expand Down
5 changes: 2 additions & 3 deletions fs_mol/multitask_train.py
@@ -1,5 +1,4 @@
import argparse
import itertools
import logging
import os
import pdb
Expand Down Expand Up @@ -42,7 +41,7 @@
)
from fs_mol.utils.cli_utils import add_train_cli_args, set_up_train_run, str2bool
from fs_mol.utils.metrics import (
avg_metrics_list,
avg_metrics_over_tasks,
BinaryEvalMetrics,
)
from fs_mol.utils.test_utils import eval_model
Expand Down Expand Up @@ -102,7 +101,7 @@ def test_model_fn(
seed=seed,
)

mean_metrics = avg_metrics_list(list(itertools.chain(*task_to_results.values())))
mean_metrics = avg_metrics_over_tasks(task_to_results)
if aml_run is not None:
for metric_name, (metric_mean, _) in mean_metrics.items():
aml_run.log(f"valid_task_test_{metric_name}", float(metric_mean))
Expand Down
4 changes: 2 additions & 2 deletions fs_mol/utils/maml_utils.py
Expand Up @@ -20,7 +20,7 @@
from fs_mol.utils.metrics import (
BinaryEvalMetrics,
BinaryMetricType,
avg_metrics_list,
avg_metrics_over_tasks,
compute_binary_task_metrics,
)
from fs_mol.utils.test_utils import eval_model
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_model_fn(
seed=seed,
)

mean_metrics = avg_metrics_list(list(itertools.chain(*task_to_results.values())))
mean_metrics = avg_metrics_over_tasks(task_to_results)
if aml_run is not None:
for metric_name, (metric_mean, _) in mean_metrics.items():
aml_run.log(f"valid_task_test_{metric_name}", float(metric_mean))
Expand Down
22 changes: 20 additions & 2 deletions fs_mol/utils/metrics.py
@@ -1,5 +1,6 @@
import dataclasses
from typing import Dict, Tuple, List
import itertools
from typing import Dict, Tuple, List, Union
from typing_extensions import Literal
from dataclasses import dataclass

Expand Down Expand Up @@ -57,7 +58,24 @@ def compute_binary_task_metrics(predictions: List[float], labels: List[float]) -
)


def avg_metrics_list(results: List[BinaryEvalMetrics]) -> Dict[str, Tuple[float, float]]:
def avg_metrics_over_tasks(
task_results: Dict[str, BinaryEvalMetrics]
) -> Dict[str, Tuple[float, float]]:
# average results over all tasks in input dictionary
# the average over each task is first created
# technically input is Dict[str, FSMolTaskSampleEvalResults], but everything
# not in BinaryEvalMetrics is unused here.
aggregated_metrics = {}
for (task, results) in task_results.items():
# this returns, for each task, a dictionary of aggregated results
aggregated_metrics[task] = avg_task_metrics_list(results)

return avg_task_metrics_list(list(itertools.chain(*aggregated_metrics.values())))


def avg_task_metrics_list(
results: List[Union[BinaryEvalMetrics, Dict[str, float]]]
) -> Dict[str, Tuple[float, float]]:
aggregated_metrics = {}

# Compute mean/std:
Expand Down
15 changes: 10 additions & 5 deletions fs_mol/utils/protonet_utils.py
@@ -1,5 +1,4 @@
import logging
import itertools
import os
import sys
from dataclasses import dataclass
Expand All @@ -21,7 +20,12 @@
)
from fs_mol.models.protonet import PrototypicalNetwork, PrototypicalNetworkConfig
from fs_mol.models.abstract_torch_fsmol_model import MetricType
from fs_mol.utils.metrics import BinaryEvalMetrics, compute_binary_task_metrics, avg_metrics_list
from fs_mol.utils.metrics import (
BinaryEvalMetrics,
compute_binary_task_metrics,
avg_metrics_over_tasks,
avg_task_metrics_list,
)
from fs_mol.utils.metric_logger import MetricLogger
from fs_mol.utils.test_utils import eval_model, FSMolTaskSampleEvalResults

Expand Down Expand Up @@ -158,7 +162,8 @@ def validate_by_finetuning_on_tasks(
aml_run=aml_run,
)

mean_metrics = avg_metrics_list(list(itertools.chain(*task_results.values())))
# take the dictionary of task_results and return correct mean over all tasks
mean_metrics = avg_metrics_over_tasks(task_results)
if aml_run is not None:
for metric_name, (metric_mean, _) in mean_metrics.items():
aml_run.log(f"valid_task_test_{metric_name}", float(metric_mean))
Expand Down Expand Up @@ -282,7 +287,7 @@ def train_loop(self, out_dir: str, dataset: FSMolDataset, aml_run=None):
self.optimizer.step()

task_batch_mean_loss = np.mean(task_batch_losses)
task_batch_avg_metrics = avg_metrics_list(task_batch_metrics)
task_batch_avg_metrics = avg_task_metrics_list(task_batch_metrics)
metric_logger.log_metrics(
loss=task_batch_mean_loss,
avg_prec=task_batch_avg_metrics["avg_precision"][0],
Expand All @@ -292,7 +297,7 @@ def train_loop(self, out_dir: str, dataset: FSMolDataset, aml_run=None):

if step % self.config.validate_every_num_steps == 0:

valid_metric = validate_by_finetuning_on_tasks(self, dataset)
valid_metric = validate_by_finetuning_on_tasks(self, dataset, aml_run=aml_run)

if aml_run:
# printing some measure of loss on all validation tasks.
Expand Down