# Imports

In [None]:
# imports

import os
import numpy as np
import random
import pandas as pd
import json
import pickle
import matplotlib.pyplot as plt
from utils import *

# Import and split data

In [None]:
# define file directory
directory = '../data/interim/PatternNet/images'

# create a list of all possible classes
all_classes = []
for item in os.listdir(directory):
    item_path = os.path.join(directory, item)
    if os.path.isdir(item_path):
        all_classes.append(item)
#print(all_classes)
        
# create a list of classes considered for this project
classes = ['beach', 'chaparral', 'dense_residential', 'forest', 'freeway', 'harbor', 'overpass', 'parking_space', 'river', 'swimming_pool']
        
# inspect the number of images per class
data = []
for class_name in classes:
    class_dir = os.path.join(directory, class_name)
    image_count = len(os.listdir(class_dir))
    data.append([class_name, image_count])
image_count_df = pd.DataFrame(data, columns=['Class', 'Total Image Count'])
print("The following classes were selected for evaluation:")
display(image_count_df)

# define the train, val, and test sets
train_files, val_files, test_files = generate_splits(classes, directory)

# Inspect examples

In [None]:
def fft_image(image):

    # Read the image
    # image = plt.imread(image)

    # Convert the image to grayscale if it's not already
    if len(image.shape) > 2:
        image_gray = np.mean(image, axis=2)
    else:
        image_gray = image

    # Compute the 2D FFT of the grayscale image
    fft_image = np.fft.fft2(image_gray)

    # Shift the zero frequency component to the center
    fft_image_shifted = np.fft.fftshift(fft_image)

    # Compute the magnitude spectrum (absolute value) of the shifted FFT
    magnitude_spectrum = np.abs(fft_image_shifted)
    
    return magnitude_spectrum


In [None]:
# define a function to show a grid of spectrum images in a directory (given a file subset)
def generate_freq_spectrum(files, directory, images_per_class=3):
    # create a dictionary to store class images
    class_images = {}
    
    # iterate over each file
    for class_name, file_name in files:
        # load the image
        img = plt.imread(os.path.join(directory, class_name, file_name))
        # if class not in dictionary, initialize empty list
        if class_name not in class_images:
            class_images[class_name] = []
        # append image to class list
        class_images[class_name].append(img)

    # create a grid of images
    num_classes = len(class_images)
    fig, axes = plt.subplots(num_classes, images_per_class + 1, figsize=(12, 3*num_classes))
    for i, (class_name, images) in enumerate(class_images.items()):
        # display class name in the first column
        axes[i, 0].text(0.5, 0.5, class_name, fontsize=16, ha='center', va='center')
        axes[i, 0].axis('off')

        # display random images in the subsequent columns
        random.shuffle(images)
        for j in range(images_per_class):
            magnitude_spectrum = fft_image(images[j])
            axes[i, j+1].imshow(np.log(1 + magnitude_spectrum), cmap='gray')

    plt.tight_layout()
    plt.show()

In [None]:
# show a grid with 3 random images per class
generate_freq_spectrum(train_files, directory, 3)

In [None]:
# define a function to show a grid of spectrum images in a directory (given a file subset)
def generate_freq_spectrum(files, directory, images_per_class=3):
    # create a dictionary to store class images
    class_images = {}
    
    # iterate over each file
    for class_name, file_name in files:
        # load the image
        img = plt.imread(os.path.join(directory, class_name, file_name))
        # if class not in dictionary, initialize empty list
        if class_name not in class_images:
            class_images[class_name] = []
        # append image to class list
        class_images[class_name].append(img)

    # create a grid of images
    num_classes = len(class_images)
    fig, axes = plt.subplots(num_classes, images_per_class + 1, figsize=(12, 3*num_classes))
    for i, (class_name, images) in enumerate(class_images.items()):
        # display class name in the first column
        axes[i, 0].text(0.5, 0.5, class_name, fontsize=16, ha='center', va='center')
        axes[i, 0].axis('off')

        # display random images in the subsequent columns
        random.shuffle(images)
        for j in range(images_per_class):
            magnitude_spectrum = fft_image(images[j])
            axes[i, j+1].imshow(np.log(1 + magnitude_spectrum), cmap='gray')

    plt.tight_layout()
    plt.show()