Skip to content

Commit

Permalink
Merge pull request #105 from mwalmsley/multiclass
Browse files Browse the repository at this point in the history
Add multiclass example
  • Loading branch information
mwalmsley committed Aug 1, 2023
2 parents 5e76bd6 + d4f466a commit 8e39a9a
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 65 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Expand Up @@ -163,4 +163,6 @@ data/cosmic_dawn*.parquet

results

hparams.yaml
hparams.yaml

data/pretrained_models
15 changes: 12 additions & 3 deletions README.md
Expand Up @@ -52,6 +52,8 @@ I share my install steps [here](#install_cuda). GPUs are optional - Zoobot will
## Quickstart
<a name="quickstart"></a>

The [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) is the quickest way to get started. Alternatively, the minimal example below illustrates how Zoobot works.

Let's say you want to find ringed galaxies and you have a small labelled dataset of 500 ringed or not-ringed galaxies. You can retrain Zoobot to find rings like so:

```python
Expand Down Expand Up @@ -98,7 +100,7 @@ Zoobot includes many guides and working examples - see the [Getting Started](#ge
## Getting Started
<a name="getting_started"></a>

I suggest starting with the worked examples below, which you can copy and adapt.
I suggest starting with the [Colab notebook](https://colab.research.google.com/drive/17bb_KbA2J6yrIm4p4Ue_lEBHMNC1I9Jd?usp=sharing) or the worked examples below, which you can copy and adapt.

For context and explanation, see the [documentation](https://zoobot.readthedocs.io/).

Expand Down Expand Up @@ -147,7 +149,14 @@ CUDA 11.2 and CUDNN 8.1 for TensorFlow 2.10.0:
conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/ # add this environment variable

### Latest features (v1.0.0)
### Latest minor features (v1.0.4)

- Now supports multi-class finetuning. See `pytorch/examples/finetuning/finetune_multiclass_classification.py`
- Removed `simplejpeg` dependency due to M1 install issue.
- Pinned `timm` version to ensure MaX-ViT models load correctly. Models supporting the latest `timm` will follow.
- (internal until published) GZ Evo v2 now includes Cosmic Dawn (HSC). Significant performance improvement on HSC finetuning.

### Latest major features (v1.0.0)

v1.0.0 recognises that most of the complexity in this repo is training Zoobot from scratch, but most non-GZ users will probably simply want to load the pretrained Zoobot and finetune it on their data.

Expand Down Expand Up @@ -193,4 +202,4 @@ You might be interested in reading papers using Zoobot:
- [Towards Foundation Models for Galaxy Morphology](https://arxiv.org/abs/2206.11927) (adding contrastive learning)
- [Harnessing the Hubble Space Telescope Archives: A Catalogue of 21,926 Interacting Galaxies](https://arxiv.org/abs/2303.00366)

Many other works use Zoobot indirectly via the [Galaxy Zoo DECaLS](https://arxiv.org/abs/2102.08414) catalog.
Many other works use Zoobot indirectly via the [Galaxy Zoo DECaLS](https://arxiv.org/abs/2102.08414) catalog.
4 changes: 2 additions & 2 deletions docs/data_notes.rst
Expand Up @@ -107,7 +107,7 @@ We also include a few additional ad-hoc models `on Dropbox <https://www.dropbox.
Which model should I use?
--------------------------

We suggest the PyTorch EfficientNetB0 single-channel 300-pixel model for most users.
We suggest the PyTorch EfficientNetB0 single-channel 224-pixel model for most users.

Zoobot will prioritise PyTorch going forward. For more, see here.
The TensorFlow models currently perform just as well as the PyTorch equivalents but will not benefit from any future updates.
Expand All @@ -119,7 +119,7 @@ Color information does not improve overall performance at predicting GZ votes.
This is a little surprising, but we're confident it's true for our datasets (see the benchmarks folder for our tests).
However, it might be useful to include for other tasks where color is critical, such as hunting certain anomalous galaxies.

Larger input images (300px vs 224px) provide a small boost in performance at predicting GZ votes.
Larger input images (300px vs 224px) provide a very small boost in performance at predicting GZ votes, on our benchmarks.
However, the models require more memory and train/finetune slightly more slowly.
You may want to start with a 224px model and experiment with "upgrading" once you're happy everything works.

Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Expand Up @@ -9,7 +9,6 @@ torch == 1.10.1
torchvision == 0.11.2
torchaudio == 0.10.1
pytorch-lightning==1.6.5 # 1.7 requires protobuf version incompatible with tensorflow/tensorboard. Otherwise works.
simplejpeg
albumentations
pyro-ppl == 1.8.0
pytorch-galaxy-datasets == 0.0.1
Expand Down
28 changes: 21 additions & 7 deletions paper/paper.bib
Expand Up @@ -42,6 +42,7 @@ @article{Walmsley2022decals
volume = {509},
url = {https://arxiv.org/abs/2102.08414},
year = {2022},
doi = {10.1093/mnras/stab2093}
}

@article{2011arXiv1110.3193L,
Expand Down Expand Up @@ -111,6 +112,7 @@ @article{HuertasCompany2022
title = {The DAWES review 10: The impact of deep learning for the analysis of galaxy surveys},
url = {http://arxiv.org/abs/2210.01813},
year = {2022},
doi = {10.1017/pasa.2022.55}
}

@article{LeCun2015,
Expand Down Expand Up @@ -157,6 +159,7 @@ @misc{Bommasani2021
title = {On the Opportunities and Risks of Foundation Models},
url = {http://arxiv.org/abs/2108.07258},
year = {2021},
doi = {10.48550/arXiv.2108.07258}
}

@misc{https://doi.org/10.48550/arxiv.2104.10972,
Expand All @@ -180,19 +183,30 @@ @misc{rw2019timm
howpublished = {\url{https://github.com/rwightman/pytorch-image-models}}
}

@article{bingham2018pyro,
author = {Bingham, Eli and Chen, Jonathan P. and Jankowiak, Martin and Obermeyer, Fritz and
Pradhan, Neeraj and Karaletsos, Theofanis and Singh, Rohit and Szerlip, Paul and
Horsfall, Paul and Goodman, Noah D.},
title = {{Pyro: Deep Universal Probabilistic Programming}},
journal = {Journal of Machine Learning Research},
year = {2018}
@article{bingham2019pyro,
author = {Eli Bingham and
Jonathan P. Chen and
Martin Jankowiak and
Fritz Obermeyer and
Neeraj Pradhan and
Theofanis Karaletsos and
Rohit Singh and
Paul A. Szerlip and
Paul Horsfall and
Noah D. Goodman},
title = {Pyro: Deep Universal Probabilistic Programming},
journal = {J. Mach. Learn. Res.},
volume = {20},
pages = {28:1--28:6},
year = {2019},
url = {http://jmlr.org/papers/v20/18-403.html}
}

@article{phan2019composable,
author = {Phan, Du and Pradhan, Neeraj and Jankowiak, Martin},
title = {Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro},
journal = {arXiv preprint arXiv:1912.11554},
doi = {10.48550/arXiv.1912.11554},
year = {2019}
}

Expand Down
2 changes: 1 addition & 1 deletion paper/paper.md
Expand Up @@ -71,7 +71,7 @@ authors:
- name: Crisel Suárez
orcid: 0000-0001-5243-7659
corresponding: false
affiliation: ["11, 12"]
affiliation: "11, 12"
- name: Nicolás Guerra-Varas
orcid: 0000-0002-9718-6352
corresponding: false
Expand Down
37 changes: 0 additions & 37 deletions requirements.txt

This file was deleted.

17 changes: 7 additions & 10 deletions setup.py
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="zoobot",
version="1.0.2",
version="1.0.4",
author="Mike Walmsley",
author_email="walmsleymk1@gmail.com",
description="Galaxy morphology classifiers",
Expand All @@ -29,11 +29,11 @@
'torchvision == 0.13.1+cpu',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
# 'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
'pytorch_m1': [
# as above but without the +cpu (and the extra-index-url in readme has no effect)
Expand All @@ -42,11 +42,10 @@
'torchvision == 0.13.1',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
# as above but without pytorch itself
# for GPU, you will also need e.g. cudatoolkit=11.3, 11.6
Expand All @@ -56,19 +55,17 @@
'torchvision == 0.13.1+cu113',
'torchaudio == 0.12.1',
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl == 1.8.0',
'torchmetrics == 0.11.0',
'timm'
'timm == 0.6.12'
],
'pytorch_colab': [
'pytorch-lightning >= 2.0.0',
'simplejpeg',
'albumentations',
'pyro-ppl>=1.8.0',
'torchmetrics==0.11.0',
'timm'
'timm == 0.6.12'
],
'tensorflow': [
'tensorflow == 2.10.0', # 2.11.0 turns on XLA somewhere which then fails on multi-GPU...TODO
Expand Down Expand Up @@ -105,6 +102,6 @@
# for saving metrics to weights&biases (cloud service, free within limits)
'wandb',
'setuptools==59.5.0', # wandb logger incompatibility
'galaxy-datasets==0.0.12' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
'galaxy-datasets==0.0.14' # for dataset loading in both TF and Torch (renamed from pytorch-galaxy-datasets)
]
)
@@ -0,0 +1,94 @@
import logging
import os

from zoobot.pytorch.training import finetune
from galaxy_datasets import demo_rings
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule


if __name__ == '__main__':

logging.basicConfig(level=logging.INFO)

zoobot_dir = '/Users/user/repos/zoobot' # TODO set to directory where you cloned Zoobot

# load in catalogs of images and labels to finetune on
# each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels
# here I'm using galaxy-datasets to download some premade data - check it out for examples
data_dir = '/Users/user/repos/galaxy-datasets/roots/demo_rings' # TODO set to any directory. rings dataset will be downloaded here
train_catalog, _ = demo_rings(root=data_dir, download=True, train=True)
test_catalog, _ = demo_rings(root=data_dir, download=True, train=False)

# wondering about "label_cols"?
# This is a list of catalog columns which should be used as labels
# Here:
# TODO should use Galaxy MNIST as my example here
label_cols = ['ring']
# For binary classification, the label column should have binary (0 or 1) labels for your classes
import numpy as np
# 0, 1, 2
train_catalog['ring'] = np.random.randint(low=0, high=3, size=len(train_catalog))

# TODO
# To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.

# load a pretrained checkpoint saved here
checkpoint_loc = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch/effnetb0_greyscale_224px.ckpt')
# checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt'

# save the finetuning results here
save_dir = os.path.join(zoobot_dir, 'results/pytorch/finetune/finetune_multiclass_classification')

datamodule = GalaxyDataModule(
label_cols=label_cols,
catalog=train_catalog, # very small, as a demo
batch_size=32
)
# datamodule.setup()
# for images, labels in datamodule.train_dataloader():
# print(images.shape)
# print(labels.shape)
# exit()


model = finetune.FinetuneableZoobotClassifier(
checkpoint_loc=checkpoint_loc,
num_classes=3,
n_layers=0 # only updating the head weights. Set e.g. 1, 2 to finetune deeper.
)
# under the hood, this does:
# encoder = finetune.load_pretrained_encoder(checkpoint_loc)
# model = finetune.FinetuneableZoobotClassifier(encoder=encoder, ...)

# retrain to find rings
trainer = finetune.get_trainer(save_dir, accelerator='cpu', max_epochs=1)
trainer.fit(model, datamodule)
# can now use this model or saved checkpoint to make predictions on new data. Well done!

# pretending we want to load from scratch:
best_checkpoint = trainer.checkpoint_callback.best_model_path
finetuned_model = finetune.FinetuneableZoobotClassifier.load_from_checkpoint(best_checkpoint)

from zoobot.pytorch.predictions import predict_on_catalog

predict_on_catalog.predict(
test_catalog,
finetuned_model,
n_samples=1,
label_cols=label_cols,
save_loc=os.path.join(save_dir, 'finetuned_predictions.csv')
# trainer_kwargs={'accelerator': 'gpu'}
)
"""
Under the hood, this is essentially doing:
import pytorch_lightning as pl
predict_trainer = pl.Trainer(devices=1, max_epochs=-1)
predict_datamodule = GalaxyDataModule(
label_cols=None, # important, else you will get "conv2d() received an invalid combination of arguments"
predict_catalog=test_catalog,
batch_size=32
)
preds = predict_trainer.predict(finetuned_model, predict_datamodule)
print(preds)
"""
13 changes: 10 additions & 3 deletions zoobot/pytorch/training/finetune.py
Expand Up @@ -269,9 +269,16 @@ def __init__(
self.loss = partial(cross_entropy_loss,
weight=class_weights,
label_smoothing=self.label_smoothing)
self.train_acc = tm.Accuracy(task='binary', average="micro")
self.val_acc = tm.Accuracy(task='binary', average="micro")
self.test_acc = tm.Accuracy(task='binary', average="micro")
logging.info(f'num_classes: {num_classes}')
if num_classes == 2:
logging.info('Using binary classification')
task = 'binary'
else:
logging.info('Using multi-class classification')
task = 'multiclass'
self.train_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)

def step_to_dict(self, y, y_pred, loss):
y_class_preds = torch.argmax(y_pred, axis=1)
Expand Down
3 changes: 3 additions & 0 deletions zoobot/shared/schemas.py
Expand Up @@ -268,6 +268,9 @@ def answers(self):
gz_candels_ortho_schema = Schema(label_metadata.candels_ortho_pairs, label_metadata.candels_ortho_dependencies)
gz_hubble_ortho_schema = Schema(label_metadata.hubble_ortho_pairs, label_metadata.hubble_ortho_dependencies)
cosmic_dawn_ortho_schema = Schema(label_metadata.cosmic_dawn_ortho_pairs , label_metadata.cosmic_dawn_ortho_dependencies)

# schemas without orthogonal question suffix (-cd, -dr8, etc)
cosmic_dawn_schema = Schema(label_metadata.cosmic_dawn_pairs , label_metadata.cosmic_dawn_dependencies)
gz_rings_schema = Schema(label_metadata.rings_pairs, label_metadata.rings_dependencies)
desi_schema = Schema(label_metadata.desi_pairs, label_metadata.desi_dependencies) # for DESI data release prediction users, not for ML training - no -dr5, -dr8, etc
# note that as this is a call to Schema (and Question and Answer), any logging within those will
Expand Down

0 comments on commit 8e39a9a

Please sign in to comment.