## Final Code

In [None]:
from zoobot.pytorch.training.finetune import FinetuneableZoobotClassifier
from zoobot.pytorch.training.finetune import FinetuneableZoobotAbstract
from galaxy_datasets.pytorch.galaxy_datamodule import CatalogDataModule

from zoobot.pytorch.training.finetune import LinearHead
import logging
from functools import partial
import torchmetrics as tm
from zoobot.pytorch.training.finetune import cross_entropy_loss

import torch
import pandas as pd

import os

In [None]:
class FinetuneableZoobotMetadataAbstract(FinetuneableZoobotAbstract):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def run_step_through_model(self, batch):
        # part of training/val/test for all subclasses
        image, y = self.batch_to_supervised_tuple(batch)
        y_pred = self.forward(batch)
             
        # must be subclasses and specified
        loss = self.loss(y_pred, y)  # type:ignore
        loss.float()
        return y, y_pred, loss

In [None]:
class FinetuneableZoobotMetadataClassifier(FinetuneableZoobotMetadataAbstract, FinetuneableZoobotClassifier):    
    def __init__(
            self,
            num_classes: int,
            label_col: str = 'label',
            label_smoothing=0.,
            class_weights=None,
            metadata_cols=None,

            run_linear_sanity_check: bool = False,
            **super_kwargs) -> None:

        super().__init__(
            num_classes=num_classes,
            label_col=label_col,
            label_smoothing=label_smoothing,
            class_weights=class_weights,
            **super_kwargs
        )

        self.label_col = label_col
        
        logging.info("Using classification head and cross-entropy loss")
        self.head = LinearHead(
            input_dim=self.encoder_dim,  # type: ignore
            output_dim=num_classes,
            head_dropout_prob=self.head_dropout_prob,
        )
        self.label_smoothing = label_smoothing

        # if isinstance(class_weights, list) or isinstance(class_weights, np.ndarray):
        if class_weights is not None:
            # https://lightning.ai/docs/pytorch/stable/accelerators/accelerator_prepare.html#init-tensors-using-tensor-to-and-register-buffer
            self.register_buffer("class_weights", torch.Tensor(class_weights))
            print(self.class_weights, self.class_weights.device)  # type: ignore
            # can now use self.class_weights in forward pass and will be on correct device (because treated as model parameters)
        else:
            self.class_weights = None

        self.loss = partial(cross_entropy_loss,
                            weight=self.class_weights,
                            label_smoothing=self.label_smoothing)
        logging.info(f'num_classes: {num_classes}')

        if num_classes == 2:
            logging.info("Using binary classification")
            task = "binary"
        else:
            logging.info("Using multi-class classification")
            task = "multiclass"
        self.train_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
        self.val_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)
        self.test_acc = tm.Accuracy(task=task, average="micro", num_classes=num_classes)

        self.run_linear_sanity_check = run_linear_sanity_check
        
        self.metadata_cols = metadata_cols or []
        metadata_dim = len(self.metadata_cols)
        
        prev_head = self.head
        self.head = LinearHead(
            input_dim=prev_head.input_dim + metadata_dim,
            output_dim=num_classes,
            head_dropout_prob=prev_head.dropout.p,
        )
        
    def forward(self, batch):
        x = torch.tensor(batch['image'], dtype=torch.float, device=self.device)
        x = self.encoder(x)

        # collect metadata columns as tensor
        if self.metadata_cols:
            metadata = torch.cat([
                torch.tensor(batch[col], dtype=torch.float, device=x.device).unsqueeze(1)
                for col in self.metadata_cols
            ], dim=1)
            x = torch.cat([x, metadata], dim=1)

        x = self.head(x)
        return x


In [None]:
train_metadata = pd.read_csv('../../imgs/train_dataset.csv')
test_metadata = pd.read_csv('../../imgs/test_dataset.csv')

In [None]:
train_metadata.head()

Unnamed: 0,image,label,id_str,X,y
0,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,a17e8ab7a7de9c79d3cf960af591bfd113669e405a9803...,3.955509,9.051425
1,<PIL.JpegImagePlugin.JpegImageFile image mode=...,1,2acac60ba3f744afb21f2e6c59d876949c8305e4f519cf...,9.231839,23.476162
2,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,50f422c99ac53ecdfb5559fb71e758c7052f3f4ad74e58...,1.518943,2.27675
3,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,12fba589f6f7b0a3ce5c17039760b3c98e1cd6b69a9c1f...,9.937435,25.216443
4,<PIL.JpegImagePlugin.JpegImageFile image mode=...,1,f69e6578faa676168e96e9e4595fb79fff28aee64fd6bf...,5.620438,16.640429


In [None]:
train_metadata["file_loc"] = "../../imgs/images/train/" + train_metadata["id_str"].astype(str) + ".jpg"
test_metadata["file_loc"] = "../../imgs/images/test/" + test_metadata["id_str"].astype(str) + ".jpg"

In [None]:
train_metadata.head()

Unnamed: 0,image,label,id_str,X,y,file_loc
0,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,a17e8ab7a7de9c79d3cf960af591bfd113669e405a9803...,3.955509,9.051425,../../imgs/images/train/a17e8ab7a7de9c79d3cf96...
1,<PIL.JpegImagePlugin.JpegImageFile image mode=...,1,2acac60ba3f744afb21f2e6c59d876949c8305e4f519cf...,9.231839,23.476162,../../imgs/images/train/2acac60ba3f744afb21f2e...
2,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,50f422c99ac53ecdfb5559fb71e758c7052f3f4ad74e58...,1.518943,2.27675,../../imgs/images/train/50f422c99ac53ecdfb5559...
3,<PIL.JpegImagePlugin.JpegImageFile image mode=...,0,12fba589f6f7b0a3ce5c17039760b3c98e1cd6b69a9c1f...,9.937435,25.216443,../../imgs/images/train/12fba589f6f7b0a3ce5c17...
4,<PIL.JpegImagePlugin.JpegImageFile image mode=...,1,f69e6578faa676168e96e9e4595fb79fff28aee64fd6bf...,5.620438,16.640429,../../imgs/images/train/f69e6578faa676168e96e9...


In [None]:
from galaxy_datasets.transforms import default_view_config, get_galaxy_transform

transform_cfg = default_view_config()
transform = get_galaxy_transform(transform_cfg)

datamodule = CatalogDataModule(
  label_cols=["label", "X"],
  catalog=train_metadata,
  train_transform=transform,
  test_transform=transform,
  batch_size=8,
)

In [None]:
model = FinetuneableZoobotMetadataClassifier(
    name='hf_hub:mwalmsley/zoobot-encoder-efficientnet_b0',
    training_mode="full",
    learning_rate=5e-5,
    layer_decay=0.65,
    num_classes=2,
    label_col='label',
    metadata_cols=['X']
    )

In [None]:
model.head

LinearHead(
  (dropout): Dropout(p=0.5, inplace=False)
  (linear): Linear(in_features=1281, out_features=2, bias=True)
)

In [None]:
from zoobot.pytorch.training.finetune import get_trainer

datamodule.setup("fit")

save_dir = os.path.join('./save_dir/')

trainer = get_trainer(save_dir, accelerator="auto", devices=1, strategy="auto", max_epochs=10)
trainer.fit(model, datamodule)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name              | Type           | Params | Mode 
-------------------------------------------------------------
0 | encoder           | EfficientNet   | 4.0 M  | train
1 | train_loss_metric | MeanMetric     | 0      | train
2 | val_loss_metric   | MeanMetric     | 0      | train
3 | test_loss_metric  | MeanMetric     | 0      | train
4 | head              | LinearHead     | 2.6 K  | train
5 | train_acc         | BinaryAccuracy | 0      | train
6 | val_acc           | BinaryAccuracy | 0      | train
7 | test_acc          | BinaryAccuracy | 0      | train
-------------------------------------------------------------
4.0 M     Trainable params
0         Non-trainable params
4.0 M     Total params
16.040    Total estimated model params size (MB)
346       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  x = torch.tensor(batch['image'], dtype=torch.float, device=self.device)
  torch.tensor(batch[col], dtype=torch.float, device=x.device).unsqueeze(1)


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 176: 'finetuning/val_loss' reached 0.55278 (best 0.55278), saving model to 'C:\\Users\\Anna\\Downloads\\Projects\\NeuralNetworks_KMA\\lab2 and 3\\zoobot_metadata\\zoobot\\zoobot_metadata\\save_dir\\checkpoints\\0.ckpt' as top 1


In [None]:
finetuned_model = FinetuneableZoobotClassifier.load_from_checkpoint('./zoobot_finetune/checkpoints/6.ckpt')