Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About metrics #48

Closed
Rainbow-Six66 opened this issue Apr 19, 2024 · 3 comments
Closed

About metrics #48

Rainbow-Six66 opened this issue Apr 19, 2024 · 3 comments
Labels
question Further information is requested

Comments

@Rainbow-Six66
Copy link

Hi, great package!

How to understand the num_train and batch_size of the metric_dci or metric_mig? In addition, are there any examples of using the factor indicator?
thank you!

@Rainbow-Six66 Rainbow-Six66 added the question Further information is requested label Apr 19, 2024
@Rainbow-Six66
Copy link
Author

This is my code
from disent.metrics import metric_dci, metric_mig
import torch
from torch.utils.data import DataLoader
from disent.metrics import metric_dci, metric_mig, metric_factor_vae
import torch
from torch.utils.data import DataLoader
from disent.dataset import DisentDataset
from model.β_vae import BetaVAE_H
from disent.dataset.data import DSpritesData
from disent.dataset.transform import ToImgTensorF32

def train():
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
data = DSpritesData()
dataset = DisentDataset(data, transform=ToImgTensorF32(), augment=None)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

checkpoint = torch.load('./checkpoints/beta_vae50.pt', map_location=device)
model = BetaVAE_H().to(device)
model.load_state_dict(checkpoint['model'])
model.eval()

# we cannot guarantee which device the representation is on
get_repr = lambda x: model.mlp_encoder(x.to(device))

# evaluate
return {
    **metric_dci(dataset, get_repr, num_train=20, boost_mode='sklearn'),
    **metric_mig(dataset, get_repr, num_train=20),
    **metric_factor_vae(dataset, get_repr, num_train=20),
}

a_results = train()
print('beta=4: ', a_results)

@nmichlo
Copy link
Owner

nmichlo commented Apr 25, 2024

Hi there, and thank you!

Unfortunately docs for this are sparse. I understand this is not the most ideal, would gladly accept PRs to fix this.

However, for context, the mig, dci and factor vae scores are largely based on those from https://github.com/google-research/disentanglement_lib (Default values should be similar) From what I remember without looking at the code num_train and batch_size affect the sample size of underlying data that is used to compute the metrics. Too little data and the metrics will be inaccurate, too much and processing time will be too much. Often for metrics during training I would lower these values and then do a final larger compute at the end with the default values.

@nmichlo
Copy link
Owner

nmichlo commented Apr 25, 2024

hydra config experiments metrics:
https://github.com/nmichlo/disent/tree/8f061a87076adeae8d6e5b0fa984b660cd40e026/experiment/config/metrics

actual code that selects these:

disent/experiment/run.py

Lines 208 to 209 in 8f061a8

train_metric = [R.METRICS[name].compute_fast] if settings.get("on_train", default_on_train) else None
final_metric = [R.METRICS[name].compute] if settings.get("on_final", default_on_final) else None

metric wrapper:

  • see compute and compute_fast
    class Metric(Generic[T]):
    def __init__(
    self,
    name: str,
    metric_fn: T, # Callable[[...], Dict[str, Number]]
    default_kwargs: Optional[Dict[str, Any]] = None,
    fast_kwargs: Optional[Dict[str, Any]] = None,
    ):
    self._name = name
    self._orig_fn = metric_fn
    self._metric_fn_default = wrapped_partial(self._orig_fn, **(default_kwargs if default_kwargs else {}))
    self._metric_fn_fast = wrapped_partial(self._orig_fn, **(fast_kwargs if fast_kwargs else {}))
    # How do we get a type hint for `__call__` so that its signature matches `T`?
    def __call__(self, *args, **kwargs) -> Dict[str, Number]:
    return self._metric_fn_default(*args, **kwargs)
    @property
    def compute(self) -> T:
    return self._metric_fn_default
    @property
    def compute_fast(self) -> T:
    return self._metric_fn_fast
    @property
    def unwrap(self) -> T:
    return self._orig_fn
    @property
    def name(self) -> str:
    return self._name
    def __str__(self):
    return f"metric-{self.name}"
    def make_metric(
    name: str,
    default_kwargs: Optional[Dict[str, Any]] = None,
    fast_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Callable[[T], Union[Metric[T], T]]:
    """
    Metrics should be decorated using this function to set defaults!
    Two versions of the metric should exist.
    1. Recommended settings
    - This should give reliable results, but may be very slow, multiple minutes to half an
    hour or more for some metrics depending on the underlying model, data and ground-truth factors.
    2. Faster settings
    - This should give a decent results, but should be decently fast, a few seconds/minutes at most.
    This is not used for testing
    """
    # `Union[Metric[T], T]` is hack to get type hint on `__call__`
    def _wrap_fn_as_metric(metric_fn: T) -> Union[Metric[T], T]:
    return Metric(name=name, metric_fn=metric_fn, default_kwargs=default_kwargs, fast_kwargs=fast_kwargs)
    return _wrap_fn_as_metric

fast version kwargs:

NOTE: kwargs for fast versions were arbitrarily chosen. The standard versions should follow kwargs from disentanglement_lib.

NOTE: batch_size is like batch size from dataset loaders, the model is often used within these metrics and is run on the GPU if possible.

@nmichlo nmichlo closed this as completed May 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants