Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to use l2l's meta-train to do standard supervised learning? #301

Closed
brando90 opened this issue Feb 4, 2022 · 12 comments
Closed

Comments

@brando90
Copy link

brando90 commented Feb 4, 2022

e.g. I don't want to do episodic meta-learning but I want to use the data in the meta-train set and sample it randomly. Is that possible? Just normal SL but on the meta-train set.

@seba-1511
Copy link
Member

Yes, you can access the underlying dataset of a TaskDataset with taskset.dataset.

@brando90
Copy link
Author

brando90 commented Feb 5, 2022

Yes, you can access the underlying dataset of a TaskDataset with taskset.dataset.

then pass that to a standard dataloader?

Thanks again for your really nice framework :) Hope to contribute in a more meaningful way soon I hope!

@seba-1511
Copy link
Member

Correct, that gives you a standard PyTorch dataset.

@brando90
Copy link
Author

brando90 commented Feb 8, 2022

Correct, that gives you a standard PyTorch dataset.

amazing! Your library is so nicely designed. Kudos.

@brando90
Copy link
Author

decided to do it slightly differently but I it seems right:

def get_sl_l2l_datasets(root,
                        data_augmentation: str = 'rfs2020',
                        device=None
                        ) -> tuple:
    if data_augmentation is None:
        train_data_transforms = transforms.ToTensor()
        test_data_transforms = transforms.ToTensor()
    elif data_augmentation == 'normalize':
        train_data_transforms = Compose([
            lambda x: x / 255.0,
        ])
        test_data_transforms = train_data_transforms
    elif data_augmentation == 'rfs2020':
        train_data_transforms = get_transform(True)
        test_data_transforms = get_transform(False)
    else:
        raise ('Invalid data_augmentation argument.')

    import learn2learn
    train_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                        transform=train_data_transforms,
                                                        mode='train',
                                                        download=True)
    valid_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                        transform=train_data_transforms,
                                                        mode='validation',
                                                        download=True)
    test_dataset = learn2learn.vision.datasets.CIFARFS(root=root,
                                                       transform=test_data_transforms,
                                                       mode='test',
                                                       download=True)
    if device is not None:
        train_dataset = learn2learn.data.OnDeviceDataset(
            dataset=train_dataset,
            device=device,
        )
        valid_dataset = learn2learn.data.OnDeviceDataset(
            dataset=valid_dataset,
            device=device,
        )
        test_dataset = learn2learn.data.OnDeviceDataset(
            dataset=test_dataset,
            device=device,
        )
    return train_dataset, valid_dataset, test_dataset


def get_sl_l2l_cifarfs_dataloader(args: Namespace) -> dict:
    train_dataset, valid_dataset, test_dataset = get_sl_l2l_datasets(root=args.data_path)

    from uutils.torch_uu.dataloaders.common import get_serial_or_distributed_dataloaders
    train_loader, val_loader = get_serial_or_distributed_dataloaders(
        train_dataset=train_dataset,
        val_dataset=valid_dataset,
        batch_size=args.batch_size,
        batch_size_eval=args.batch_size_eval,
        rank=args.rank,
        world_size=args.world_size
    )
    _, test_loader = get_serial_or_distributed_dataloaders(
        train_dataset=test_dataset,
        val_dataset=test_dataset,
        batch_size=args.batch_size,
        batch_size_eval=args.batch_size_eval,
        rank=args.rank,
        world_size=args.world_size
    )
    dataloaders: dict = {'train': train_loader, 'val': val_loader, 'test': test_loader}
    return dataloaders


# - tests

def l2l_sl_dl():
    print('starting...')
    args = Namespace(data_path='~/data/l2l_data/', batch_size=8, batch_size_eval=2, rank=-1, world_size=1)
    args.data_path = Path('~/data/l2l_data/').expanduser()
    dataloaders = get_sl_l2l_cifarfs_dataloader(args)
    max_val = torch.tensor(-1)
    # for i, batch in enumerate(dataloaders['train']):
    #     # print(batch[1])
    #     max_val = max(list(batch[1]) + [max_val])
    #     # print(f'{max_val}')
    #     # if 63 in batch[1]:
    #     #     break
    # print(f'--> TRAIN FINAL: {max_val=}')
    max_val = torch.tensor(-1)
    for i, batch in enumerate(dataloaders['val']):
        # print(batch[1])
        max_val = max(list(batch[1]) + [max_val])
        # print(f'{max_val}')
        # if 15 in batch[1]:
        #     break
    print(f'--> VAL FINAL: {max_val=}')
    max_val = torch.tensor(-1)
    for i, batch in enumerate(dataloaders['test']):
        # print(batch[1])
        max_val = max(list(batch[1]) + [max_val])
        # print(f'{max_val}')
        # if 19 in batch[1]:
        #     break
    print(f'--> TEST FINAL: {max_val=}')

@brando90
Copy link
Author

brando90 commented Apr 8, 2022

Actually do:

taskset.dataset

@brando90
Copy link
Author

brando90 commented Apr 8, 2022

@seba-1511 I'm confused, what is the type of taskset in taskset.dataset suppose to be? Is it BenchmarkTasksets

@brando90
Copy link
Author

brando90 commented Apr 8, 2022

@seba-1511 I'm confused, what is the type of taskset in taskset.dataset suppose to be? Is it BenchmarkTasksets

hmmm... this doesn't seem to be the right type:

args.tasksets.train.dataset
<learn2learn.data.meta_dataset.MetaDataset object at 0x7f89f3a44460>
type(args.tasksets.train.dataset)
<class 'learn2learn.data.meta_dataset.MetaDataset'>

@brando90
Copy link
Author

brando90 commented Apr 8, 2022

perhaps its this:

args.tasksets.train.dataset
<learn2learn.data.meta_dataset.MetaDataset object at 0x7f89f3a44460>
type(args.tasksets.train.dataset)
<class 'learn2learn.data.meta_dataset.MetaDataset'>
type(args.tasksets.train.dataset.dataset)
<class 'learn2learn.vision.datasets.cifarfs.CIFARFS'>

@brando90
Copy link
Author

brando90 commented Apr 8, 2022

I think this is better:

def get_standard_pytorch_dataset_from_l2l_taskdatasets(tasksets: BenchmarkTasksets, split: str) -> Dataset:
    """
    Trying to do:
        type(args.tasksets.train.dataset.dataset)
        <class 'learn2learn.vision.datasets.cifarfs.CIFARFS'>

    :param tasksets:
    :param split:
    :return:
    """
    # trying to do something like: args.tasksets.train
    taskset: TaskDataset = getattr(tasksets, split)
    # trying to do: type(args.tasksets.train.dataset.dataset)
    dataset: MetaDataset = taskset.dataset
    dataset: Dataset = dataset.dataset
    # assert isinstance(args.tasksets.train.dataset.dataset, Dataset)
    # asser isinstance(tasksets.train.dataset.dataset, Dataset)
    assert isinstance(dataset, Dataset), f'Expect dataset to be of type Dataset but got {type(dataset)=}.'
    return dataset

@brando90 brando90 changed the title Is it possible to use the meta-train set for l2l's task to do standard supervised learning? Is it possible to use l2l's meta-train to do standard supervised learning? Apr 15, 2022
@brando90
Copy link
Author

if device is not None:
train_dataset = learn2learn.data.OnDeviceDataset(
dataset=train_dataset,
device=device,
)
valid_dataset = learn2learn.data.OnDeviceDataset(
dataset=valid_dataset,
device=device,
)
test_dataset = learn2learn.data.OnDeviceDataset(
dataset=test_dataset,
device=device,
)

@seba-1511 is this really needed?

@brando90
Copy link
Author

#385

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants