From 25328502ea0dd0aa0209b76309b9f03637d2b51c Mon Sep 17 00:00:00 2001 From: VictorSanh Date: Fri, 8 Oct 2021 17:44:20 -0400 Subject: [PATCH 1/2] catch get_dataset_infos error for community datasets without dataset_info.json --- promptsource/app.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/promptsource/app.py b/promptsource/app.py index 768eba864..698eb3588 100644 --- a/promptsource/app.py +++ b/promptsource/app.py @@ -1,6 +1,6 @@ import argparse import textwrap -from multiprocessing import Manager, Pool +from multiprocessing import Manager, Pool, cpu_count import pandas as pd import plotly.express as px @@ -133,7 +133,7 @@ def show_text(t, width=WIDTH): def get_infos(d_name): all_infos[d_name] = get_dataset_infos(d_name) - pool = Pool(processes=len(all_datasets)) + pool = Pool(processes=cpu_count()) pool.map(get_infos, all_datasets) pool.close() pool.join() @@ -146,12 +146,18 @@ def get_infos(d_name): all_infos[dataset_name] = infos else: infos = all_infos[dataset_name] - if subset_name is None: - subset_infos = infos[list(infos.keys())[0]] - else: - subset_infos = infos[subset_name] + if infos: + if subset_name is None: + subset_infos = infos[list(infos.keys())[0]] + else: + subset_infos = infos[subset_name] - split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()} + split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()} + else: + # Zaid/coqa_expanded and Zaid/quac_expanded don't have dataset_infos.json + # so infos is an empty dic, and `infos[list(infos.keys())[0]]` raises an error + # For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0. + split_sizes = {} # Collect template counts, original task counts and names dataset_templates = template_collection.get_dataset(dataset_name, subset_name) From 6e262e66dd2954e2be00b6c4405593d23d77b518 Mon Sep 17 00:00:00 2001 From: Victor SANH Date: Wed, 13 Oct 2021 13:38:29 -0400 Subject: [PATCH 2/2] rm the pool fixes --- promptsource/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/promptsource/app.py b/promptsource/app.py index 698eb3588..425f212ac 100644 --- a/promptsource/app.py +++ b/promptsource/app.py @@ -1,6 +1,6 @@ import argparse import textwrap -from multiprocessing import Manager, Pool, cpu_count +from multiprocessing import Manager, Pool import pandas as pd import plotly.express as px @@ -133,7 +133,7 @@ def show_text(t, width=WIDTH): def get_infos(d_name): all_infos[d_name] = get_dataset_infos(d_name) - pool = Pool(processes=cpu_count()) + pool = Pool(processes=len(all_datasets)) pool.map(get_infos, all_datasets) pool.close() pool.join()