You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm also collaborating on brando on this issue and found a potential fix. I was wondering if my potential fix might solve this issue.
To summarize the issue, initially we tried using len(metadataset_pipeline) where metadataset_pipeline is a object in the form metadataset_pipeline = pipeline.make_episodic_pipeline(...). However, as you may know the pipeline is a PyTorch IterableDataset derived from TFRecords that only implements __iter__(), so we couldn't call len() on it.
I did just find a solution around this issue using the dataset_spec.images_per_class metadata field for a given dataset, which returns a dictionary that maps a given class to the number of images in the class. Then, I sum up all the image counts belonging to a certain split via dataset_spec.get_classes(Split[split]).
I was wondering if the following code would work according to your API. I have provided a example run of my snippet below:
def get_num_images(args, split: str = 'VALID'):
# first we want to get the sources to figure out which datasets we use
data_config = config_lib.DataConfig(args)
datasets = data_config.sources
num_images = 0
for dataset_name in datasets:
dataset_records_path = os.path.join(data_config.path, dataset_name)
dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)
all_class_sizes = dataset_spec.images_per_class
# let's get only the class sizes of our split
class_set = dataset_spec.get_classes(Split[split])
for c in class_set:
# ignore classes that have less images than needed for our n-way k-shot task
if (all_class_sizes[c] >= args.min_examples_in_class):
num_images += all_class_sizes[c]
return num_images
args: Namespace = parse_args_standard_sl() # our meta-dataset configuration args are in this
args.sources = ['dtd','cu_birds']
print(get_num_images(args, 'TRAIN')) # outputs 12199
print(get_num_images(args, 'VALID')) # outputs 2619
print(get_num_images(args, 'TEST')) # outputs 2610
No description provided.
The text was updated successfully, but these errors were encountered: