### Import Packages

In [1]:
import warnings
warnings.filterwarnings('ignore')

import os
import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from skimage.transform import resize
from skimage.io import imread, imsave
from tensorflow.keras.models import load_model
from sklearn.model_selection import train_test_split

from astroNN.datasets import load_galaxy10sdss
from astroNN.datasets.galaxy10sdss import galaxy10cls_lookup, galaxy10_confusion

### Save 10 Test Images

In [2]:
# Load Data
pre_images, labels = load_galaxy10sdss()

C:\Users\weigfan\.astroNN\datasets\Galaxy10.h5 was found!


In [3]:
# Train/Test Split
train_idx, test_idx = train_test_split(np.arange(labels.shape[0]), test_size=0.1, random_state=20050531)
train_images, train_labels = pre_images[train_idx], labels[train_idx]
test_images, test_labels = pre_images[test_idx], labels[test_idx]

In [46]:
# Save image in high resolution
save_dir = "./images"
for idx in tqdm(range(test_images.shape[0]), total=10, desc="Save"):
    high_res_image = resize(test_images[idx], (500, 500), anti_aliasing=False, preserve_range=True).astype(np.uint8)
    imsave(os.path.join(save_dir, f"image_{idx}.png"), high_res_image, check_contrast=False)
    if idx == 10:
        break

Save:   0%|          | 0/10 [00:00<?, ?it/s]

### Classify

In [14]:
# Load image
img_path = "./images/image_0.png"
img = imread(img_path)
resized_img = resize(img, (69, 69), anti_aliasing=False, preserve_range=True).astype(np.uint8)
resized_img = resized_img[:, :, :3]
image = np.expand_dims(resized_img, axis=0).astype(np.float32) / 255.0

In [15]:
# Load model
cnn_model = load_model("cnn_model.h5")

In [16]:
# Predict
predicted_labels = cnn_model.predict(image)
prediction_class = np.argmax(predicted_labels, axis=1)
prediction_class



array([8], dtype=int64)

In [20]:
# Lookup
galaxy10cls_lookup(prediction_class[0])

'Disk, Face-on, Medium Spiral'

In [18]:
# Test Labels
test_labels[:10]

array([8, 2, 8, 1, 1, 0, 1, 2, 1, 0], dtype=uint8)