In [1]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
import pandas as pd

username = "fnx11"
save_image_loc = f'/home/{username}/thesis/codes/CosDefence/data_dists/'

In [51]:
def get_minor_share(dataset_name, class_ratio):
    if dataset_name == "mnist":
        # in mnist maximum number of samples each class at least have is 5421
        if class_ratio == 10:
            minor_share = 28  # so we will put 28 images from each minority class => 28*9*10 = 2520 
        elif class_ratio == 4:
            minor_share = 41
        else:
            minor_share = 54
    elif dataset_name == "fmnist":
        # in fashion mnist all classes have 6000 examples each
        if class_ratio == 10:
            minor_share = 31
        elif class_ratio == 4:
            minor_share = 46
        else:
            minor_share = 60
    else:
        # in cifar10 all classes have 5000 examples each
        if class_ratio == 10:
            minor_share = 26
        elif class_ratio == 4:
            minor_share = 38
        else:
            minor_share = 50
    return minor_share

def get_image_dists(dataset_name, class_ratio):
    minor_share = get_minor_share(dataset_name, class_ratio)
    major_share = minor_share * class_ratio 
    image_dists = list()
    for i in range(1, 11):
        cat_lt = list()
        for j in range(10):
            if j==i-1:
                cat_lt.append(major_share)
            else:
                cat_lt.append(minor_share)
        image_dists.append(cat_lt)
    return image_dists

def get_fig(dataset_name, class_labels, class_ratio, hspace):
    image_dists = get_image_dists(dataset_name, class_ratio)
    sub_titles = [f'clients_{i*10}-{10*(i+1)-1}' for i in range(10)]
    fig = make_subplots(rows=2, cols=5, horizontal_spacing=hspace, vertical_spacing = 0.15, subplot_titles=tuple(sub_titles))
    idx = 0
    for row in range(2):
        for col in range(5):
            fig.append_trace(go.Bar(y=class_labels, x=image_dists[idx], marker=go.bar.Marker(color=px.colors.qualitative.Plotly), orientation="h", cliponaxis=False), row+1, col+1)
            idx += 1
    fig.update_layout(height=550, showlegend=False)
    return fig

# 1. MNIST Dataset
## a) Class Ratio 10:1

In [52]:
dataset_name = "mnist"
class_labels = ['digit_0', 'digit_1', 'digit_2', 'digit_3', 'digit_4', 'digit_5', 'digit_6', 'digit_7', 'digit_8', 'digit_9']
hspace = 0.07
class_ratio = 10
mnist_fig_cr10 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr10.write_image(save_image_loc + "mnist_fig_cr10.png")
mnist_fig_cr10.show()

## b) Class Ratio 4:1

In [53]:
class_ratio = 4
mnist_fig_cr4 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr4.show()

## c) Class Ratio 1:1

In [54]:
class_ratio = 1
mnist_fig_cr1 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr1.show()

# 2. FMNIST Dataset
## a) Class Ratio 10:1

In [55]:
dataset_name = "fmnist"
class_labels = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
hspace = 0.08
class_ratio = 10
mnist_fig_cr10 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr10.show()

## b) Class Ratio 4:1

In [56]:
class_ratio = 4
mnist_fig_cr4 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr4.show()

## c) Class Ratio 1:1

In [57]:
class_ratio = 1
mnist_fig_cr1 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr1.show()

# 1. CIFAR10 Dataset
## a) Class Ratio 10:1

In [58]:
dataset_name = "cifar10"
class_labels =['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
hspace = 0.09
class_ratio = 10
mnist_fig_cr10 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr10.show()

## b) Class Ratio 4:1

In [59]:
class_ratio = 4
mnist_fig_cr4 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr4.show()

## c) Class Ratio 1:1

In [60]:
class_ratio = 1
mnist_fig_cr1 = get_fig(dataset_name, class_labels, class_ratio, hspace)
mnist_fig_cr1.show()