In [None]:
# Install necessary packages
!pip install torch torchvision pyyaml


In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!git clone https://github.com/gj33/IJEPAGalaxyZoo.git

In [4]:
import torch
import torch.nn as nn
import sys
sys.path.append('/content/IJEPAGalaxyZoo')  # Add the IJepa src directory to the Python path

import src.models.vision_transformer as vit




In [6]:
import torch
import torch.nn as nn
import sys
sys.path.append('/content/IJEPAGalaxyZoo/src')  # Ensure the path is correct

from models.vision_transformer import vit_tiny


In [9]:
import torch.nn as nn

class GalaxyClassifier(nn.Module):
    def __init__(self, backbone, embed_dim=192, num_classes=8):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        patch_embeddings = self.backbone(x)     # [B, 256, 192]
        pooled = patch_embeddings.mean(dim=1)   # [B, 192]
        logits = self.classifier(pooled)        # [B, 8]
        return logits

In [10]:

base_vit_model = vit_tiny(
    img_size=[224],
    patch_size=14,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    use_checkpoint=False,
)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GalaxyClassifier(backbone=base_vit_model, embed_dim=192, num_classes=8).to(device)


# Load the checkpoint from after fine-tuning
checkpoint_path = '/content/drive/MyDrive/GalaxyZoo/checkpoints2/models_epoch_10.pth'

checkpoint = torch.load(checkpoint_path, map_location='cpu')

print("Checkpoint keys:", checkpoint.keys())

state_dict = checkpoint

state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}

model.load_state_dict(state_dict, strict=False)

print("Checkpoint loaded successfully into the GalaxyClassifier model.")

Checkpoint keys: odict_keys(['backbone.pos_embed', 'backbone.patch_embed.proj.weight', 'backbone.patch_embed.proj.bias', 'backbone.blocks.0.norm1.weight', 'backbone.blocks.0.norm1.bias', 'backbone.blocks.0.attn.qkv.weight', 'backbone.blocks.0.attn.qkv.bias', 'backbone.blocks.0.attn.proj.weight', 'backbone.blocks.0.attn.proj.bias', 'backbone.blocks.0.norm2.weight', 'backbone.blocks.0.norm2.bias', 'backbone.blocks.0.mlp.fc1.weight', 'backbone.blocks.0.mlp.fc1.bias', 'backbone.blocks.0.mlp.fc2.weight', 'backbone.blocks.0.mlp.fc2.bias', 'backbone.blocks.1.norm1.weight', 'backbone.blocks.1.norm1.bias', 'backbone.blocks.1.attn.qkv.weight', 'backbone.blocks.1.attn.qkv.bias', 'backbone.blocks.1.attn.proj.weight', 'backbone.blocks.1.attn.proj.bias', 'backbone.blocks.1.norm2.weight', 'backbone.blocks.1.norm2.bias', 'backbone.blocks.1.mlp.fc1.weight', 'backbone.blocks.1.mlp.fc1.bias', 'backbone.blocks.1.mlp.fc2.weight', 'backbone.blocks.1.mlp.fc2.bias', 'backbone.blocks.2.norm1.weight', 'backbone

In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GalaxyClassifier(backbone=model, embed_dim=192, num_classes=8).to(device)

In [None]:
# Number of classes in TinyImageNet
num_classes = 8

# Replace the head with a new classification head
model.head = nn.Linear(model.embed_dim, num_classes)



In [17]:
from torchvision import datasets, transforms

# Data transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225],   # ImageNet std
    ),
])

transform_val = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])



In [18]:
import os
import subprocess
import time

import numpy as np

from logging import getLogger

import torch
import torchvision
import pandas as pd

from PIL import Image
from torch.utils.data import Dataset

from torchvision import transforms

_GLOBAL_SEED = 0
logger = getLogger()

dir_cat = "/content/drive/MyDrive/GalaxyZoo/"
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
df = pd.read_csv(dir_cat+'gz2_valid.csv')
class GalaxyZooDataset(Dataset):
    '''Galaxy Zoo 2 image dataset
        Args:
            dataframe : pd.dataframe, outputs from the data_split function
                e.g. df_train / df_valid / df_test
            dir_image : str, path where galaxy images are located
            label_tag : str, class label system to be used for training
                e.g. label_tag = 'label1' / 'label2' / 'label3' / 'label4'
    '''

    def __init__(self, dataframe, dir_image, label_tag='label1', transform=None):
        self.df = dataframe
        self.transform = transform
        self.dir_image = dir_image
        self.label_tag = label_tag


    def __getitem__(self, index):
        galaxyID = self.df.iloc[[index]].galaxyID.values[0]
        file_img = os.path.join(self.dir_image, str(galaxyID) + '.jpg')
        image = Image.open(file_img)

        if self.transform:
            image = self.transform(image)

        label = self.df.iloc[[index]][self.label_tag].values[0]

        return image, label

    def __len__(self):
        return len(self.df)


collator=None
pin_mem=True
num_workers=8
world_size=1
rank=0
root_path=None
image_folder=None
training=True
copy_data=False
drop_last=True
subset_file=None

if transform is None:

    transform = transforms.ToTensor()
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
dataset_valid = GalaxyZooDataset(df, dir_image, label_tag='label1', transform=transform)
logger.info('ImageNet dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset=dataset_valid,
    num_replicas=world_size,
    rank=rank)
valid_data_loader = torch.utils.data.DataLoader(
    dataset_valid,
    collate_fn=collator,
    sampler=dist_sampler,
    batch_size=32,
    drop_last=drop_last,
    pin_memory=pin_mem,
    num_workers=num_workers,
    persistent_workers=False)

In [20]:
!mkdir -p /content/ijepa/data/GalaxyZoo

In [None]:
!unzip /content/drive/MyDrive/GalaxyZoo/archive.zip -d /content/ijepa/data/GalaxyZoo/

In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix


model.eval()
correct_top1 = 0
correct_top3 = 0
total = 0

all_preds = []
all_targets = []

with torch.no_grad():
     for images, labels in valid_data_loader:
         images = images.to(device, non_blocking=True)
         labels = labels.to(device, non_blocking=True)
         outputs = model(images)
         _, preds = outputs.topk(3, dim=1, largest=True, sorted=True)

         total += labels.size(0)
         correct_top1 += (preds[:, 0] == labels).sum().item()
         correct_top3 += (preds == labels.view(-1, 1)).sum().item()


         all_preds.extend(preds[:, 0].cpu().numpy())
         all_targets.extend(labels.cpu().numpy())

top1_acc = 100 * correct_top1 / total
top3_acc = 100 * correct_top3 / total
print(f'Validation Top-1 Accuracy: {top1_acc:.2f}%, Top-3 Accuracy: {top3_acc:.2f}%')

    # Confusion Matrix
num_classes = 8
conf_mat = confusion_matrix(all_targets, all_preds, labels=list(range(num_classes)))
conf_mat_normalized = conf_mat.astype(float) / conf_mat.sum(axis=1, keepdims=True)

conf_df = pd.DataFrame(conf_mat_normalized,
                        index=[f"True {i}" for i in range(num_classes)],
                        columns=[f"Pred {i}" for i in range(num_classes)])

print("\nNormalized Confusion Matrix (% per true class):")
print(conf_df.to_string(float_format="%.2f"))

Validation Top-1 Accuracy: 70.03%, Top-3 Accuracy: 94.68%

Normalized Confusion Matrix (% per true class):
        Pred 0  Pred 1  Pred 2  Pred 3  Pred 4  Pred 5  Pred 6  Pred 7
True 0    0.80    0.13    0.00    0.00    0.01    0.04    0.01    0.00
True 1    0.15    0.71    0.02    0.01    0.02    0.08    0.01    0.00
True 2    0.01    0.14    0.41    0.33    0.01    0.09    0.01    0.00
True 3    0.02    0.04    0.03    0.84    0.01    0.05    0.00    0.00
True 4    0.02    0.06    0.00    0.02    0.71    0.17    0.01    0.00
True 5    0.06    0.10    0.01    0.02    0.12    0.66    0.02    0.00
True 6    0.08    0.12    0.02    0.06    0.12    0.24    0.34    0.03
True 7    0.08    0.14    0.00    0.06    0.18    0.16    0.10    0.28
