Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Q]: What's the native way to split datasets into train, validation and test? #22

Closed
akekic opened this issue Feb 1, 2022 · 4 comments
Assignees
Labels
documentation Improvements or additions to documentation enhancement New feature or request proposal Discussion on a new proposed feature question Further information is requested refactoring Internal changes to the code that do not change the outputs.
Milestone

Comments

@akekic
Copy link

akekic commented Feb 1, 2022

I'm trying to train a vae on Cars3dData and I was wondering how to split an instance of DisentDataset. Is there a dedicated sampler that does this?

Here is the backbone of what I am trying to run:

import pytorch_lightning as pl
import torch

from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run  # you can ignore and remove this


# prepare the data
data = Cars3dData()
size = 64
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
dataset_train = DisentDataset(data, transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))
# dataset_val = ?
# dataset_test = ?

dataloader_train = DataLoader(
    dataset=dataset_train,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
        decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
    ),
    cfg=Ae.cfg(
        optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum"
    ),
)

# train the model
trainer = pl.Trainer(
    max_steps=10,
    checkpoint_callback=False,
    fast_dev_run=is_test_run(),
    gpus=1 if torch.cuda.is_available() else None,
)
trainer.fit(module, dataloader_train)

# compute disentanglement metrics
# - we cannot guarantee which device the representation is on
# - this will take a while to run
get_repr = lambda x: module.encode(x.to(module.device))

metrics = {
    **metric_dci(
        dataset_train, get_repr, num_train=1000, num_test=500, show_progress=True
    ),
    **metric_mig(dataset_train, get_repr, num_train=2000),
}

# evaluate
print("metrics:", metrics)

Any hints are highly appreciated. Thank you for providing this package!

Best regards
Armin

@akekic akekic added the question Further information is requested label Feb 1, 2022
@nmichlo
Copy link
Owner

nmichlo commented Feb 5, 2022

Hi there, sorry for the delayed response!

Unfortunately this is something that I will need to add to the roadmap. I am just not entirely sure myself how to approach this problem when it comes to samplers/metrics that require ground-truth datasets.

  • As soon as you randomly split the dataset into train/test/validate portions it is no longer a ground-truth dataset due to the way the factors are stored, and some of the samplers may no longer function as they assume all the ground-truth factors are still available Additionally this is where the metrics fail too, they require full access to the original datasets.

  • There may be a future workaround for this if new samplers are created that can handle this problem. Although the flexibility of these samplers might be reduced especially for weakly or strongly supervised methods. However, I am still not sure how this problem would be solved with the metrics themselves.

  • I realise the PyTorch lightning frameworks do not yet implement the validation step. I'll add this to the roadmap.

A workaround for your current code may be:

import math

import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import Cars3dData
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.ae import Ae
from disent.metrics import metric_dci
from disent.metrics import metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64


# normalise the data
vis_mean = [0.8976676149976628, 0.8891658020067508, 0.885147515814868]
vis_std = [0.22503195531503034, 0.2399461278981261, 0.24792106319684404]
data = Cars3dData(transform=ToImgTensorF32(size=64, mean=vis_mean, std=vis_std))

# SOLUTION:
# -- split the data using built-in functions (no longer ground-truth datasets, but subsets)
data_train, data_val, data_test = torch.utils.data.random_split(data, [
    int(math.floor(len(data)*0.6)),
    int(math.ceil(len(data)*0.2)),
    int(math.ceil(len(data)*0.2)),
])
# -- create multiple disent datasets
dataset_train = DisentDataset(data_train)
dataset_val   = DisentDataset(data_val)
dataset_test  = DisentDataset(data_test)
# -- create dataloaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=4, shuffle=True, num_workers=0)
dataloader_val   = DataLoader(dataset=dataset_val, batch_size=4, shuffle=True, num_workers=0)
dataloader_test  = DataLoader(dataset=dataset_test, batch_size=4, shuffle=True, num_workers=0)

# create the pytorch lightning system
module: pl.LightningModule = Ae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=(3, 64, 64), z_size=6),
        decoder=DecoderConv64(x_shape=(3, 64, 64), z_size=6),
    ),
    cfg=Ae.cfg(
        optimizer="adam", optimizer_kwargs=dict(lr=1e-3)
    ),
)

# PROBLEM: unfortunately the framework does not yet implement the pytorch-lightning validation step
#          I'll add this to the roadmap and this should work in future.
trainer = pl.Trainer(max_steps=10000, checkpoint_callback=False, gpus=1 if torch.cuda.is_available() else None)
trainer.fit(module, dataloader_train, dataloader_val)

# PROBLEM: unfortunately the metrics will no longer work with the subsets
#          of data. You could instead pass the original full dataset to the
#          metrics, but this may be considered an information leak?
#          -- This will crash!
get_repr = lambda x: module.encode(x.to(module.device))
metrics = {
    **metric_dci(dataset_test, get_repr, num_train=1000, num_test=500, show_progress=True),
    **metric_mig(dataset_test, get_repr, num_train=2000),
}
print("metrics:", metrics)

@nmichlo nmichlo added documentation Improvements or additions to documentation enhancement New feature or request refactoring Internal changes to the code that do not change the outputs. labels Feb 5, 2022
@nmichlo nmichlo added this to the v0.3.0 milestone Feb 5, 2022
@nmichlo nmichlo self-assigned this Feb 5, 2022
@nmichlo nmichlo modified the milestones: roadmap, v0.4.0 Feb 5, 2022
@nmichlo nmichlo added the proposal Discussion on a new proposed feature label Feb 5, 2022
@nmichlo nmichlo modified the milestones: v0.4.0, v0.3.4 Feb 6, 2022
@nmichlo
Copy link
Owner

nmichlo commented Feb 6, 2022

This has been fixed in 5695747 release v0.3.4

Frameworks now support basic validation and testing, reusing the code from the training step, however schedules might be broken if these are used.

A new example has been added to the docs: https://github.com/nmichlo/disent/blob/5695747c1e94420c024f1505d9b8a4b3c81ad610/docs/examples/overview_framework_train_val.py

@ema-marconato
Copy link

It seems not possible to use the train/val/test partition for AdaVAE training. Any way out?

@nmichlo
Copy link
Owner

nmichlo commented Jul 11, 2023

@ema-marconato So there are different sampling strategies in the original paper that can be used in different cases.

Unfortunately only the fully random sampling strategies work with the training and validation splits.

from disent.dataset.sampling import RandomSampler

The other strategies need more information and use the ground-truth factor information to enforce certain characteristics:

from disent.dataset.sampling import GroundTruthPairOrigSampler
from disent.dataset.sampling import GroundTruthPairSampler

It is possible that a random sampler could be written that tries to enforce the constraints provided by these ground-truth samplers. Unfortunately these are not implemented.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request proposal Discussion on a new proposed feature question Further information is requested refactoring Internal changes to the code that do not change the outputs.
Projects
None yet
Development

No branches or pull requests

3 participants