Skip to content

Commit

Permalink
Merge branch 'dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Feb 6, 2022
2 parents c4fb871 + 1817000 commit 5695747
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 20 deletions.
50 changes: 35 additions & 15 deletions disent/frameworks/_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,25 +131,33 @@ def configure_optimizers(self):
# return the optimizer
return optimizer_instance

@final
def _compute_loss_step(self, batch, batch_idx, update_schedules: bool):
# augment batch with GPU support
if self._batch_augment is not None:
batch = self._batch_augment(batch)
# update the config values based on registered schedules
if update_schedules:
# TODO: how do we handle this in the case of the validation and test step? I think this
# might still give the wrong results as this is based on the trainer.global_step which
# may be incremented by these steps.
self._update_config_from_schedules()
# compute loss
loss, logs_dict = self.do_training_step(batch, batch_idx)
# check returned values
assert 'loss' not in logs_dict
self._assert_valid_loss(loss)
# log returned values
logs_dict['loss'] = loss
self.log_dict(logs_dict)
# return loss
return loss

@final
def training_step(self, batch, batch_idx):
"""This is a pytorch-lightning function that should return the computed loss"""
try:
# augment batch with GPU support
if self._batch_augment is not None:
batch = self._batch_augment(batch)
# update the config values based on registered schedules
self._update_config_from_schedules()
# compute loss
loss, logs_dict = self.do_training_step(batch, batch_idx)
# check returned values
assert 'loss' not in logs_dict
self._assert_valid_loss(loss)
# log returned values
logs_dict['loss'] = loss
self.log_dict(logs_dict)
# return loss
return loss
return self._compute_loss_step(batch, batch_idx, update_schedules=True)
except Exception as e: # pragma: no cover
# call in all the child processes for the best chance of clearing this...
# remove callbacks from trainer so we aren't stuck running forever!
Expand All @@ -160,6 +168,18 @@ def training_step(self, batch, batch_idx):
# continue propagating errors
raise e

def validation_step(self, batch, batch_idx):
"""
TODO: how do we handle the schedule in this case?
"""
return self._compute_loss_step(batch, batch_idx, update_schedules=False)

def test_step(self, batch, batch_idx):
"""
TODO: how do we handle the schedule in this case?
"""
return self._compute_loss_step(batch, batch_idx, update_schedules=False)

@final
def _assert_valid_loss(self, loss):
if self.trainer.terminate_on_nan:
Expand Down
47 changes: 47 additions & 0 deletions docs/examples/overview_framework_train_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import math
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.frameworks.vae import BetaVae
from disent.metrics import metric_dci, metric_mig
from disent.model import AutoEncoder
from disent.model.ae import DecoderConv64, EncoderConv64
from disent.dataset.transform import ToImgTensorF32
from disent.util import is_test_run # you can ignore and remove this

# make the ground-truth data
gt_data = XYObjectData()
# split the data using built-in functions (no longer ground-truth datasets, but subsets)
data_train, data_val = random_split(gt_data, [
int(math.floor(len(gt_data)*0.7)),
int(math.ceil(len(gt_data)*0.3)),
])
# create the disent datasets
gt_dataset = DisentDataset(gt_data, transform=ToImgTensorF32()) # .is_ground_truth == True
dataset_train = DisentDataset(data_train, transform=ToImgTensorF32()) # .is_ground_truth == False
dataset_val = DisentDataset(data_val, transform=ToImgTensorF32()) # .is_ground_truth == False
# create the data loaders
dataloader_train = DataLoader(dataset=dataset_train, batch_size=4, shuffle=True, num_workers=0)
dataloader_val = DataLoader(dataset=dataset_val, batch_size=4, shuffle=True, num_workers=0)

# create the pytorch lightning system
module: pl.LightningModule = BetaVae(
model=AutoEncoder(
encoder=EncoderConv64(x_shape=gt_data.x_shape, z_size=6, z_multiplier=2),
decoder=DecoderConv64(x_shape=gt_data.x_shape, z_size=6),
),
cfg=BetaVae.cfg(optimizer='adam', optimizer_kwargs=dict(lr=1e-3), loss_reduction='mean_sum', beta=4)
)

# train the model
trainer = pl.Trainer(logger=False, checkpoint_callback=False, fast_dev_run=is_test_run())
trainer.fit(module, dataloader_train, dataloader_val)

# compute metrics
# - we cannot guarantee which device the representation is on
get_repr = lambda x: module.encode(x.to(module.device))
# - We cannot compute disentanglement metrics over the split datasets `dataset_train` & `dataset_val`
# because they are no longer ground-truth datasets, we can only use `gt_dataset`
print(metric_dci(gt_dataset, get_repr, num_train=10 if is_test_run() else 1000, num_test=5 if is_test_run() else 500))
print(metric_mig(gt_dataset, get_repr, num_train=20 if is_test_run() else 2000))
6 changes: 3 additions & 3 deletions experiment/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ defaults:
- run_length: long
# logs
- run_callbacks: vis
- run_logging: wandb
- run_logging: none
# runtime
- run_location: stampede_shr
- run_launcher: slurm
- run_location: local
- run_launcher: local
- run_action: train
# entries in this file override entries from default lists
- _self_
Expand Down
2 changes: 1 addition & 1 deletion experiment/config/run_location/local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

dsettings:
trainer:
cuda: TRUE
cuda: NULL # `NULL` tries to use CUDA if it is available. `TRUE` forces cuda to be used!
storage:
logs_dir: 'logs'
data_root: '/tmp/${oc.env:USER}/datasets'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="NathanJMichlo@gmail.com",

version="0.3.3",
version="0.3.4",
python_requires=">=3.8", # we make use of standard library features only in 3.8
packages=setuptools.find_packages(),

Expand Down

0 comments on commit 5695747

Please sign in to comment.