# Galaxy Zoo 2: ConvNeXt Transfer Learning

## 1. Import Required Libraries and Utilities

In [61]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure, show
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.models import ConvNeXt_Tiny_Weights
import torch
from torchvision import transforms
import glob
from PIL import Image
import os
import time
import concurrent.futures
%matplotlib inline

## 2. Define Image Preprocessing and Tensor Saving Functions

In [62]:
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
def process_image(img_path_output_size):
    img_path, output_dir, size = img_path_output_size
    basename = os.path.splitext(os.path.basename(img_path))[0]
    out_path = os.path.join(output_dir, f"{basename}.pt")
    img = Image.open(img_path).convert('RGB')
    width, height = img.size
    left, top, right, bottom = 20, 20, width - 20, height - 20
    img_cropped = img.crop((left, top, right, bottom))
    img_resized = img_cropped.resize(size, Image.LANCZOS)
    tensor = transforms.ToTensor()(img_resized)
    tensor = norm(tensor)
    torch.save(tensor, out_path)

In [63]:
def save_tensor_images_threaded(input_dir, output_dir, size=(224, 224), num_workers=4):
    os.makedirs(output_dir, exist_ok=True)
    image_files = glob.glob(os.path.join(input_dir, '*.jpg'))
    print(f"Found {len(image_files)} images.")
    args = [(img_path, output_dir, size) for img_path in image_files]
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        results = list(executor.map(process_image, args))
    for res in results[:20]:
        print(res)
    print(f"Finished saving tensors for {len(image_files)} images.")

## 3. Dataset class creation
Can handle both training and test data

In [None]:
class GalaxyZooTensorDataset(Dataset):
    def __init__(self, csv_file, tensor_dir):
        self.tensor_dir = tensor_dir
        if csv_file is not None:
            self.df = pd.read_csv(csv_file)
            self.ids = self.df.iloc[:, 0].values
            self.labels = self.df.iloc[:, 1:].values.astype(np.float32)
            self.has_labels = True
        else:
            # For test set, infer IDs from tensor filenames
            self.df = None
            self.ids = [os.path.splitext(f)[0] for f in sorted(os.listdir(tensor_dir)) if f.endswith('.pt')]
            self.has_labels = False
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        galaxy_id = int(self.ids[idx])
        tensor_path = os.path.join(self.tensor_dir, f"{galaxy_id}.pt")
        image = torch.load(tensor_path, weights_only=True)
        if self.has_labels:
            labels = self.labels[idx]
            return image, labels
        else:
            return image, galaxy_id

## 4. Create DataLoader

In [65]:
# Load training solutions and dataset
csv_file = './training_solutions_rev1/training_solutions_rev1.csv'
tensor_dir = './images_training_rev1/images_training_resized'
dataset = GalaxyZooTensorDataset(csv_file, tensor_dir)

batch_size = 128

full_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)

# For validation, split a portion from the full dataset
from torch.utils.data import random_split
total = len(dataset)
val_size = int(0.2 * total)
train_size = total - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

## 5. Load and Modify ConvNeXt Model
Load the ConvNeXt model, freeze all layers except the classifier, and modify the final layer to match the number of classes.

In [66]:
convnext = models.convnext_tiny(weights=ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
num_classes = dataset.labels.shape[1]

# Freeze all parameters
for param in convnext.parameters():
    param.requires_grad = False

# Replace the final layer
in_features = convnext.classifier[2].in_features
convnext.classifier[2] = nn.Linear(in_features, num_classes)

# Unfreeze only the last layer of the classifier
for param in convnext.classifier[2].parameters():
    param.requires_grad = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
convnext = convnext.to(device)

## 6. Set Up Loss Function and Optimizer

In [67]:
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, convnext.parameters()), lr=1e-4)

In [68]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
convnext = convnext.to(device)
print(f"Device: {device}")

Device: cuda


In [69]:
# reload a checkpointed model
if os.path.exists("checkpoint_last_layer.pth"):
    checkpoint = torch.load("checkpoint_last_layer.pth")
    convnext.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    loss = checkpoint['loss']

## 7. Train Model

In [74]:
def train(model, loader, optimizer, criterion, device, num_epochs=12):
    model.train()
    losses = []
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_start = time.time()
        for batch_idx, (images, labels) in enumerate(loader):
            batch_start = time.time()
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            if (batch_idx + 1) % 30 == 0:
                batch_time = time.time() - batch_start
                print(f"  Batch {batch_idx+1}/{len(loader)}: Loss={loss.item():.4f}, Batch Time={batch_time:.2f}s")
        avg_loss = running_loss / len(loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}: Loss={avg_loss:.4f}, Time={time.time()-epoch_start:.2f}s")
        losses.append(avg_loss)
    return losses

## 8. Save Model Checkpoints

In [75]:
def save_checkpoint(model, optimizer, train_losses, filename):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved as {filename}")

In [None]:
file_name = "normal"
train_losses = train(convnext, full_loader, optimizer, criterion, device, num_epochs=20)
save_checkpoint(convnext, optimizer, train_losses, f"checkpoint_{file_name}.pth")

  Batch 30/482: Loss=0.3344, Batch Time=3.74s
  Batch 60/482: Loss=0.3292, Batch Time=3.60s
  Batch 90/482: Loss=0.3145, Batch Time=3.32s
  Batch 120/482: Loss=0.2944, Batch Time=4.39s
  Batch 150/482: Loss=0.2998, Batch Time=4.21s
  Batch 180/482: Loss=0.2885, Batch Time=2.66s
  Batch 210/482: Loss=0.3007, Batch Time=3.10s
  Batch 240/482: Loss=0.2976, Batch Time=3.61s


In [None]:
save_checkpoint(convnext, optimizer, train_losses, f"checkpoint_{file_name}.pth")

In [None]:
loss_df = pd.DataFrame({'epoch': list(range(1, len(train_losses)+1)), 'loss': train_losses})
loss_df.to_csv('train_losses.csv', index=False)
print('Training losses saved to train_losses.csv')

## 9. Evaluate Model on Validation Set

In [None]:
convnext.eval()
val_losses = []
all_preds = []
all_targets = []
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        outputs = convnext(images)
        loss = criterion(outputs, labels)
        val_losses.append(loss.item() * images.size(0))
        preds = torch.sigmoid(outputs).cpu().numpy()
        all_preds.append(preds)
        all_targets.append(labels.cpu().numpy())

val_loss = np.sum(val_losses) / len(val_loader.dataset)
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

print(f"Validation Loss: {val_loss:.4f}")

rmse = np.sqrt(np.mean((all_preds - all_targets) ** 2))
print(f"Validation RMSE: {rmse:.4f}")

np.savez('val_predictions.npz', preds=all_preds, targets=all_targets)
print('Validation predictions and targets saved to val_predictions.npz')

## 10. Evaluate on test dataset and save submission

In [None]:
test_dataset = GalaxyZooTensorDataset(csv_dir=None, tensor_dir='./images_test_rev1/images_test_resized')
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=0, pin_memory=True)

convnext.eval()
all_predictions = []
all_galaxy_ids = []

with torch.no_grad():
    for images, galaxy_ids in test_loader:
        images = images.to(device, non_blocking=True)
        outputs = convnext(images)
        probs = torch.sigmoid(outputs)
        all_predictions.append(probs.cpu().numpy())
        all_galaxy_ids.extend(galaxy_ids)

predictions = np.concatenate(all_predictions, axis=0)

columns = ['GalaxyId']
questions = {1: 3, 2: 2, 3: 2, 4: 2, 5: 4, 6: 2, 7: 3, 8: 7, 9: 3, 10: 3, 11: 6}
for q, count in questions.items():
    for i in range(1, count + 1):
        columns.append(f'Class{q}.{i}')

submission_df = pd.DataFrame(predictions, columns=columns[1:])
submission_df.insert(0, 'GalaxyId', all_galaxy_ids)
print(submission_df.head())
submission_df.to_csv('submission_normal.csv', index=False)