In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(680, 512),
            nn.ReLU(),
            nn.Linear(512, 4),
        )

        self.decoder = nn.Sequential(
            nn.Linear(2, 512),
            nn.ReLU(),
            nn.Linear(512, 680),
            nn.Sigmoid(),
        )

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu, log_var = x.chunk(2, dim=1)
        z = self.reparameterize(mu, log_var)
        return self.decoder(z), mu, log_var


def loss_function(recon_x, x, mu, log_var):
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD


def train(model, data_loader, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (data,) in enumerate(data_loader):
        data = data.to("cuda")
        optimizer.zero_grad()
        recon_batch, mu, log_var = model(data)
        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    average_loss = train_loss / len(data_loader.dataset)
    print(f"Average loss: {average_loss:.4f}")

In [None]:
with open("tmp/data.pkl", "rb") as f:
    data = pickle.load(f)
with open("tmp/data_id_dict.pkl", "rb") as f:
    data_id_dict = pickle.load(f)
with open("tmp/child_id_dict.pkl", "rb") as f:
    child_id_dict = pickle.load(f)
with open("tmp/word_dict.pkl", "rb") as f:
    word_dict = pickle.load(f)
with open("tmp/category_dict.pkl", "rb") as f:
    category_dict = pickle.load(f)

tensor_data = torch.tensor(data.astype(np.float32))
dataset = TensorDataset(tensor_data)
data_loader = DataLoader(dataset, batch_size=64)

model = VAE().to("cuda")
model.load_state_dict(torch.load("tmp/best_model.pth"))
# model.load_state_dict(torch.load("tmp/model_state_dict.pth"))

word_count = len(word_dict)

In [None]:
print(data)
print(data.shape)

In [None]:
# 単語データを潜在変数に変換
def x_to_z(model: VAE, xs: np.ndarray) -> np.ndarray:
    model.eval()
    with torch.no_grad():
        xs = torch.tensor(xs.astype(np.float32)).cuda()
        zs = model.encoder(xs)
        mu, log_var = zs.chunk(2, dim=1)
        z_points = mu.cpu()
        z_points = np.array(z_points)
        return z_points


# 潜在変数を単語データに変換
def z_to_x(model: VAE, z: np.ndarray) -> np.ndarray:
    model.eval()
    with torch.no_grad():
        z = torch.tensor(z.astype(np.float32)).cuda()
        xs = model.decoder(z)
        xs = np.array(xs.cpu())
        return xs


def category_to_num(categories: list[str]) -> list[int]:
    nums = []
    if categories == ["all"]:
        return list(range(word_count))
    for category in categories:
        nums.extend([i[0] for i in category_dict[category]])
    return nums


def get_vocabulary(xs: np.ndarray, categories: list[str] = ["all"]) -> np.ndarray:
    nums = category_to_num(categories)
    return np.sum(xs[..., nums], axis=-1)

In [None]:
def plot_origin(model: VAE, ax: Axes) -> None:
    all_0s = np.zeros((1, 680))
    z0 = x_to_z(model, all_0s)
    all_1s = np.ones((1, 680))
    z1 = x_to_z(model, all_1s)
    ax.scatter(z0[:, 0], z0[:, 1], color="blue", label="all 0s", marker="x")
    ax.scatter(z1[:, 0], z1[:, 1], color="red", label="all 1s", marker="*")
    ax.legend()


def set_labels(ax: Axes, title: str = "") -> None:
    ax.set_xlabel(r"$z_{1}$")
    ax.set_ylabel(r"$z_{2}$")
    ax.set_title(title)


def plot_x(model: VAE, xs: np.ndarray, ax: Axes, color: str = "tab:blue") -> None:
    zs = x_to_z(model, xs)
    ax.scatter(zs[:, 0], zs[:, 1], s=0.2, color=color)

    plot_origin(model, ax)


def plot_x_with_age(model: VAE, data_ids: list[int], fig: Figure, ax: Axes) -> None:
    xs = data[data_ids]
    ages = np.array([data_id_dict[i][1] for i in data_ids])
    zs = x_to_z(model, xs)
    scatter = ax.scatter(zs[:, 0], zs[:, 1], c=ages, cmap="turbo", s=0.2)
    cbar = fig.colorbar(scatter, ax=ax)
    cbar.set_label("age")

    plot_origin(model, ax)
    set_labels(ax, "age")


def plot_x_with_vocabulary(
    model: VAE,
    data_ids: list[int],
    fig: Figure,
    ax: Axes,
    categories: list[str] = ["all"],
) -> None:
    xs = data[data_ids]
    vocabulary = get_vocabulary(xs, categories)
    zs = x_to_z(model, xs)
    scatter = ax.scatter(zs[:, 0], zs[:, 1], c=vocabulary, cmap="turbo", s=0.2)
    cbar = fig.colorbar(scatter, ax=ax)
    cbar.set_label("vocabulary")

    plot_origin(model, ax)
    set_labels(ax, ", ".join(categories))

In [None]:
for key, val in category_dict.items():
    print(key, len(val))

In [None]:
figs = {}
figs["age"] = plt.subplots()
figs["vocabulary"] = plt.subplots()

data_ids = list(data_id_dict.keys())
plot_x_with_age(model, data_ids, *figs["age"])
plot_x_with_vocabulary(model, data_ids, *figs["vocabulary"], ["locations"])

In [None]:
# 潜在空間の格子点
import itertools


def make_lattice_points(
    z1_start: float,
    z1_end: float,
    z2_start: float,
    z2_end: float,
    spacing: float,
) -> np.float32:
    z1 = np.arange(z1_start, z1_end + spacing, spacing)
    z2 = np.arange(z2_start, z2_end + spacing, spacing)

    return np.meshgrid(z1, z2)


def plot_vocabulary(
    model: VAE,
    z_mashgrid,
    fig: Figure,
    ax: Axes,
    categories: list[str] = ["all"],
) -> None:
    z1, z2 = z_mashgrid
    zs = np.dstack((z1, z2))
    xs = z_to_x(model, zs)
    vocabulary = get_vocabulary(xs, categories)
    cmap = ax.pcolormesh(z1, z2, vocabulary, cmap="turbo")
    # cmap = ax.pcolormesh(z1, z2, vocabulary, cmap="viridis")
    cbar = fig.colorbar(cmap, ax=ax)
    cbar.set_label("vocabulary")
    set_labels(ax, ", ".join(categories))


def plot_arrow(model: VAE, data_ids: list[int], ax: Axes) -> None:
    xs = data[data_ids]
    zs = x_to_z(model, xs)
    for z1, z2 in zip(zs[0:, :], zs[1:, :]):
        ax.annotate(
            "",
            xy=z2,
            xytext=z1,
            arrowprops=dict(arrowstyle="->", color="black"),
        )

In [None]:
z1_start, z1_end = -6, 7
z2_start, z2_end = -3, 3
spacing = 0.1

z_meshgrid = make_lattice_points(z1_start, z1_end, z2_start, z2_end, spacing)
for category in category_dict.keys():
    figs[category] = plt.subplots()
    plot_vocabulary(model, z_meshgrid, *figs[category], [category])
    plot_x(model, data, figs[category][1], "tab:blue")
    figs[category][0].savefig(f"images/vocabulary/{category}_vocabulary.png")
    

In [None]:
import random

figs["arrow"] = plt.subplots()
plot_x(model, data, figs["arrow"][1])
data_ids = []
for i, v in child_id_dict.items():
    if len(v) >= 2:
        data_ids.append([j[0] for j in v])
print(len(data_ids))
n = 1
datas = random.sample(data_ids, n)
child_id = data_id_dict[datas[0][0]][0]
print(child_id_dict[child_id])
for i in datas:
    plot_arrow(model, i, figs["arrow"][1])

In [None]:
from matplotlib.patches import Circle


def make_circle(r: float, fig: Figure, ax: Axes) -> None:
    all_0s = np.zeros((1, 680))
    z0 = x_to_z(model, all_0s).flatten()
    all_1s = np.ones((1, 680))
    z1 = x_to_z(model, all_1s).flatten()
    mid = (z0 + z1) / 2
    # 2点間の距離
    d = np.linalg.norm(z1 - z0)

    # 2点間の中心からの距離
    h = np.sqrt(r**2 - (d / 2) ** 2)

    # 中心点を見つけるための単位ベクトルの計算
    vec = z1 - z0
    vec_perp = np.array([-vec[1], vec[0]])
    unit_vec_perp = vec_perp / np.linalg.norm(vec_perp)

    # 中心点Cの計算
    C1 = mid + h * unit_vec_perp
    C2 = mid - h * unit_vec_perp

    # 円を描く
    circle1 = Circle(C1, r, fill=False, color="black")
    circle2 = Circle(C2, r, fill=False, color="black")
    ax.add_patch(circle1)
    ax.add_patch(circle2)

In [None]:
figs["tmp"] = plt.subplots()
plot_x(model, data, figs["tmp"][1])
make_circle(2.15, *figs["tmp"])

In [None]:
figs["tmp"] = plt.subplots()
plot_vocabulary(model, z_meshgrid, *figs["tmp"], ["all"])
plot_x(model, data, figs["tmp"][1])