In [None]:
# Install necessary packages
!pip install torch torchvision pyyaml


In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!git clone https://github.com/gj33/IJEPAGalaxyZoo.git

Cloning into 'IJEPAGalaxyZoo'...
remote: Enumerating objects: 53, done.[K
remote: Counting objects: 100% (53/53), done.[K
remote: Compressing objects: 100% (42/42), done.[K
remote: Total 53 (delta 15), reused 27 (delta 6), pack-reused 0 (from 0)[K
Receiving objects: 100% (53/53), 63.17 KiB | 10.53 MiB/s, done.
Resolving deltas: 100% (15/15), done.


In [None]:
import torch
import torch.nn as nn
import sys
sys.path.append('/content/IJEPAGalaxyZoo')  # Add the IJepa src directory to the Python path

import src.models.vision_transformer as vit




In [None]:
import torch
import torch.nn as nn
import sys
sys.path.append('/content/IJEPAGalaxyZoo/src')  # Ensure the path is correct

from models.vision_transformer import vit_tiny


In [None]:
import torch.nn as nn

class GalaxyClassifier(nn.Module):
    def __init__(self, backbone, embed_dim=192, num_classes=8):
        super().__init__()
        self.backbone = backbone
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        patch_embeddings = self.backbone(x)     # [B, 256, 192]
        pooled = patch_embeddings.mean(dim=1)   # [B, 192]
        logits = self.classifier(pooled)        # [B, 8]
        return logits

In [None]:

base_vit_model = vit_tiny(
    img_size=[224],
    patch_size=14,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    drop_path_rate=0.0,
    use_checkpoint=False,
)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GalaxyClassifier(backbone=base_vit_model, embed_dim=192, num_classes=8).to(device)



checkpoint_path = '/content/drive/MyDrive/ijepa_logs/GalaxyZooOutput1/jepa-ep30.pth.tar'

checkpoint = torch.load(checkpoint_path, map_location='cpu')


print("Checkpoint keys:", checkpoint.keys())


state_dict = checkpoint


state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}


model.load_state_dict(state_dict, strict=False)

print("Checkpoint loaded successfully into the GalaxyClassifier model.")

Checkpoint keys: dict_keys(['encoder', 'predictor', 'target_encoder', 'opt', 'scaler', 'epoch', 'loss', 'batch_size', 'world_size', 'lr'])
Checkpoint loaded successfully into the GalaxyClassifier model.


In [None]:
# Number of classes in TinyImageNet
num_classes = 8

# Replace the head with a new classification head
model.classifier = nn.Linear(model.backbone.embed_dim, num_classes)

In [None]:
from torchvision import datasets, transforms

# Data transformations
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet mean
        std=[0.229, 0.224, 0.225],   # ImageNet std
    ),
])

transform_val = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])



In [None]:
import os
import subprocess
import time

import numpy as np

from logging import getLogger

import torch
import torchvision
import pandas as pd

from PIL import Image
from torch.utils.data import Dataset

from torchvision import transforms

_GLOBAL_SEED = 0
logger = getLogger()

dir_cat = "/content/drive/MyDrive/GalaxyZoo/"
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
df = pd.read_csv(dir_cat+'gz2_train.csv')
df2 = pd.read_csv(dir_cat+'gz2_valid.csv')
class GalaxyZooDataset(Dataset):
    '''Galaxy Zoo 2 image dataset
        Args:
            dataframe : pd.dataframe, outputs from the data_split function
                e.g. df_train / df_valid / df_test
            dir_image : str, path where galaxy images are located
            label_tag : str, class label system to be used for training
                e.g. label_tag = 'label1' / 'label2' / 'label3' / 'label4'
    '''

    def __init__(self, dataframe, dir_image, label_tag='label1', transform=None):
        self.df = dataframe
        self.transform = transform
        self.dir_image = dir_image
        self.label_tag = label_tag


    def __getitem__(self, index):
        galaxyID = self.df.iloc[[index]].galaxyID.values[0]
        file_img = os.path.join(self.dir_image, str(galaxyID) + '.jpg')
        image = Image.open(file_img)

        if self.transform:
            image = self.transform(image)

        label = self.df.iloc[[index]][self.label_tag].values[0]

        return image, label

    def __len__(self):
        return len(self.df)


collator=None
pin_mem=True
num_workers=8
world_size=1
rank=0
root_path=None
image_folder=None
training=True
copy_data=False
drop_last=True
subset_file=None

if transform is None:
    # If no transform provided, at least convert images to Tensors
    transform = transforms.ToTensor()
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
dataset = GalaxyZooDataset(df, dir_image, label_tag='label1', transform=transform)
logger.info('ImageNet dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset=dataset,
    num_replicas=world_size,
    rank=rank)
train_data_loader = torch.utils.data.DataLoader(
    dataset,
    collate_fn=collator,
    sampler=dist_sampler,
    batch_size=32,
    drop_last=drop_last,
    pin_memory=pin_mem,
    num_workers=num_workers,
    persistent_workers=False)

In [None]:
import os
import subprocess
import time

import numpy as np

from logging import getLogger

import torch
import torchvision
import pandas as pd

from PIL import Image
from torch.utils.data import Dataset

from torchvision import transforms

_GLOBAL_SEED = 0
logger = getLogger()

dir_cat = "/content/drive/MyDrive/GalaxyZoo/"
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
df = pd.read_csv(dir_cat+'gz2_valid.csv')
class GalaxyZooDataset(Dataset):
    '''Galaxy Zoo 2 image dataset
        Args:
            dataframe : pd.dataframe, outputs from the data_split function
                e.g. df_train / df_valid / df_test
            dir_image : str, path where galaxy images are located
            label_tag : str, class label system to be used for training
                e.g. label_tag = 'label1' / 'label2' / 'label3' / 'label4'
    '''

    def __init__(self, dataframe, dir_image, label_tag='label1', transform=None):
        self.df = dataframe
        self.transform = transform
        self.dir_image = dir_image
        self.label_tag = label_tag


    def __getitem__(self, index):
        galaxyID = self.df.iloc[[index]].galaxyID.values[0]
        file_img = os.path.join(self.dir_image, str(galaxyID) + '.jpg')
        image = Image.open(file_img)

        if self.transform:
            image = self.transform(image)

        label = self.df.iloc[[index]][self.label_tag].values[0]

        return image, label

    def __len__(self):
        return len(self.df)


collator=None
pin_mem=True
num_workers=8
world_size=1
rank=0
root_path=None
image_folder=None
training=True
copy_data=False
drop_last=True
subset_file=None

if transform is None:
    # If no transform provided, at least convert images to Tensors
    transform = transforms.ToTensor()
dir_image = '/content/ijepa/data/GalaxyZoo/images_gz2/images'
dataset_valid = GalaxyZooDataset(df, dir_image, label_tag='label1', transform=transform)
logger.info('ImageNet dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
    dataset=dataset_valid,
    num_replicas=world_size,
    rank=rank)
valid_data_loader = torch.utils.data.DataLoader(
    dataset_valid,
    collate_fn=collator,
    sampler=dist_sampler,
    batch_size=32,
    drop_last=drop_last,
    pin_memory=pin_mem,
    num_workers=num_workers,
    persistent_workers=False)

In [None]:
# Unfreeze all layers
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
for param in model.parameters():
    param.requires_grad = True

# Use a smaller learning rate for the encoder parameters
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)

In [None]:
!mkdir -p /content/ijepa/data/GalaxyZoo

In [None]:
!unzip /content/drive/MyDrive/GalaxyZoo/archive.zip -d /content/ijepa/data/GalaxyZoo/

In [None]:
import os
#create directory to store checkpoints
checkpoint_dir = '/content/drive/MyDrive/GalaxyZoo/galaxyzoocheckpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
num_epochs = 10  # Adjust as needed

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, labels in train_data_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        model.to(device) # Ensure model is on the correct device

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    # Evaluate on validation set
    model.eval()
    correct_top1 = 0
    correct_top3 = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_data_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            model.to(device) # Ensure model is on the correct device
            outputs = model(images)
            _, preds = outputs.topk(3, dim=1, largest=True, sorted=True)
            total += labels.size(0)
            correct_top1 += (preds[:, 0] == labels).sum().item()
            correct_top3 += (preds == labels.view(-1, 1)).sum().item()
    top1_acc = 100 * correct_top1 / total
    top3_acc = 100 * correct_top3 / total
    print(f'Validation Top-1 Accuracy: {top1_acc:.2f}%, Top-3 Accuracy: {top3_acc:.2f}%')
    model.eval()
    correct_top1 = 0
    correct_top3 = 0
    total = 0

    # Evaluate on training set too to test for overfitting
    with torch.no_grad():
        for images, labels in train_data_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            model.to(device) # Ensure model is on the correct device
            outputs = model(images)
            _, preds = outputs.topk(3, dim=1, largest=True, sorted=True)
            total += labels.size(0)
            correct_top1 += (preds[:, 0] == labels).sum().item()
            correct_top3 += (preds == labels.view(-1, 1)).sum().item()
    top1_acc = 100 * correct_top1 / total
    top3_acc = 100 * correct_top3 / total
    print(f'Training Top-1 Accuracy: {top1_acc:.2f}%, Top-3 Accuracy: {top3_acc:.2f}%')
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'models_epoch_{epoch+1}.pth'))

Epoch [1/10], Loss: 1.4429
Validation Top-1 Accuracy: 52.21%, Top-3 Accuracy: 86.59%
Training Top-1 Accuracy: 52.78%, Top-3 Accuracy: 86.76%
Epoch [2/10], Loss: 1.1501
Validation Top-1 Accuracy: 58.64%, Top-3 Accuracy: 89.87%
Training Top-1 Accuracy: 59.32%, Top-3 Accuracy: 89.89%
Epoch [3/10], Loss: 1.0691
Validation Top-1 Accuracy: 60.81%, Top-3 Accuracy: 90.85%
Training Top-1 Accuracy: 60.84%, Top-3 Accuracy: 90.75%
Epoch [4/10], Loss: 1.0346
Validation Top-1 Accuracy: 60.69%, Top-3 Accuracy: 90.65%
Training Top-1 Accuracy: 60.66%, Top-3 Accuracy: 90.97%
Epoch [5/10], Loss: 1.0078
Validation Top-1 Accuracy: 63.37%, Top-3 Accuracy: 91.90%
Training Top-1 Accuracy: 63.26%, Top-3 Accuracy: 91.89%
Epoch [6/10], Loss: 0.9828
Validation Top-1 Accuracy: 63.44%, Top-3 Accuracy: 91.93%
Training Top-1 Accuracy: 63.31%, Top-3 Accuracy: 92.01%
Epoch [7/10], Loss: 0.9660
Validation Top-1 Accuracy: 64.71%, Top-3 Accuracy: 92.44%
Training Top-1 Accuracy: 65.13%, Top-3 Accuracy: 92.71%
Epoch [8/10],