In [None]:
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.figure import Figure

import utilities_plot as up
from utilities_base import VAE, load_data

In [None]:
# 学習済みの重みとデータを読み込む

data, data_id_dict, child_id_dict, word_dict, category_dict = load_data(
    ["data", "data_id_dict", "child_id_dict", "word_dict", "category_dict"]
)
word_count = len(word_dict)

model = VAE().to("cuda")
model.load_state_dict(torch.load("tmp/best_model.pth"))
# model.load_state_dict(torch.load("tmp/model_state_dict.pth"))

In [None]:
# カテゴリの表示
for key, val in category_dict.items():
    print(key, len(val))

In [None]:
# 年齢の実データと語彙の実データを潜在空間上にプロット

figs = {}
figs["age"] = plt.subplots()
figs["vocabulary"] = plt.subplots()

data_ids = list(data_id_dict.keys())
up.plot_x_with_age(model, data_ids, *figs["age"])
up.plot_x_with_vocabulary(model, data_ids, *figs["vocabulary"], ["all"])
# up.plot_x_with_vocabulary(
#     model,
#     data_ids,
#     *figs["vocabulary"],
#     [上のセルの結果を参考にしてここに語彙のカテゴリを入れる(["all"]だと全語彙)]
# )

In [None]:
# 潜在空間の語彙の発達具合をcategoryごとに可視化

z1_start, z1_end = -6, 7
z2_start, z2_end = -3, 3
spacing = 0.1

z_meshgrid = up.make_lattice_points(z1_start, z1_end, z2_start, z2_end, spacing)
for category in category_dict.keys():
    figs[category] = plt.subplots()
    up.plot_vocabulary(model, z_meshgrid, *figs[category], [category])
    up.plot_x(model, data, figs[category][1], "tab:blue")
    figs[category][0].savefig(f"images/vocabulary/{category}_vocabulary.png")
    plt.close()

In [None]:
# 潜在空間の語彙の発達具合を全categoryまとめて可視化

figs["all"] = plt.subplots()
up.plot_vocabulary(model, z_meshgrid, *figs["all"], ["all"])
up.plot_x(model, data, figs["all"][1], "tab:blue")
figs["all"][0].savefig(f"images/vocabulary/{"all"}_vocabulary.png")
plt.close()

In [None]:
# 縦断データ(同じ子供の年齢が違うデータ)の可視化

figs["arrow"] = plt.subplots()
up.plot_x(model, data, figs["arrow"][1])
data_ids = []

# 何個以上の縦断データを選ぶか
n = 3
for i, v in child_id_dict.items():
    if len(v) >= n:
        data_ids.append([j[0] for j in v])
# print(len(data_ids))

# n個以上の縦断データの中からランダムにm個選んでプロット
m = 1
datas = random.sample(data_ids, m)
child_id = data_id_dict[datas[0][0]][0]
print(child_id_dict[child_id])
for i in datas:
    up.plot_arrow(model, i, figs["arrow"][1])

In [None]:


from matplotlib.patches import Circle


def make_circle(r: float, fig: Figure, ax: Axes) -> None:
    all_0s = np.zeros((1, 680))
    z0 = up.x_to_z(model, all_0s).flatten()
    all_1s = np.ones((1, 680))
    z1 = up.x_to_z(model, all_1s).flatten()
    mid = (z0 + z1) / 2
    # 2点間の距離
    d = np.linalg.norm(z1 - z0)

    # 2点間の中心からの距離
    h = np.sqrt(r**2 - (d / 2) ** 2)

    # 中心点を見つけるための単位ベクトルの計算
    vec = z1 - z0
    vec_perp = np.array([-vec[1], vec[0]])
    unit_vec_perp = vec_perp / np.linalg.norm(vec_perp)

    # 中心点Cの計算
    C1 = mid + h * unit_vec_perp
    C2 = mid - h * unit_vec_perp

    # 円を描く
    circle1 = Circle(C1, r, fill=False, color="black")
    circle2 = Circle(C2, r, fill=False, color="black")
    ax.add_patch(circle1)
    ax.add_patch(circle2)

In [None]:
figs["tmp"] = plt.subplots()
up.plot_x(model, data, figs["tmp"][1])
make_circle(2.15, *figs["tmp"])

In [None]:
figs["tmp"] = plt.subplots()
up.plot_vocabulary(model, z_meshgrid, *figs["tmp"], ["all"])
up.plot_x(model, data, figs["tmp"][1])