In [None]:
import h5py
import tensorflow as tf
import keras as keras
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import utils
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping
from skimage.transform import resize
from tensorflow.keras.utils import to_categorical
from PIL import Image
import cv2

In [None]:
length = 17736

In [None]:
# Classes:
"""
Irregular: 0
Merging: 1
Smooth/Round: 2
Spiral: 3
Edge-On: 4
"""

In [None]:
with h5py.File('/Path/to/File', 'r') as F:
    images = np.array(F['images'])
    labels = np.array(F['ans'])

In [None]:
# Gray Scaling

gray_images = []
for i in range(len(images)):
    gray_img = cv2.cvtColor(images[i], cv2.COLOR_BGR2GRAY)
    gray_images.append(gray_img)
    
images = np.array(gray_images)

In [None]:
# normalize images
images = images / 255

In [None]:
# Resize images to 64x64
images_resized = np.array([resize(img, (128, 128)) for img in images])

In [None]:
# Reducing number of classifications
for i in range(len(labels)):
    if labels[i] == 3 or labels[i] == 4:
        labels[i] = 2
    if labels[i] == 5 or labels[i] == 6 or labels[i] == 7:
        labels[i] = 3
    if labels[i] == 8 or labels[i] == 9:
        labels[i] = 4

In [None]:
# Split the data into a training set and a temporary set using the train_test_split function.
images_train, images_temp, labels_train, labels_temp = train_test_split(images_resized, labels, test_size=0.3, 
                                                                        random_state=42)

In [None]:
# Split the temporary set into a validation set and a test set.
images_val, images_test, labels_val, labels_test = train_test_split(images_temp, labels_temp, test_size=0.5, 
                                                                   random_state=42)

In [None]:
# Network Architecture
input_shape = (128, 128, 1) 

model = Sequential([
    Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.2),
    
    Conv2D(64, kernel_size=(3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.3),
    
    Conv2D(128, kernel_size=(3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D(pool_size=(2, 2)),
    Dropout(0.4),
    
    Flatten(),
    Dense(128, activation='relu'),
    BatchNormalization(),
    Dropout(0.5),
    
    Dense(5, activation='softmax')
])

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
# One-hot encoding the labels
labels_train_encoded = to_categorical(labels_train)
labels_val_encoded = to_categorical(labels_val)
labels_test_encoded = to_categorical(labels_test)

In [None]:
# Add early stopping
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# Training the model
history = model.fit(images_train, labels_train_encoded, 
                    validation_data=(images_val, labels_val_encoded),
                    epochs=10, batch_size=32)

In [None]:
# Testing
loss, accuracy = model.evaluate(images_test, labels_test_encoded)
print("Test Accuracy: ", accuracy)
print("Loss", loss)

In [None]:
# New image testing 

def GalaxyPredict(file):
    
    # Load the image
    img = Image.open(file)
    
    # Numpy array conversion
    img = np.array(img)
    
    # Grayscale
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    
    # normalize
    img = img / 255.0
    
    # Resize the image
    img = cv2.resize(img, (128, 128))  # Change this line

    # Add an extra dimension because the model expects a batch
    img_array = np.expand_dims(img, axis=0)

    #Predict galaxy
    predictions = model.predict(img_array)
    predicted_class = np.argmax(predictions)
    print("The predicted class is:", predicted_class)
    plt.imshow(img, cmap='gray')
    plt.show()

GalaxyPredict('/Path/to/Image')
