<a href="https://colab.research.google.com/github/harshanand9891/Galaxy-Morphology-using-AIML/blob/main/galaxy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import utils, layers, models
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import itertools

In [None]:
!pip install astroNN

**#LOAD DATA**

In [None]:
from astroNN.datasets import load_galaxy10
images, labels = load_galaxy10()

/root/.astroNN/datasets/Galaxy10_DECals.h5 was found!


# Ensure data is of the correct type


In [None]:
labels = labels.astype(np.float32)
images = images.astype(np.float32)

# Split data into training and test sets

In [None]:
train_x, test_x = train_test_split(np.arange(labels.shape[0]), test_size=0.1)
train_images, train_labels = images[train_x], labels[train_x]
test_images, test_labels = images[test_x], labels[test_x]


# Define image labels

In [None]:
imageLabel = [
    "Disturbed", "Merging", "Round Smooth", "In-between Round Smooth",
    "Cigar Shaped Smooth", "Barred Spiral", "Unbarred Tight Spiral",
    "Unbarred Loose Spiral", "Edge-on Galaxies without Bulge",
    "Edge-on Galaxies with Bulge"
]

Plot sample images

In [None]:
fig, axes = plt.subplots(ncols=10, nrows=10, figsize=(20, 20))
index = 0
for i in range(10):
    for j in range(10):
        axes[i, j].set_title(imageLabel[int(labels[index])])
        axes[i, j].imshow(images[index].astype(np.uint8))
        axes[i, j].get_xaxis().set_visible(False)
        axes[i, j].get_yaxis().set_visible(False)
        index += 1
plt.show()


Vision Transformer parameters

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 70
image_size = 72
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]

Data augmentation

In [None]:
data_augmentation = models.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)

Adapt normalization layer to the training data

In [None]:
data_augmentation.layers[0].adapt(train_images)


Data generator for augmentation

In [None]:
datagen = ImageDataGenerator(
    rotation_range=30, zoom_range=0.2, width_shift_range=0.1, height_shift_range=0.1,
    horizontal_flip=True, vertical_flip=False
)
datagen.fit(train_images)


MLP function

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

Patch and PatchEncoder classes

In [None]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

Create Vision Transformer classifier

In [None]:
def create_vit_classifier(input_shape, num_classes):
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = layers.Add()([attention_output, encoded_patches])
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    logits = layers.Dense(num_classes)(features)

    model = models.Model(inputs=inputs, outputs=logits)
    return model

Plot confusion matrix

In [None]:
def plot_confusion_matrix(cm, class_names):
    figure = plt.figure(figsize=(10, 10))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion Matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment="center")

    plt.tight_layout()
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')

Run experiment

In [None]:
def run_experiment(model, X_train, y_train, X_test, y_test):
    optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=weight_decay)

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.CategoricalAccuracy(name="accuracy")],
    )

    history = model.fit(
        x=X_train, y=y_train, batch_size=batch_size, epochs=num_epochs, validation_split=0.1,
    )

    y_test_arg = np.argmax(y_test, axis=1)
    Y_pred = np.argmax(model.predict(X_test), axis=1)

    print('Confusion Matrix')
    cm = confusion_matrix(y_test_arg, Y_pred)
    print(cm)

    target_names = imageLabel
    print(classification_report(y_test_arg, Y_pred, target_names=target_names))

    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    pred = model.predict(X_test)
    pred_label = np.argmax(pred, axis=1)
    actual_label = np.argmax(y_test, axis=1)

    cm = confusion_matrix(pred_label, actual_label)
    plot_confusion_matrix(cm, imageLabel)

    fig, axes = plt.subplots(ncols=7, nrows=3, sharex=False, sharey=True, figsize=(17, 8))
    index = 0
    for i in range(3):
        for j in range(7):
            axes[i, j].set_title(f'Actual: {imageLabel[actual_label[index]]}\nPredicted: {imageLabel[pred_label[index]]}')
            axes[i, j].imshow(test_images[index].astype(np.uint8), cmap='gray')
            axes[i, j].get_xaxis().set_visible(False)
            axes[i, j].get_yaxis().set_visible(False)
            index += 1
    plt.show()

Define input shape and number of classes

In [None]:
input_shape = (72, 72, 3)  # Example shape, adjust based on your data
num_classes = len(imageLabel)

Create and run the Vision Transformer classifier

In [None]:
vit_classifier = create_vit_classifier(input_shape, num_classes)
history = run_experiment(vit_classifier, train_images, train_labels, test_images, test_labels)
