In [None]:
from wv_utils import reduce_dimensions, create_angle_matrix
from note_reader import get_notes_as_pitch_classes
from gensim.models import FastText
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import glob

In [None]:
plt.rcParams["font.family"] = ["Times"]
SAVE_FIGS = True
UNK_TAG = "<UNK>"

In [None]:
file_list = glob.glob("data/txt/*.txt")

In [None]:
docs, counts = get_notes_as_pitch_classes(file_list)

In [None]:
pitches_sorted = counts.most_common()
pitches, pitch_counts = zip(*pitches_sorted)

In [None]:
THRESHOLD = 20

In [None]:
fig = plt.figure(figsize=(6, 4))
plt.title("Pitch-class counts", fontdict={"fontsize": 14})
plt.xlabel("Pitch classes")
plt.xticks(rotation=-90)
plt.ylabel("Counts")
plt.gca().yaxis.set_major_formatter(
    plt.matplotlib.ticker.StrMethodFormatter("{x:,.0f}"))
plt.bar(pitches, pitch_counts)
plt.tight_layout()
plt.axvline(x=THRESHOLD-0.5, ymax=0.95, color="r",
            linestyle="--", linewidth=0.5)
plt.text(THRESHOLD-1.2, 100000, "cutoff", rotation=90, color="r")
plt.tight_layout()
plt.show()
if SAVE_FIGS:
    fig.savefig("results/overall_cutoff.pdf")

In [None]:
common_pitches = [x[0] for x in counts.most_common(THRESHOLD)]
common_docs = []
for doc in docs:
    common_doc = [x if x in common_pitches else UNK_TAG for x in doc]
    common_docs.append(common_doc)

# FastText model


In [None]:
# cb = callback()
ft_model = FastText(sentences=common_docs, vector_size=10,
                    epochs=500, window=5, seed=42)

In [None]:
wv = ft_model.wv
wv.most_similar(positive=["C", "A"], negative=["G"])  # D

In [None]:
angles_dict = create_angle_matrix(wv)
df = pd.DataFrame.from_dict(angles_dict)
if SAVE_FIGS:
    df.to_excel("results/angles.xlsx")

In [None]:
x_vals, y_vals, labels = reduce_dimensions(wv)
plt.figure(figsize=(6, 6))
plt.scatter(x_vals, y_vals)

for i in range(len(labels)):
    plt.annotate(labels[i], (x_vals[i], y_vals[i]))

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(11, 8))
hm = sns.heatmap(df, cmap="coolwarm", annot=True, fmt=",.0f",
                 linewidths=0.5, cbar_kws={"shrink": 0.8})
hm.set(title="Pitch class correlation for the entire dataset")
ax.collections[0].colorbar.set_label("Angle (in degrees)")
ax.xaxis.tick_top()
plt.show()
if SAVE_FIGS:
    fig.savefig("results/overall_corr.pdf")