# t-SNE initial state

Code to reproduce Figure 3.

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from plearning.phone_pair import AMERICAN_ENGLISH_CONSONANTS, AMERICAN_ENGLISH_SONORITY, TIPA

plt.style.use("./paper.mplstyle")

root = Path("./tsne")

In [None]:
def get_sonority(phone: str) -> str | None:
    category = [k for k, v in AMERICAN_ENGLISH_SONORITY.items() if phone in v]
    assert len(category) in (0, 1)
    return category[0] if category else None


def load_phone(root: Path, pair: str | None = None, rotate: bool = False) -> pd.DataFrame:
    phone_infos = pd.read_csv(root / "phone_infos.csv", index_col="idx")
    phone_infos["sonority"] = phone_infos["phone"].apply(get_sonority)
    if pair is None:
        tsne = np.load(root / "tsne.npy")
        if rotate:
            tsne = tsne @ np.array([[-1, 0], [0, -1]])  # Rotation for better figure
    else:
        tsne = np.load(root / f"tsne_{pair.lower()}.npy")
        phone_infos = phone_infos[phone_infos.phone.isin([pair[0].upper(), pair[1].upper()])]
    phone_infos["x"] = tsne[:, 0]
    phone_infos["y"] = tsne[:, 1]
    phone_infos.dropna(inplace=True)  # Remove SIL phone
    return phone_infos

In [None]:
data = []
for init in ["untrained", "noise"]:
    data.append(load_phone(root / init, rotate=init == "noise").groupby("sonority").sample(1000, random_state=0))
    for phone_pair in ["RL", "WY"]:
        data.append(load_phone(root / init, phone_pair).groupby("phone").sample(1000, random_state=0))

In [None]:
fig = plt.figure(figsize=(6, 5), layout="constrained")
subfigs = fig.subfigures(nrows=1, ncols=2)

colors = np.array([(0, 158, 115), (213, 94, 0), (240, 228, 66), (204, 121, 167)]) / 255

titles = ["No pretraining", "Pretrained on ambient sounds"]
for k, subfig in enumerate(subfigs):
    subsub = subfig.subfigures(nrows=2, ncols=1, height_ratios=[1, 1], hspace=0.5)
    ax = subsub[0].subplots()
    ax.axis("off")
    ax.set_title(titles[k], color=["#0072B2", "#E69F00"][k])

    sns.scatterplot(
        data[0 + k * 3],
        x="x",
        y="y",
        hue="sonority",
        s=3,
        alpha=0.5,
        hue_order=AMERICAN_ENGLISH_SONORITY.keys(),
        palette="magma",
        ax=ax,
        legend="auto" if k == 1 else False,
        rasterized=True,
    )
    if k == 1:
        handles, labels = ax.get_legend_handles_labels()
        subsub[0].legend(handles, labels, title="Sonority", ncols=3, loc="upper center", bbox_to_anchor=(0, 0))
        ax.get_legend().remove()
    subsub[1].suptitle(titles[k], color=["#0072B2", "#E69F00"][k])
    ax = subsub[1].subplots(ncols=2, gridspec_kw=dict(wspace=0))
    for i in [0, 1]:
        ax[i].axis("off")
    sns.scatterplot(
        data[1 + k * 3],
        x="x",
        y="y",
        hue="phone",
        s=3,
        alpha=0.5,
        ax=ax[0],
        palette={"R": colors[0], "L": colors[1]},
        legend="auto" if k == 1 else False,
        rasterized=True,
    )
    sns.scatterplot(
        data[2 + k * 3],
        x="x",
        y="y",
        hue="phone",
        s=3,
        alpha=0.5,
        palette={"W": colors[2], "Y": colors[3]},
        ax=ax[1],
        legend="auto" if k == 1 else False,
        rasterized=True,
    )
    if k == 1:
        handles0, labels0 = ax[0].get_legend_handles_labels()
        handles1, labels1 = ax[1].get_legend_handles_labels()
        new_labels = [TIPA[AMERICAN_ENGLISH_CONSONANTS[label]] for label in labels0[::-1] + labels1]
        subsub[1].legend(
            handles0[::-1] + handles1, new_labels, title="Phone", ncols=2, loc="upper center", bbox_to_anchor=(0, 0)
        )
        ax[0].get_legend().remove()
        ax[1].get_legend().remove()
    if k == 0:
        subsub[0].text(0, 0.96, r"\textbf{a)}", fontsize="x-large", transform=subsub[0].transFigure)
        subsub[1].text(0, 0.335, r"\textbf{b)}", fontsize="x-large", transform=subsub[0].transFigure)

plt.show()