# AstroCLIP

This notebook finetunes the astroclip model on this dataset.

## Load model

In [2]:
from astroclip.models import AstroClipModel

model = AstroClipModel.load_from_checkpoint(
    "./pretrained/astroclip.ckpt"
).cuda()

## Data

In [3]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((144, 144)),
    transforms.ToTensor()
])

In [4]:
from datamodule import VLASSLoader
datamodule = VLASSLoader(
    root = "./data", batch_size=32, pin_memory=True,
    transform = transform
)

## LinearProbe

In [7]:
from torch import nn
from models.vision import VisionModel

class AstroClip(VisionModel):
    def __init__(self, embed_dim=1024, freeze=True):
        super().__init__()
        self.backbone = model
        self.head = nn.Linear(embed_dim, self.num_classes)
        
        if freeze:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.backbone(x, input_type='image')
        x = self.head(x)
        return x

In [27]:
clip_model = AstroClip().cuda()

In [28]:
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint

checkpointer = ModelCheckpoint('./outputs/AstroCLIP', monitor='val_loss')

trainer = Trainer(
    # max_steps=100, 
    max_epochs=10,
    accelerator='gpu', devices=1,
    callbacks=[checkpointer]
)
trainer.fit(clip_model, datamodule)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params | Mode 
----------------------------------------------------
0 | backbone | AstroClipModel | 370 M  | eval 
1 | head     | Linear         | 4.1 K  | train
----------------------------------------------------
4.1 K     Trainable params
370 M     Non-trainable params
370 M     Total params
1,483.210 Total estimated model params size (MB)
1         Modules in train mode
525       Modules in eval mode
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## Test

In [29]:
val_dataloader = datamodule.val_dataloader()
clip_model = clip_model.cuda()

In [32]:
from utils import test_model
result = test_model(clip_model, val_dataloader)
print(f'Accuracy = {result["accuracy"]}, F1 = {result["f1"]}')
print('Classification report:\n', result['classification_report'])
print('Confusion Matrix:\n', result['confusion_matrix'])

In [31]:
result

{'accuracy': 0.7811684073107049,
 'f1': 0.6263537675883039,
 'confusion_matrix': array([[2349,   24,  224,   31],
        [  50, 4777,  342,    3],
        [ 207,  676, 2400,   39],
        [ 242,   69,  775,   48]]),
 'classification_report': '              precision    recall  f1-score   support\n\n           0       0.82      0.89      0.86      2628\n           1       0.86      0.92      0.89      5172\n           2       0.64      0.72      0.68      3322\n           3       0.40      0.04      0.08      1134\n\n    accuracy                           0.78     12256\n   macro avg       0.68      0.65      0.63     12256\nweighted avg       0.75      0.78      0.75     12256\n'}