In [1]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px

from nltk.corpus import reuters

In [2]:
categories = reuters.categories()
print(categories)

['acq', 'alum', 'barley', 'bop', 'carcass', 'castor-oil', 'cocoa', 'coconut', 'coconut-oil', 'coffee', 'copper', 'copra-cake', 'corn', 'cotton', 'cotton-oil', 'cpi', 'cpu', 'crude', 'dfl', 'dlr', 'dmk', 'earn', 'fuel', 'gas', 'gnp', 'gold', 'grain', 'groundnut', 'groundnut-oil', 'heat', 'hog', 'housing', 'income', 'instal-debt', 'interest', 'ipi', 'iron-steel', 'jet', 'jobs', 'l-cattle', 'lead', 'lei', 'lin-oil', 'livestock', 'lumber', 'meal-feed', 'money-fx', 'money-supply', 'naphtha', 'nat-gas', 'nickel', 'nkr', 'nzdlr', 'oat', 'oilseed', 'orange', 'palladium', 'palm-oil', 'palmkernel', 'pet-chem', 'platinum', 'potato', 'propane', 'rand', 'rape-oil', 'rapeseed', 'reserves', 'retail', 'rice', 'rubber', 'rye', 'ship', 'silver', 'sorghum', 'soy-meal', 'soy-oil', 'soybean', 'strategic-metal', 'sugar', 'sun-meal', 'sun-oil', 'sunseed', 'tea', 'tin', 'trade', 'veg-oil', 'wheat', 'wpi', 'yen', 'zinc']


In [3]:
train_data = [f_id for f_id in reuters.fileids() if "train" in f_id]
print(len(train_data))

7769


In [4]:
test_data = [f_id for f_id in reuters.fileids() if "test" in f_id]
print(len(test_data))

3019


# Training Data

In [6]:
category_count = {c: 0 for c in categories}

In [7]:
for c in categories:
    category_count[c] = len(reuters.fileids(c))

In [8]:
category_count_df = pd.DataFrame.from_dict(category_count, orient="index")
category_count_df.reset_index(inplace=True)
category_count_df = (
    category_count_df.rename(columns={"index": "category", 0: "counts"})
    .sort_values(by="counts")
    .reset_index(drop=True)
)

In [9]:
twenty_fifth = category_count_df[
    category_count_df["counts"]
    == category_count_df.quantile(0.25, interpolation="nearest")[0]
]["counts"].index[0]

median = category_count_df[
    category_count_df["counts"]
    == category_count_df.quantile(0.5, interpolation="nearest")[0]
]["counts"].index[0]

seventy_fifth = category_count_df[
    category_count_df["counts"]
    == category_count_df.quantile(0.75, interpolation="nearest")[0]
]["counts"].index[0]

In [46]:
# TODO: Plot the smallest, 25%, 50%, 75% and largest values and categories.
display_arr = [
    category_count_df["counts"].idxmin(),
    twenty_fifth,
    median,
    seventy_fifth,
    category_count_df["counts"].idxmax(),
]

fig = px.bar(
    category_count_df,
    x="category",
    y="counts",
    title="Label Counts",
)
fig.update_xaxes(
    type="category",
    tickangle=60,
    tickmode="array",
    tickvals=display_arr,
)

fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='#dcdcdc')

fig.update_traces(marker_color="#9467bd")

for i in display_arr:
    count = category_count_df.iloc[i]["counts"]
    fig.add_annotation(x=i, y=count + 100, text=str(count), showarrow=False)
    
fig.update_layout(
    {
        "plot_bgcolor": "rgba(0, 0, 0, 0)",
        "paper_bgcolor": "rgba(0, 0, 0, 0)",
    }
)

fig.show()

In [11]:
category_count_matrix = {
    c: {c_: 0 if c_ != c else False for c_ in categories} for c in categories
}

In [12]:
for c in categories:
    f_ids = reuters.fileids(c)
    for f_id in f_ids:
        cats = reuters.categories(f_id)
        for c_ in cats:
            if c != c_:
                category_count_matrix[c][c_] += 1

In [13]:
category_count_matrix_df = pd.DataFrame.from_dict(category_count_matrix, orient="index")

In [14]:
# Get Lower Triangle
category_count_matrix_df = category_count_matrix_df.mask(
    np.triu(np.ones(category_count_matrix_df.shape, dtype=bool))
)

In [22]:
title = "Inter-Label Counts"

heat = go.Heatmap(
    z=category_count_matrix_df,
    x=category_count_matrix_df.columns,
    y=category_count_matrix_df.columns,
    colorscale=px.colors.sequential.Purples,
)

layout = go.Layout(
    title_text=title,
    height=1000,
    xaxis_showgrid=False,
    yaxis_showgrid=False,
    yaxis_autorange="reversed",
    paper_bgcolor="rgba(0,0,0,0)",
    plot_bgcolor="rgba(0,0,0,0)",
)

fig = go.Figure(data=[heat], layout=layout)

fig.show()