In [None]:
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
preds = pd.read_csv("results/result.csv")
X = pd.read_csv("data/X_test.csv")

X["prediction"] = preds["prediction"]

## Find and plot the top 5 most frequent notes from the Review column

In [None]:
rows = []
for cls in sorted(X["prediction"].unique()):
    counter = Counter()
    reviews = X.loc[X["prediction"] == cls, "review"].dropna()
    for review in reviews:
        notes = [note.strip().lower() for note in review.split(",")]
        counter.update(notes)
    for note, cnt in counter.most_common(5):
        rows.append({"class": cls, "note": note, "count": cnt})

top_notes_df = pd.DataFrame(rows)

In [None]:
classes = sorted(top_notes_df["class"].unique())
fig, axes = plt.subplots(1, len(classes), figsize=(12, 5), sharey=True)

for ax, cls in zip(axes, classes):
    dfc = top_notes_df[top_notes_df["class"] == cls]
    ax.bar(dfc["note"], dfc["count"])
    ax.set_title(f"Top 5 Frequent Notes for Class {cls}")
    ax.set_xlabel("Note")
    if ax is axes[0]:
        ax.set_ylabel("Count")
    ax.tick_params(axis="x", rotation=45)

plt.tight_layout()
plt.show()

## Find number of classes from each country

In [None]:
counts = X.groupby(["origin", "prediction"]).size().unstack(fill_value=0)

counts["total"] = counts.sum(axis=1)
counts = counts.sort_values(by="total", ascending=False)

counts.drop(columns="total").plot(
    kind="bar",
    figsize=(12, 6)
)
plt.ylabel("Number of samples")
plt.title("Absolute counts of class 0 vs class 1 by Origin (sorted by total)")
plt.legend(title="Class", loc="upper right")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
# 1. Build the custom label: "[Roaster Name]"s [Roast Type] ([Country])"
X["roaster_label"] = X.apply(
    lambda r: f"{r["roaster"]}'s \n {r["roast"]} \n ({r["origin"]})",
    axis=1
)


# 2. Compute counts per (roaster_label, prediction)
counts = X.groupby(['roaster_label', 'prediction']).size().unstack(fill_value=0)

# 3. Select top 5 by class-1 and top 5 by class-0
top1 = counts.sort_values(by=1, ascending=False).head(5)
top0 = counts.sort_values(by=0, ascending=False).head(5)

# 4. Plot side-by-side subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

# Left: top 5 by class 1
top1.plot(kind='bar', ax=axes[0])
axes[0].set_title('Top 5 Roasters with Outstanding Roasts')
axes[0].set_xlabel('')
axes[0].set_ylabel('Number of samples')
axes[0].legend(title='Class', loc='upper right')
axes[0].tick_params(axis='x', rotation=45)

# Right: top 5 by class 0
top0.plot(kind='bar', ax=axes[1])
axes[1].set_title('Top 5 Roasters with Mid Roasts')
axes[1].set_xlabel('')
axes[1].legend(title='Class', loc='upper right')
axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()