In [1]:
from highdim.chap16.redunet import SimpleReduNet
from highdim.chap16.coding_rate import coding_rate
import time
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
%matplotlib inline

In [6]:
# Model and Input Setup
SPACE_DIM = 3
N_SAMPLES_PER_CLASS = 200
N_SAMPLES_PER_CLASS_TEST = 20
N_CLASSES = 3
N_SAMPLES = N_SAMPLES_PER_CLASS * N_CLASSES
N_SAMPLES_TEST = N_SAMPLES_PER_CLASS_TEST * N_CLASSES
L = 500
epsilon = torch.tensor([0.1])
eta = torch.tensor([0.5])
temperature = torch.tensor([0.01]) # これが大きいとP-hatが一様になってしまい、モード崩壊する


MU = torch.rand([N_CLASSES, SPACE_DIM-1]) * 20
STD = torch.rand([N_CLASSES, SPACE_DIM-1])
LAST_DIM_VALUE = [torch.tensor([[2]])]*N_CLASSES


P = torch.zeros([N_CLASSES, N_SAMPLES]) # n_classes x n_samples
for c in range(N_CLASSES):
    P[c, N_SAMPLES_PER_CLASS*c:N_SAMPLES_PER_CLASS*(c+1)] = 1
assert (P.sum(0) == 1).all(), 'prob normalization failed.'

class_ids = torch.argmax(P, dim=0)

Z_2d = []
for s in range(N_SAMPLES):
    # get class_id
    c = torch.argmax(P[:,s])
    sample = torch.cat([torch.normal(MU[c,d], STD[c,d], size=(1, 1)) for d in range(SPACE_DIM-1)]+[LAST_DIM_VALUE[c]], dim=0) # n_dim x 1
    Z_2d.append(sample)
Z_2d = torch.cat(Z_2d, dim=1)
Z = Z_2d / torch.norm(Z_2d, p='fro', dim=0, keepdim=True)


a0 = torch.ones(1) * 1000 # float 1だとclassが潰れる
a1 = torch.ones([N_CLASSES]) # num_classes
gamma = torch.ones([N_CLASSES])


P_test = torch.zeros([N_CLASSES, N_SAMPLES_TEST]) # n_classes x n_samples
for c in range(N_CLASSES):
    P_test[c, N_SAMPLES_PER_CLASS_TEST*c:N_SAMPLES_PER_CLASS_TEST*(c+1)] = 1
assert (P_test.sum(0) == 1).all(), 'prob normalization failed.'

class_ids_test = torch.argmax(P_test, dim=0)

Z_2d_test = []
for s in range(N_SAMPLES_TEST):
    # get class_id
    c = torch.argmax(P_test[:,s])
    sample = torch.cat([torch.normal(MU[c,d], STD[c,d], size=(1, 1)) for d in range(SPACE_DIM-1)]+[LAST_DIM_VALUE[c]], dim=0) # n_dim x 1
    Z_2d_test.append(sample)
Z_2d_test = torch.cat(Z_2d_test, dim=1)
Z_test = Z_2d_test / torch.norm(Z_2d_test, p='fro', dim=0, keepdim=True)


In [7]:
# visualize input
if SPACE_DIM == 3:
    fig = go.Figure(data=[go.Scatter(x=Z_2d[0,:], y=Z_2d[1,:], mode='markers', marker=dict(
        size=2,
        color=class_ids,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
    fig.show()
    fig = go.Figure(data=[go.Scatter3d(x=Z[0,:], y=Z[1,:], z=Z[2,:], mode='markers', marker=dict(
        size=1,
        color=class_ids,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
    fig.show()


# Redunet

In [8]:
model = SimpleReduNet(L, epsilon, a0, a1, gamma, eta, temperature, Z, P, use_fast=True)
print('E0: ', model.Elist[0].shape)
print('Cs0: ', model.Cslist[0].shape)

E0:  torch.Size([3, 3])
Cs0:  torch.Size([3, 3, 3])


# TRAIN

In [9]:
fin_Z = model.predict(Z, None, 500) # model.predict(Z, P)
fin_Z = fin_Z.detach().cpu().numpy()
fig = go.Figure(data=[go.Scatter3d(x=fin_Z[0,:], y=fin_Z[1,:], z=fin_Z[2,:], mode='markers', marker=dict(
        size=5,
        color=class_ids,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()
cr_before = coding_rate(Z, P, a0, a1, gamma, use_fast=True)
cr_after = coding_rate(fin_Z, P, a0, a1, gamma, use_fast=True)
print(f'coding rate changed as {cr_before} -> {cr_after}')

coding rate changed as 6.639529228210449 -> 10.346418380737305


# TEST

In [12]:
fin_Z = model.predict(Z_test, None, None)
fin_Z = fin_Z.detach().cpu().numpy()
fig = go.Figure(data=[go.Scatter3d(x=fin_Z[0,:], y=fin_Z[1,:], z=fin_Z[2,:], mode='markers', marker=dict(
        size=5,
        color=class_ids_test,                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.8
    ))])
fig.show()
cr_before = coding_rate(Z_test, P_test, a0, a1, gamma, use_fast=True)
cr_after = coding_rate(fin_Z, P_test, a0, a1, gamma, use_fast=True)
print(f'coding rate changed as {cr_before} -> {cr_after}')

coding rate changed as 6.9936723709106445 -> 10.28071403503418
