In [20]:
import lightning as L
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData, Shapes3dData, SpritesData
from disent.dataset.transform import ToImgTensorF32
from disent.frameworks.vae import BetaVae
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64
from disent.model.ae import EncoderConv64
from disent.util import is_test_run  # you can ignore and remove this

# prepare the data
data = Shapes3dData(prepare=False)
data = SpritesData(prepare=True)

dataset = DisentDataset(data, transform=ToImgTensorF32())
dataloader = DataLoader(dataset=dataset, batch_size=4, shuffle=True, num_workers=10)

# create the pytorch lightning system
module: L.LightningModule = BetaVae(
    model=AutoEncoder(
        encoder=EncoderConv64(x_shape=data.x_shape, z_size=8, z_multiplier=2),
        decoder=DecoderConv64(x_shape=data.x_shape, z_size=8),
    ),
    cfg=BetaVae.cfg(optimizer="adam", optimizer_kwargs=dict(lr=1e-3), loss_reduction="mean_sum", beta=4),
)



In [26]:
# train the model
trainer = L.Trainer(logger=True, fast_dev_run=is_test_run(), max_epochs=3, )
                    # enable_checkpointing=True, default_root_dir='.data/checkpoints')
trainer.fit(module, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/3
[W socket.cpp:426] [c10d] The server socket has failed to bind to [::]:55326 (errno: 98 - Address already in use).
  rank_zero_warn(
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/3
  rank_zero_warn(
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/3
[W socket.cpp:426] [c10d] The server socket has failed to bind to 0.0.0.0:55326 (errno: 98 - Address already in use).
[E socket.cpp:462] [c10d] The server socket has failed to listen on any local network address.


ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 147, in _wrapping_function
    results = function(*args, **kwargs)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 570, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 933, in _run
    self.strategy.setup_environment()
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/pytorch/strategies/ddp.py", line 143, in setup_environment
    self.setup_distributed()
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/pytorch/strategies/ddp.py", line 192, in setup_distributed
    _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/lightning/fabric/utilities/distributed.py", line 246, in _init_dist_connection
    torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 900, in init_process_group
    store, rank, world_size = next(rendezvous_iterator)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 245, in _env_rendezvous_handler
    store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
  File "/mnt/data/emanuele.marconato/miniconda3/envs/disent-py38/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 176, in _create_c10d_store
    return TCPStore(
RuntimeError: The server socket has failed to listen on any local network address. The server socket has failed to bind to [::]:55326 (errno: 98 - Address already in use). The server socket has failed to bind to 0.0.0.0:55326 (errno: 98 - Address already in use).


In [4]:
from disent.metrics import metric_dci, metric_mig
from disent.util import is_test_run

get_repr = lambda x: module.encode(x.to(module.device))

# evaluate
{
    **metric_dci(
        dataset,
        get_repr,
        num_train=10 if is_test_run() else 1000,
        num_test=5 if is_test_run() else 500,
        boost_mode="sklearn",
    ),
    **metric_mig(dataset, get_repr, num_train=20 if is_test_run() else 2000),
}


{'dci.informativeness_train': 0.96025,
 'dci.informativeness_test': 0.24649999999999997,
 'dci.disentanglement': 0.02376654453393539,
 'dci.completeness': 0.02985560412256319,
 'mig.discrete_score': 0.035976604417420434}