### **Import Libraries**


In [1]:
import albumentations
import cv2
import pytorch_lightning as pl
import torch
import wandb
from albumentations.pytorch import ToTensorV2
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from rgc.models import Classifier, DSteerableLeNet
from rgc.utils.datasets import BentLightningDataModule

### **Run Experiments**


In [2]:
# Define hyperparameters
hparams = {
    "model_name": "dstreeablelenet",
    "image_size": 151,
    "kernel_size": 5,
    "N": 16,
    "model_path": "../results/models/byol/run_0/best.pt",
    "learning_rate": 3e-4,
    "weight_decay": 1e-3,
    "batch_size": 16,
    "num_workers": 4,
    "num_classes": 2,
    "max_epochs": 500,
    "train_transform": albumentations.Compose([
        albumentations.PadIfNeeded(min_height=151, min_width=151, border_mode=cv2.BORDER_CONSTANT, fill=0),
        albumentations.CenterCrop(height=151, width=151),
        albumentations.Affine(rotate=(-360, 360), interpolation=cv2.INTER_LINEAR),
        albumentations.Normalize(mean=(0.0032), std=(0.0376)),
        ToTensorV2(),
    ]),
    "test_transform": albumentations.Compose([
        albumentations.PadIfNeeded(min_height=151, min_width=151, border_mode=cv2.BORDER_CONSTANT, fill=0),
        albumentations.CenterCrop(height=151, width=151),
        albumentations.Normalize(mean=(0.0032), std=(0.0376)),
        ToTensorV2(),
    ]),
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

# Instantiate the data module
data_module = BentLightningDataModule(
    data_dir="../data",
    batch_size=hparams["batch_size"],
    num_workers=hparams["num_workers"],
    transform=hparams["train_transform"],
    test_transform=hparams["test_transform"],
)

# Instantiate the selected model
model = DSteerableLeNet(
    imsize=hparams["image_size"],
    kernel_size=hparams["kernel_size"],
    N=hparams["N"],
    pre_training=False,
    num_classes=hparams["num_classes"],
).to(hparams["device"])
model.eval()
model.load_state_dict(torch.load(hparams["model_path"])["model_state_dict"], strict=False)
model.train()

# Instantiate the VisionClassifier Lightning module
vision_classifier = Classifier(
    model=model,
    num_classes=hparams["num_classes"],
    learning_rate=hparams["learning_rate"],
    weight_decay=1e-3,
)

# W&B Login
wandb.login()

# Initialize W&B logger with hyperparameters
wandb_logger = WandbLogger(name=model.__class__.__name__, project="vision-classifier", log_model="all")
wandb_logger.experiment.config.update(hparams)  # Save hyperparameters

# Checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor="val/acc",
    dirpath="checkpoints",
    filename="vision-classifier-{epoch:02d}-{val/acc:.2f}",
    save_top_k=2,
    mode="max",
)

# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval="step")

# Trainer
trainer = pl.Trainer(
    max_epochs=hparams["max_epochs"],
    logger=wandb_logger,
    callbacks=[checkpoint_callback, lr_monitor],
    accelerator="auto",
    log_every_n_steps=1,
)

# Train the model
trainer.fit(vision_classifier, datamodule=data_module)

# Test the model with the best checkpoint
trainer.test(ckpt_path="best", datamodule=data_module)

# Close
wandb.finish()

  full_mask[mask] = norms.to(torch.uint8)
  model.load_state_dict(torch.load(hparams["model_path"])["model_state_dict"], strict=False)
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmirsazzathossain[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/media/drive1/sazzat/radio-galaxy-classifier/.venv/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /media/drive1/sazzat/radio-galaxy-classifier/Examples/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                      | Params | Mode 
-----------------------------------------------------------------------
0 | model            | DSteerableLeNet           | 45.0 M | train
1 | criterion        | CrossEntropyLoss          | 0      | train
2 | train_accuracy   | MulticlassAccuracy        | 0      | train
3 | val_accuracy     | MulticlassAccuracy        | 0      | train
4 | test_accuracy    | MulticlassAccuracy        | 0      | train
5 | precision        | MulticlassPrecision       | 0      | train
6 | recall           | MulticlassRecall          | 0      | tr

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7fcaf9a132d0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fcaf997af50, execution_count=2 error_before_exec=None error_in_exec=name 'exit' is not defined info=<ExecutionInfo object at 7fcaf99780d0, raw_cell="# Define hyperparameters
hparams = {
    "model_na.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Brtx-sazzat/media/drive1/sazzat/radio-galaxy-classifier/Examples/finetune.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe