<a href="https://colab.research.google.com/github/josvalen/practicetemp/blob/main/Zoobot_Finetune_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import logging

logging.basicConfig(level=logging.INFO)

This notebook demonstrates finetuning Zoobot.

Finetuning means adapting a model pretrained on a large amount of data (here, many Galaxy Zoo answers) to solve a new problem using a small amount of new data.

We follow these steps:
- Install Zoobot (PyTorch version)
- Downloads a pretrained checkpoint
- Download the data to finetune on, starting from that checkpoint and the data to finetune on (ring images and ring label catalog)
- Configure and run the finetuning

For standalone script examples, see zoobot/pytorch/examples/finetuning



---



*Retraining will be quite slow unless you use a GPU. Use the top
toolbar: Runtime- > Change Runtime -> GPU*

---

## Install Zoobot

In [3]:
!git clone https://github.com/mwalmsley/zoobot.git # places the cloned repo into zoobot_dir

# there's an identical notebook I use for testing the pre-release versions of zoobot and galaxy-datasets here, if useful
# https://colab.research.google.com/drive/1A_-M3Sz5maQmyfW2A7rEu-g_Zi0RMGz5?usp=sharing

Cloning into 'zoobot'...
remote: Enumerating objects: 8328, done.[K
remote: Counting objects: 100% (2316/2316), done.[K
remote: Compressing objects: 100% (760/760), done.[K
remote: Total 8328 (delta 1601), reused 2255 (delta 1546), pack-reused 6012[K
Receiving objects: 100% (8328/8328), 334.16 MiB | 29.27 MiB/s, done.
Resolving deltas: 100% (5192/5192), done.


In [4]:
!pip install -e /content/zoobot[pytorch_colab]

Obtaining file:///content/zoobot
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting wandb (from zoobot==1.0.3)
  Downloading wandb-0.15.7-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting setuptools==59.5.0 (from zoobot==1.0.3)
  Downloading setuptools-59.5.0-py3-none-any.whl (952 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m952.4/952.4 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting galaxy-datasets==0.0.12 (from zoobot==1.0.3)
  Downloading galaxy_datasets-0.0.12-py3-none-any.whl (51 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.3/51.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyto



If you will get the error below later:
> AttributeError: module 'pkg_resources._vendor.packaging' has no attribute 'requirements'

then restart the runtime and run it again - press the button just above.

In [5]:
# google colab needs this hack to make git-cloned packaged importable
# it's not needed locally
import os
import sys
zoobot_dir = '/content/zoobot'
os.chdir(zoobot_dir)
sys.path.append(zoobot_dir)
!git pull  # just to make sure we're up to date

Already up to date.


Now we're set up and can start using Zoobot.

## Download Pretrained Checkpoint

In [6]:


# make a directory to place the checkpoint
# this could be anywhere, but Zoobot has this folder already
checkpoint_dir = os.path.join(zoobot_dir, 'data/pretrained_models/pytorch')  # Can place your checkpoint anywhere,

# if not os.path.isdir(checkpoint_dir):
#   os.makedirs(checkpoint_dir)

Download the pretrained model checkpoint from Dropbox.

The pretrained models are described and linked from the [Data Notes](https://zoobot.readthedocs.io/en/latest/data_notes.html) docs.

Outside Colab, you can just download them with a browser.
On Colab, we don't have a file browser, so we have will download them with this one-liner.

In [7]:
!wget --no-check-certificate 'https://dl.dropboxusercontent.com/s/7ixwo59imjfz4ay/effnetb0_greyscale_224px.ckpt?dl=0' -O $checkpoint_dir/checkpoint.ckpt

--2023-07-27 18:09:29--  https://dl.dropboxusercontent.com/s/7ixwo59imjfz4ay/effnetb0_greyscale_224px.ckpt?dl=0
Resolving dl.dropboxusercontent.com (dl.dropboxusercontent.com)... 162.125.1.15, 2620:100:6016:15::a27d:10f
Connecting to dl.dropboxusercontent.com (dl.dropboxusercontent.com)|162.125.1.15|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 17384753 (17M) [application/octet-stream]
Saving to: ‘/content/zoobot/data/pretrained_models/pytorch/checkpoint.ckpt’


2023-07-27 18:09:32 (54.1 MB/s) - ‘/content/zoobot/data/pretrained_models/pytorch/checkpoint.ckpt’ saved [17384753/17384753]



## Download Catalogs of Images and Labels

Each catalog should be a dataframe with columns of "id_str", "file_loc", and any labels.

Here I'm using galaxy-datasets to download some premade data - check it out for examples.

In [8]:
# galaxy-datasets is a dependency of Zoobot.
# It has code handling downloading and loading data.
from galaxy_datasets import demo_rings
from galaxy_datasets import gz_candels

In [9]:
data_dir = os.path.join(zoobot_dir, 'data/gz_candels')

In [11]:
train_and_val_catalog, _  = gz_candels(root=data_dir, download=True, train=True)
test_catalog, _ = gz_candels(root=data_dir, download=True, train=False)

Using downloaded and verified file: /content/zoobot/data/gz_candels/candels_ortho_train_catalog.parquet
Using downloaded and verified file: /content/zoobot/data/gz_candels/candels_ortho_test_catalog.parquet
Downloading https://dl.dropboxusercontent.com/s/d67we9xsn8vyr5k/candels_images.tar.gz to /content/zoobot/data/gz_candels/candels_images.tar.gz


100%|██████████| 6446950787/6446950787 [05:08<00:00, 20878551.50it/s]


Extracting /content/zoobot/data/gz_candels/candels_images.tar.gz to /content/zoobot/data/gz_candels


In [14]:
train_and_val_catalog.head()

Unnamed: 0,smooth-or-featured-candels_smooth,smooth-or-featured-candels_features,smooth-or-featured-candels_artifact,how-rounded-candels_completely,how-rounded-candels_in-between,how-rounded-candels_cigar-shaped,clumpy-appearance-candels_yes,clumpy-appearance-candels_no,clump-count-candels_1,clump-count-candels_2,...,spiral-arm-count-candels_cant-tell,bulge-size-candels_none,bulge-size-candels_obvious,bulge-size-candels_dominant,merging-candels_merger,merging-candels_tidal-debris,merging-candels_both,merging-candels_neither,filename,file_loc
0,7.0,0.0,17.0,6.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.0,UDS_14439.jpg,/content/zoobot/data/gz_candels/images/UDS_144...
1,14.0,4.0,16.0,5.0,9.0,0.0,4.0,0.0,2.0,0.0,...,0.0,0.0,0.0,0.0,3.0,0.0,1.0,14.0,COS_89.jpg,/content/zoobot/data/gz_candels/images/COS_89.jpg
2,5.0,5.0,8.0,2.0,3.0,0.0,3.0,2.0,1.0,0.0,...,1.0,2.0,0.0,0.0,0.0,0.0,0.0,10.0,GDS_14317.jpg,/content/zoobot/data/gz_candels/images/GDS_143...
3,17.0,7.0,12.0,4.0,13.0,0.0,2.0,5.0,2.0,0.0,...,0.0,2.0,2.0,1.0,2.0,2.0,2.0,18.0,UDS_9017.jpg,/content/zoobot/data/gz_candels/images/UDS_901...
4,8.0,2.0,12.0,4.0,3.0,1.0,2.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,10.0,COS_7477.jpg,/content/zoobot/data/gz_candels/images/COS_747...


Wondering about "label_cols"?

This is a list of catalog columns which should be used as labels

Here,`label_cols = ['ring']`



In [22]:
 schema = gz_candels_ortho_schema
 label_cols=schema.label_cols

For binary classification, the label column should have binary (0 or 1) labels for your classes.

To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.


## Configure Finetuning

In [16]:
import pandas as pd

from zoobot.pytorch.training import finetune
from galaxy_datasets.pytorch.galaxy_datamodule import GalaxyDataModule
from zoobot.shared.schemas import gz_candels_ortho_schema
from sklearn.model_selection import train_test_split

GZDESI and GZRings not available from galaxy_datasets.pytorch.datasets - skipping


In [139]:
# TODO you can update these paths to suit own data


#checkpoint_loc = os.path.join(
        # TODO replace with path to downloaded checkpoints. See Zoobot README for download links.
#        repo_dir, 'gz-decals-classifiers/results/benchmarks/pytorch/evo/uploaded/effnetb0_greyscale_224px.ckpt')
#save_dir = os.path.join(
#        repo_dir, f'gz-decals-classifiers/results/finetune_{np.random.randint(1e8)}')
train_catalog, val_catalog = train_test_split(train_and_val_catalog, test_size=0.3)
repo_dir = '/Users/user/repos'
checkpoint_loc = '/Users/user/repos/gz-decals-classifiers/results/benchmarks/pytorch/dr5/dr5_py_gr_15366/checkpoints/epoch=58-step=18939.ckpt'


save_dir = os.path.join(
        repo_dir, f'gz-decals-classifiers/results/finetune_{np.random.randint(1e8)}')


In [140]:
#label_col = 'ring'  # name of column in catalog with binary (0 or 1) labels for your classes
#label_cols = [label_col]  # To support more complicated labels, Zoobot expects a list of columns. A list with one element works fine.
label_cols=schema.label_cols
label_cols

['smooth-or-featured-candels_smooth',
 'smooth-or-featured-candels_features',
 'smooth-or-featured-candels_artifact',
 'how-rounded-candels_completely',
 'how-rounded-candels_in-between',
 'how-rounded-candels_cigar-shaped',
 'clumpy-appearance-candels_yes',
 'clumpy-appearance-candels_no',
 'clump-count-candels_1',
 'clump-count-candels_2',
 'clump-count-candels_3',
 'clump-count-candels_4',
 'clump-count-candels_5-plus',
 'clump-count-candels_cant-tell',
 'disk-edge-on-candels_yes',
 'disk-edge-on-candels_no',
 'edge-on-bulge-candels_yes',
 'edge-on-bulge-candels_no',
 'bar-candels_yes',
 'bar-candels_no',
 'has-spiral-arms-candels_yes',
 'has-spiral-arms-candels_no',
 'spiral-winding-candels_tight',
 'spiral-winding-candels_medium',
 'spiral-winding-candels_loose',
 'spiral-arm-count-candels_1',
 'spiral-arm-count-candels_2',
 'spiral-arm-count-candels_3',
 'spiral-arm-count-candels_4',
 'spiral-arm-count-candels_5-plus',
 'spiral-arm-count-candels_cant-tell',
 'bulge-size-candels_n

In [141]:
#datamodule = GalaxyDataModule(
#  label_cols=label_cols,
#  catalog=train_catalog,
#  batch_size=32,
#  resize_after_crop=224,  # the size of the images input to the model
#  num_workers=2  # sets the parallelism for loading data. 2 works well on colab.
#)

resize_after_crop = 224  # must match how checkpoint below was trained
datamodule = GalaxyDataModule(
        label_cols=schema.label_cols,
        train_catalog=train_catalog,
        val_catalog=val_catalog,
        test_catalog=test_catalog,
        batch_size=32,
        # uses default_augs
        resize_after_crop=resize_after_crop
)

## Check Images Load

Optional - check that all images load correctly.

Worth checking once, and especially if you get "InvalidArgumentError" when running the model below.

In [142]:
assert all([os.path.isfile(loc) for loc in train_catalog['file_loc']])

# Now the Actual Finetuning

In [143]:
#model = finetune.FinetuneableZoobotClassifier(
#  checkpoint_loc=checkpoint_loc,
#  num_classes=2,
#  n_layers=2  # only updating the head weights. Set 0 for only output layer. Set e.g. 1, 2 to finetune deeper.
#)
# don't worry about any "automatically upgraded" INFO message below
schema = gz_candels_ortho_schema



FileNotFoundError: ignored

In [144]:
#trainer = finetune.get_trainer(save_dir, accelerator='auto', devices='auto', max_epochs=130)
#trainer.fit(model, datamodule)
# should reach about 80% accuracy, loss of about 0.4
from pytorch_lightning.loggers import WandbLogger
logger =None
trainer = finetune.get_trainer(save_dir=save_dir, logger=logger, accelerator='auto',max_epochs=2)
trainer.fit(model, datamodule)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name              | Type         | Params
---------------------------------------------------
0 | encoder           | EfficientNet | 4.0 M 
1 | train_loss_metric | MeanMetric   | 0     
2 | val_loss_metric   | MeanMetric   | 0     
3 | test_loss_metric  | MeanMetric   | 0     
4 | head              | Sequential   | 48.7 K
---------------------------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.223    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]



Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 868: 'finetuning/val_loss' reached 2.26360 (best 2.26360), saving model to '/Users/user/repos/gz-decals-classifiers/results/finetune_8309475/checkpoints/0.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 1, global step 1736: 'finetuning/val_loss' reached 1.93284 (best 1.93284), saving model to '/Users/user/repos/gz-decals-classifiers/results/finetune_8309475/checkpoints/1.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


In [176]:
# now we can load the best checkpoint and make predictions
best_checkpoint = trainer.checkpoint_callback.best_model_path
model = finetune.FinetuneableZoobotTree(checkpoint_loc=checkpoint_loc, schema=schema)

FileNotFoundError: ignored

In [175]:
def predict(catalog: pd.DataFrame, model: pl.LightningModule, n_samples: int, save_loc: str, datamodule_kwargs={}, trainer_kwargs={}):


    image_id_strs = list(catalog['id_str'].astype(str))

    predict_datamodule = GalaxyDataModule(  # not using label_cols to load labels, we're only using it to name our predictions
        predict_catalog=catalog,  # no need to specify the other catalogs
        # will use the default transforms unless overridden with datamodule_kwargs
        #
        **datamodule_kwargs  # e.g. batch_size, resize_size, crop_scale_bounds, etc.
    )
    # with this stage arg, will only use predict_catalog
    # crucial to specify the stage, or will error (as missing other catalogs)
    predict_datamodule.setup(stage='predict')

    datamodule_kwargs = {'batch_size': 32, 'resize_after_crop': resize_after_crop}
    trainer_kwargs = {'devices': 1, 'accelerator': accelerator}
    predict_on_catalog.predict(
        test_catalog,
        model,
        n_samples=1,
        label_cols=schema.label_cols,
        save_loc=os.path.join(save_dir, 'finetuned_predictions.csv'),
        datamodule_kwargs=datamodule_kwargs,
        trainer_kwargs=trainer_kwargs
    )
predictions = predict_trainer.predict(model, predict_datamodule)


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

Let's quickly check if they're any good: