# Galaxy Classification using PyTorch
## Matthew Bartnik

### Importing Initial Modules

In [None]:
import pandas as pd
import os
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

### Inspecting Data

In [None]:
working_folder = 'images_training_rev1'
data_labels = pd.read_csv('labels.csv')
display(data_labels)

### Creating Dataset for Network

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision import transforms

class GalaxyDataset(Dataset):
  def __init__(self, images_folder, labels_path, img_size, num_answers):
    self.img_size = img_size
    self.image_labels = pd.read_csv(labels_path)
    self.images_folder = images_folder

  def __len__(self):
    return len(self.image_labels)

  def transform(self, image):
    # transformations to apply to image (tensor) after loading
    # convert from integer to float for normalization
    image = image.float()

    # normalize it 
    image = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(image)
    # finally, resize it
    image = transforms.Resize(self.img_size)(image)

    return image

  def __getitem__(self, idx):
    # get galaxy ID of image
    galaxy_ID = self.image_labels.iloc[idx, 0]
    # print(galaxy_ID)

    # get filepath of image and load it
    image_filename = str(galaxy_ID) + '.jpg'
    image_path = os.path.join(self.images_folder, image_filename)

    # print(image_path)

    # load image and convert to torch Tensor (array-like)
    image = read_image(image_path)

    # convert image to torch Tensor, normalize and resize
    image = self.transform(image)

    # get label for image (probabilities of each of the classes) 
    class_probs = [self.image_labels.iloc[idx, col_idx] for col_idx in range(1, num_answers + 1)]
    label = torch.Tensor(class_probs)

    return image, label

### Splitting Data

In [None]:
# build dataset and split it into train and test sets

image_size = 32 # should be a sufficient size for shape classification,
num_answers = 37 # (or 37 for full Galaxy Zoo decision tree)

# but try messing around with it!
galaxy_dataset = GalaxyDataset('images_training_rev1', 'labels.csv', image_size, num_answers)

# split full dataset into training set and (unseen) test set
train_fraction = 0.9
test_fraction = 0.1
dataset_size = len(galaxy_dataset)
print(dataset_size)

num_train = int(train_fraction * dataset_size)
num_test = dataset_size - num_train
print(num_train, num_test)

train_dataset, test_dataset = torch.utils.data.random_split(galaxy_dataset, [num_train, num_test])

### Looking at Data

In [None]:
example_batch_size = 5 # number of images and labels to load at once
example_loader = DataLoader(train_dataset, 
                                     batch_size=example_batch_size, 
                                     shuffle=True # images are loaded in random order
                          )


# helper function for plotting a batch of images
from torchvision.utils import make_grid

def plot_imgbatch(imgs):
    imgs = imgs.cpu()
    imgs = imgs.type(torch.IntTensor)
    plt.figure(figsize=(15, 3*(imgs.shape[0])))
    grid_img = make_grid(imgs, nrow=5)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.show()
    
for batch_index, (images, labels) in enumerate(example_loader):
  print('\n\n\nimage batch {}:'.format(batch_index))
  plot_imgbatch(images)
  print('labels of images from batch {}:'.format(batch_index))
  print(labels)
  if batch_index > 2:
    # plot just a few examples
    break

### Set-up on GPU (CPU can also work, but will take a lot longer)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Setting up Data Loader

In [None]:
train_batchsize = 1500 # depends on your computation hardware:
#  try about 500 for image size of 32, remember that as image size increases, 
# batch size should decrease
eval_batchsize = 100 # batch size for evaluating on test set: can be small
train_loader = DataLoader(train_dataset, 
                                     batch_size=train_batchsize, 
                                     shuffle=True
                                     # images are loaded in random order
                                                )
test_loader = DataLoader(test_dataset, 
                        batch_size=eval_batchsize)

### Neural Network Setup

In [None]:
from torchvision.models import resnet18
# instantiate classifier from ResNet-18 neural network models
net = resnet18()

# modify final layer of classifier to output correct number of answers
net.fc = torch.nn.Linear(in_features=512, out_features=num_answers, bias=True)


# load neural net to GPU device
net = net.to(device)
criterion = torch.nn.MSELoss()

### Loss Function and Optimizer Setup

In [None]:
criterion = torch.nn.MSELoss()

In [None]:
learning_rate = 0.001
#optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

epochs = 10

### Measuring Performance

In [None]:
from datetime import datetime

saved_model_folder = os.path.join(working_folder, 'saved_models')
if not os.path.exists(saved_model_folder):
  os.makedirs(saved_model_folder)

now = datetime.now()
date_time = now.strftime("%m_%d_%Y__%H_%M_%S")
saved_model_filename = 'trainedmodel_{}.pt'.format(date_time)
saved_model_path = os.path.join(saved_model_folder, saved_model_filename)

### Training Network and Looking at Accuracy

In [None]:
best_test_error = float('inf')
# make the best test error infinitely high by default
# (so that the first model is saved)

for epoch in range(epochs):
	# set network to training mode, so that its parameters can be changed
	net.train()

	# print training info
	print("### Epoch {}:".format(epoch))

	# initialize statistics needed to compute overall error/loss
	train_error = 0
	test_error = 0

	# iterate over the training set once
	for batch_index, (inputs, targets) in tqdm(enumerate(train_loader), total=len(train_loader.dataset)//train_batchsize):
		# load the data onto the computation device.
		# inputs are a tensor of shape: 
		#   (batch size, number of channels, image height, image width).
		# targets are a tensor of one-hot-encoded class labels for the inputs, 
		#   of shape (batch size, number of classes)
		# in other words, 
		inputs = inputs.to(device)
		targets = targets.to(device)

		# reset changes (gradients) to parameters
		optimizer.zero_grad()

		# get the network's predictions on the training set batch
		predictions = net(inputs)

		# evaluate the error, and estimate 
		#   how much to change the network parameters
		loss = criterion(predictions, targets)
		loss.backward()
		train_error += loss

		# change parameters
		optimizer.step()
	
	# overall results on training set
	# error in predicted probabilities
	avg_loss_train = train_error / (batch_index + 1)
	print("Average Training Error (MSE of class probabilities): %.4f" %(avg_loss_train))

	
	# get results for this epoch on test set

	# evaluating, not training
	net.eval()
 
	for batch_index, (inputs, targets) in tqdm(enumerate(test_loader), total=len(test_loader.dataset)//eval_batchsize):
		inputs = inputs.to(device)
		targets = targets.to(device)
		# get the network's predictions on the training set batch
		predictions = net(inputs)

		# evaluate the error
		loss = criterion(predictions, targets)
		test_error += loss

	# error in predicted probabilities
	avg_loss_test = test_error / (batch_index + 1)
	print("Average Test Error (MSE of class probabilities): %.4f" %(avg_loss_test))
 

  # save trained network as file (make sure you can find it!)
  # if the test error was improved/made lower
	if avg_loss_test < best_test_error:
		print('better test performance, saving model...')
		torch.save(net.state_dict(), saved_model_path)
	
		best_test_error = avg_loss_test

		print(torch.save(net.state_dict(), saved_model_path))

### Saving Model

In [None]:
saved_model_folder = os.path.join(working_folder, 'saved_models')

# choose model filename here (****.pt)
model_filename = 'trainedmodel_04_19_2023.pt'

saved_model_path = os.path.join(saved_model_folder, model_filename)
net.load_state_dict(torch.load(saved_model_path))

# test it

test_error = 0
net.eval()
 
for batch_index, (inputs, targets) in tqdm(enumerate(test_loader), total=len(test_loader.dataset)//eval_batchsize):
  inputs = inputs.to(device)
  targets = targets.to(device)
  # get the network's predictions on the training set batch
  predictions = net(inputs)

  # evaluate the error
  loss = criterion(predictions, targets)
  test_error += loss

# error in predicted probabilities
avg_loss_test = test_error / (batch_index + 1)
print("\n\nAverage Test Error (MSE of class probabilities): %.4f" %(avg_loss_test))