In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
import os
import numpy as np
from tensorflow.keras.utils import to_categorical
from PIL import Image

# extract the quantum data from the .png filename and label the data pre-training
def load_quantum_state_data(image_folder_path, n_values):
    images = []
    labels = []
    skipped_files = []

    for filename in os.listdir(image_folder_path):
        if filename.endswith('.png'):
            try:
                # extract n, l, m from the filename
                n, l, m = map(int, filename.rstrip('.png').split('_'))
                
                # if n is in n_values, process the image
                if n in n_values:
                    image_file = os.path.join(image_folder_path, filename)
                    image = Image.open(image_file).convert('RGB')  # Ensure image is RGB
                    image = image.resize((128, 128))  # Resize to 128x128 pixels
                    images.append(np.array(image))
                    labels.append((n, l, m))
                else:
                    skipped_files.append(filename)
            except ValueError:
                skipped_files.append(filename)
    
    # debugging
    print(f"Loaded {len(images)} images and {len(labels)} labels.")
    print(f"Skipped {len(skipped_files)} files due to format issues or filtering: {skipped_files}")

    # extract only the principal quantum number n for labels
    n_labels = np.array([n for n, l, m in labels])
    
    return np.array(images), n_labels

# load training data for n = 1 to n = 4
training_image_folder_path = 'training-data'
n_train_values = range(1, 6)
x_train, y_train = load_quantum_state_data(training_image_folder_path, n_train_values)

# debugging
print(f"x_train shape: {x_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"Unique n values found: {np.unique(y_train)}")

# normalize images
x_train = x_train / 255.0

# set number of n classes and encode data labels
num_n_classes = np.max(y_train) + 1
y_train = tf.keras.utils.to_categorical(y_train, num_classes=num_n_classes)

# build model
base_model = tf.keras.applications.VGG16(include_top=False, input_shape=(128, 128, 3))
base_model.trainable = False  # Freeze the base model layers

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation="relu"),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_n_classes, activation="softmax")
])

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=[tf.keras.metrics.CategoricalAccuracy()]
)

model.summary()

# begin model training
model.fit(
    x_train, y_train,
    batch_size=16,
    epochs=200,
    validation_split=0.2
)

# load test data for n = 5
test_image_folder_path = 'classification-data'
n_test_values = [3]
x_test, y_test = load_quantum_state_data(test_image_folder_path, n_test_values)

# normalize test images
x_test = x_test / 255.0

# Decode predictions to the n values
predictions = model.predict(x_test)
predicted_n = np.argmax(predictions, axis=1)

print(f"Predicted principal quantum numbers (n): {predicted_n}")

if len(predicted_n) != len(y_test):
    print(f"Warning: Mismatch in predictions. Expected {len(y_test)} predictions but got {len(predicted_n)}")

Loaded 35 images and 35 labels.
Skipped 0 files due to format issues or filtering: []
x_train shape: (35, 128, 128, 3)
y_train shape: (35,)
Unique n values found: [1 2 3 4 5]


Epoch 1/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 695ms/step - categorical_accuracy: 0.3661 - loss: 2.5287 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.8823
Epoch 2/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 589ms/step - categorical_accuracy: 0.5476 - loss: 1.2149 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.5264
Epoch 3/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 585ms/step - categorical_accuracy: 0.5685 - loss: 1.3971 - val_categorical_accuracy: 0.0000e+00 - val_loss: 4.1852
Epoch 4/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 605ms/step - categorical_accuracy: 0.6786 - loss: 0.9999 - val_categorical_accuracy: 0.0000e+00 - val_loss: 3.6448
Epoch 5/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 603ms/step - categorical_accuracy: 0.6131 - loss: 0.9022 - val_categorical_accuracy: 0.0000e+00 - val_loss: 2.9130
Epoch 6/200
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[