In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

random_state = 12227
np.random.seed(random_state)

In [None]:
subset_name = 'imagenet30'
part = None # None or 1, 2, 3, 4, 5
data_path = f'../out_files/{subset_name}_cls.npz' if part == None else f'../out_files/{subset_name}_cls_part{part}.npz'
metadata_path = f'../out_files/{subset_name}_cls_metadata.csv'

data = np.load(data_path)
X_train, y_train, train_id = data['X_train'], data['y_train'], data['train_id']
X_val, y_val, val_id = data['X_val'], data['y_val'], data['val_id']

In [None]:
# Load metadata
metadata_df = pd.read_csv(metadata_path)
label_to_human = dict(zip(metadata_df['subset_label'], metadata_df['human_label']))

In [None]:
print(f"Unique classes: {np.unique(y_train)}")
print(f"Train shape: {X_train.shape}")
# Train
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
for i in range(3):
    random_index = np.random.randint(0, len(X_train) - 1)
    print(f"y_train[{random_index}]={y_train[random_index]}, train_id[{random_index}]={train_id[random_index]}")
    axs[i].imshow(X_train[random_index])
    axs[i].axis('off')
    axs[i].set_title(f'Label: {label_to_human[y_train[random_index]]}')
plt.show()

print(f"Validation shape: {X_val.shape}")
# Validation
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
for i in range(3):
    random_index = np.random.randint(0, len(X_val) - 1)
    print(f"y_train[{random_index}]={y_val[random_index]}, val_id[{random_index}]={val_id[random_index]}")
    axs[i].imshow(X_val[random_index])
    axs[i].axis('off')
    axs[i].set_title(f'Label: {label_to_human[y_val[random_index]]}')

plt.show()