In [None]:
"""
    Copyright 2023 by Michał Stolarz <michal.stolarz@h-brs.de>

    This file is part of dlbm_binary.
    It is used to plot statistics of the binary dataset.

    dlbm_binary is free software: you can redistribute it and/or modify
    it under the terms of the GNU Affero General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    dlbm_binary is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU Affero General Public License for more details.
    You should have received a copy of the GNU Affero General Public License
    along with dlbm_binary. If not, see <http://www.gnu.org/licenses/>.
"""

import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

sns.set(rc={'figure.figsize':(11.7,1.1)})
sns.set_style("whitegrid")

DATASET_PATH = "/home/michal/thesis/interaction_dataset"
SPLITS = ["train", "valid", "test"]
CLASSES = ["diff", "feedback"]
USERS = ["1MBU59SJ", "Z7U8NLC9", "U3L9LFS0", "M4OE3RP5", "J0YH72SI",
         "03DEQR1O", "Q4GTE6L4", "PTEM0K27", "6XTLNK55", "5J7PWO3G",
         "1PE38CJI", "25NQFBB2", "1CZ1CL1P",
         "6ZN36CQR", "6RGY40ES", "3UDT4XN8", "3G4MPE2W", "76HKXYD3",
         "A9XL9U1N", "COT085MQ", "F41CCF9W", "Q4ABT87L", "SYBO5F61"]

In [None]:
results_frame = pd.DataFrame()
for split in SPLITS:
    for user in USERS:
        for cat in CLASSES:
            cat_counts = {'Split': [split], 'User': np.where(np.array(USERS)==user)[0], 'Class': [cat], 'Count': [0]}
            counter = 0
            image_paths = os.listdir(os.path.join(DATASET_PATH, split, cat))
            for image_path in image_paths:
                user_id = image_path.split('.')[0].split('_')[1]
                if user_id == "03DEQR10":
                    user_id = "03DEQR1O"
                    
                if user_id == user:
                    counter += 1
                    
            cat_counts['Count'] = [counter/8]
            cat_counts = pd.DataFrame(cat_counts)
            results_frame = pd.concat([results_frame, cat_counts], ignore_index=True)
results_frame

In [None]:
for split in SPLITS:
    ax = sns.barplot(x="User", y='Count', hue="Class", data=results_frame.loc[results_frame['Split']==split])
    if split != 'train':
        ax.set_yticks([0, 100, 200])
    else:
        ax.set_yticks([0, 500, 1000])

    plt.title(f"{split.capitalize()} split")
    plt.legend(bbox_to_anchor=(1.01, 0.85), loc='upper left', borderaxespad=0)
    plt.savefig(f"plots/data/{split}.pdf", format="pdf", bbox_inches="tight")
    plt.show()