In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

from src.data_generation.datasets import gummy_worm_dataset_family, exclamation_mark_dataset_family

In [None]:
num_samples_per_dist = 10000

gummmy_worm_data_generations = gummy_worm_dataset_family()
exclamation_mark_data_generations = exclamation_mark_dataset_family()

In [None]:
fig, axs = plt.subplots(5, 4, figsize=(20, 24), dpi=200, sharex=True, sharey=True, constrained_layout=True)
axs = axs.flatten()

plt.suptitle("Gummy Worm Dataset Family")
colormap = np.array(['red', 'blue'])

for i, data_generation in enumerate(gummmy_worm_data_generations):
    ax = axs[i]

    X, y = data_generation.generate_data(num_samples_per_dist)
    
    print("Gummy Worm Variant", i, ":", X.shape, y.shape)
    
    ax.scatter(X[:, 0], X[:, 1], c=colormap[y], s=0.8)
    ax.grid(True, linestyle='--', alpha=0.6)

    legend_elements = [
    Line2D([0], [0], marker='o', color='white', markerfacecolor=colormap[label], label=str(label))
    for label in np.unique(y)
    ]
    ax.legend(handles=legend_elements, title="Classes", loc='upper left')
    ax.set_xlabel("feature 0")
    ax.set_ylabel("feature 1")

plt.savefig("./gummy_worm_dataset_family.png")

In [None]:
fig, axs = plt.subplots(5, 4, figsize=(20, 24), dpi=200, sharex=True, sharey=True, constrained_layout=True)
axs = axs.flatten()

plt.suptitle("Exclamation Mark Family")
colormap = np.array(['red', 'blue'])

for i, data_generation in enumerate(exclamation_mark_data_generations):
    ax = axs[i]

    X, y = data_generation.generate_data(num_samples_per_dist)
    
    print("Exclamation Mark Variant", i, ":", X.shape, y.shape)
    
    ax.scatter(X[:, 0], X[:, 1], c=colormap[y], s=0.8)
    ax.grid(True, linestyle='--', alpha=0.6)

    legend_elements = [
    Line2D([0], [0], marker='o', color='white', markerfacecolor=colormap[label], label=str(label))
    for label in np.unique(y)
    ]
    ax.legend(handles=legend_elements, title="Classes", loc='upper left')
    ax.set_xlabel("feature 0")
    ax.set_ylabel("feature 1")
    

plt.savefig("./exclamation_mark_dataset_family.png")

