In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset

import utilities_plot as up
from utilities_model import VAE

In [None]:
with open("tmp/data.pkl", "rb") as f:
    data = pickle.load(f)
with open("tmp/data_id_dict.pkl", "rb") as f:
    data_id_dict = pickle.load(f)
with open("tmp/child_id_dict.pkl", "rb") as f:
    child_id_dict = pickle.load(f)
with open("tmp/word_dict.pkl", "rb") as f:
    word_dict = pickle.load(f)
with open("tmp/category_dict.pkl", "rb") as f:
    category_dict = pickle.load(f)

tensor_data = torch.tensor(data.astype(np.float32))
dataset = TensorDataset(tensor_data)
data_loader = DataLoader(dataset, batch_size=64)

model = VAE().to("cuda")
model.load_state_dict(torch.load("tmp/best_model.pth"))
# model.load_state_dict(torch.load("tmp/model_state_dict.pth"))

word_count = len(word_dict)

In [None]:
print(data)
print(data.shape)

In [None]:
for key, val in category_dict.items():
    print(key, len(val))

In [None]:
figs = {}
figs["age"] = plt.subplots()
figs["vocabulary"] = plt.subplots()

data_ids = list(data_id_dict.keys())
up.plot_x_with_age(model, data_ids, *figs["age"])
up.plot_x_with_vocabulary(model, data_ids, *figs["vocabulary"], ["locations"])

In [None]:
z1_start, z1_end = -6, 7
z2_start, z2_end = -3, 3
spacing = 0.1

z_meshgrid = up.make_lattice_points(z1_start, z1_end, z2_start, z2_end, spacing)
for category in category_dict.keys():
    figs[category] = plt.subplots()
    up.plot_vocabulary(model, z_meshgrid, *figs[category], [category])
    up.plot_x(model, data, figs[category][1], "tab:blue")
    figs[category][0].savefig(f"images/vocabulary/{category}_vocabulary.png")

In [None]:
import random

figs["arrow"] = plt.subplots()
up.plot_x(model, data, figs["arrow"][1])
data_ids = []
for i, v in child_id_dict.items():
    if len(v) >= 2:
        data_ids.append([j[0] for j in v])
print(len(data_ids))
n = 1
datas = random.sample(data_ids, n)
child_id = data_id_dict[datas[0][0]][0]
print(child_id_dict[child_id])
for i in datas:
    up.plot_arrow(model, i, figs["arrow"][1])

In [None]:
from matplotlib.patches import Circle


def make_circle(r: float, fig: Figure, ax: Axes) -> None:
    all_0s = np.zeros((1, 680))
    z0 = up.x_to_z(model, all_0s).flatten()
    all_1s = np.ones((1, 680))
    z1 = up.x_to_z(model, all_1s).flatten()
    mid = (z0 + z1) / 2
    # 2点間の距離
    d = np.linalg.norm(z1 - z0)

    # 2点間の中心からの距離
    h = np.sqrt(r**2 - (d / 2) ** 2)

    # 中心点を見つけるための単位ベクトルの計算
    vec = z1 - z0
    vec_perp = np.array([-vec[1], vec[0]])
    unit_vec_perp = vec_perp / np.linalg.norm(vec_perp)

    # 中心点Cの計算
    C1 = mid + h * unit_vec_perp
    C2 = mid - h * unit_vec_perp

    # 円を描く
    circle1 = Circle(C1, r, fill=False, color="black")
    circle2 = Circle(C2, r, fill=False, color="black")
    ax.add_patch(circle1)
    ax.add_patch(circle2)

In [None]:
figs["tmp"] = plt.subplots()
up.plot_x(model, data, figs["tmp"][1])
make_circle(2.15, *figs["tmp"])

In [None]:
figs["tmp"] = plt.subplots()
up.plot_vocabulary(model, z_meshgrid, *figs["tmp"], ["all"])
up.plot_x(model, data, figs["tmp"][1])