## Convolutional Neural Network (CNN)

In [3]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing import image

import numpy as np

### Architecture Overview 
The goal of our network will be to learn which features to extract from our images and then determine which features belong to which class. To start we will attempt to determine which features are stars. As we progress through this process we may attempt to classify more objects in our images.

To do this we will use the CNN base architecture to start
CNN architecture consists of two main parts 
* a Convolutional Base
* a Dense Head

In a CNN the convolutional base is used to extract the required features from images, and the head is used to determine the correct class for those features. 

Modern classifiers utilize transfer learning, our CNN will attempt to follow this common framework.


In [13]:
old_base = keras.applications.VGG16(
    include_top=False,
    weights='imagenet',
    input_shape = (300,300,3), #defaults to (224, 224, 3) 224x224 colour. Update with actual dimensions
    classes = 2, #placeholder until classifiers can be set
    classifier_activation = 'sigmoid', #sigmoid for binary classification, softmax if more are added
)

for layer in old_base.layers:
    layer.trainable = False #freeze base layers

AstrID = models.Sequential([
    old_base,
    layers.Dense(300, activation ='relu'),
    layers.Dropout(rate = 0.3),
    layers.Dense(2, activation = "sigmoid")
])

AstrID.compile(
    optimizer='adam',
    loss = 'BinaryCrossentropy',
    metrics=['accuracy']
)

In [9]:
## Used this code to confirm the channels present in the .fits files 
# While they appear as greyscale they are 3 channel RGB images. 
# If they were not 3 channel RGB images and instead 1 channel greyscale 
# the current recommended method is to increase the channel size and then duplicate the image into each channel


img = image.load_img('../space_test/images/Crab_Nebula/Crab_Nebula.fits')
img_array = image.img_to_array(img)

print("Image shape:", img_array.shape)

Image shape: (300, 300, 3)
