Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to compute data set size form tfrecrods for mds (within python)? #26

Open
brando90 opened this issue Jan 26, 2023 · 1 comment
Open

Comments

@brando90
Copy link
Contributor

No description provided.

@patricks-lab
Copy link

I'm also collaborating on brando on this issue and found a potential fix. I was wondering if my potential fix might solve this issue.

To summarize the issue, initially we tried using len(metadataset_pipeline) where metadataset_pipeline is a object in the form metadataset_pipeline = pipeline.make_episodic_pipeline(...). However, as you may know the pipeline is a PyTorch IterableDataset derived from TFRecords that only implements __iter__(), so we couldn't call len() on it.

I did just find a solution around this issue using the dataset_spec.images_per_class metadata field for a given dataset, which returns a dictionary that maps a given class to the number of images in the class. Then, I sum up all the image counts belonging to a certain split via dataset_spec.get_classes(Split[split]).

I was wondering if the following code would work according to your API. I have provided a example run of my snippet below:

def get_num_images(args, split: str = 'VALID'):
    # first we want to get the sources to figure out which datasets we use
    data_config = config_lib.DataConfig(args)
    datasets = data_config.sources
    num_images = 0

    for dataset_name in datasets:
        dataset_records_path = os.path.join(data_config.path, dataset_name)
        dataset_spec = dataset_spec_lib.load_dataset_spec(dataset_records_path)

        all_class_sizes = dataset_spec.images_per_class

        # let's get only the class sizes of our split
        class_set = dataset_spec.get_classes(Split[split])

        for c in class_set:
            # ignore classes that have less images than needed for our n-way k-shot task
            if (all_class_sizes[c] >= args.min_examples_in_class):
                num_images += all_class_sizes[c]

    return num_images


args: Namespace = parse_args_standard_sl() # our meta-dataset configuration args are in this
args.sources = ['dtd','cu_birds']

print(get_num_images(args, 'TRAIN')) # outputs 12199
 print(get_num_images(args, 'VALID')) # outputs  2619
 print(get_num_images(args, 'TEST')) # outputs 2610

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants