In [13]:
import matplotlib.pyplot as plt
import geoopt
import torch
import itertools
import torch.nn as nn
import numpy as np
import tqdm
%matplotlib inline

In [5]:
#@title
class SyntheticDataset(torch.utils.data.Dataset):
    '''
    Adopted from https://github.com/emilemathieu/pvae/blob/ca5c4997a90839fc8960ec812df4cbf83da55781/pvae/datasets/datasets.py
    
    Implementation of a synthetic dataset by hierarchical diffusion. 
    Args:
    :param int dim: dimension of the input sample
    :param int depth: depth of the tree; the root corresponds to the depth 0
    :param int :numberOfChildren: Number of children of each node in the tree
    :param int :numberOfsiblings: Number of noisy observations obtained from the nodes of the tree
    :param float sigma_children: noise
    :param int param: integer by which :math:`\\sigma_children` is divided at each deeper level of the tree
    '''
    def __init__(self, ball, dim, depth, numberOfChildren=2, dist_children=1, sigma_sibling=2, param=1, numberOfsiblings=1):
        assert numberOfChildren == 2
        self.dim = int(dim)
        self.ball = ball
        self.root = ball.origin(self.dim)
        self.sigma_sibling = sigma_sibling
        self.depth = int(depth)
        self.dist_children = dist_children
        self.numberOfChildren = int(numberOfChildren)
        self.numberOfsiblings = int(numberOfsiblings)  
        self.__class_counter = itertools.count()
        self.origin_data, self.origin_labels, self.data, self.labels = map(torch.detach, self.bst())
        self.num_classes = self.origin_labels.max().item()+1
        #self.data = ball.mobius_add(self.data, -ball.weighted_midpoint(self.data)).detach()

    def __len__(self):
        '''
        this method returns the total number of samples/nodes
        '''
        return len(self.data)

    def __getitem__(self, idx):
        '''
        Generates one sample
        '''
        data, labels = self.data[idx], self.labels[idx]
        return data, labels, labels.max(-1).values

    def get_children(self, parent_value, parent_label, current_depth, offspring=True):
        '''
        :param 1d-array parent_value
        :param 1d-array parent_label
        :param int current_depth
        :param  Boolean offspring: if True the parent node gives birth to numberOfChildren nodes
                                    if False the parent node gives birth to numberOfsiblings noisy observations
        :return: list of 2-tuples containing the value and label of each child of a parent node
        :rtype: list of length numberOfChildren
        '''
        if offspring:
            numberOfChildren = self.numberOfChildren
            sigma = self.dist_children
        else:
            numberOfChildren = self.numberOfsiblings
            sigma = self.sigma_sibling
        if offspring:
            direction = torch.randn_like(parent_value)
            parent_value_n = parent_value / parent_value.norm().clamp_min(1e-15)
            direction -= parent_value_n @ direction * parent_value_n
            child_value_1 = ball.geodesic_unit(torch.tensor(sigma), parent_value, direction)
            child_value_2 = ball.geodesic_unit(torch.tensor(sigma), parent_value, -direction)
            child_label_1 = parent_label.clone()
            child_label_1[current_depth] = next(self.__class_counter)
            child_label_2 = parent_label.clone()
            child_label_2[current_depth] = next(self.__class_counter)
            children = [
                (child_value_1, child_label_1),
                (child_value_2, child_label_2)
            ]
        else:
            children = []
            for i in range (numberOfChildren):
                child_value = ball.random(self.dim, mean=parent_value, std=sigma ** .5)
                child_label = parent_label.clone()
                children.append((child_value, child_label))
        return children

    def bst(self):
        '''
        This method generates all the nodes of a level before going to the next level
        '''
        label = -torch.ones(self.depth+1, dtype=torch.long)
        label[0] = next(self.__class_counter)
        queue = [(self.root, label, 0)]
        visited = []
        labels_visited = []
        values_clones = []
        labels_clones = []
        while len(queue) > 0:
            current_node, current_label, current_depth = queue.pop(0)
            visited.append(current_node)
            labels_visited.append(current_label)
            if current_depth < self.depth:
                children = self.get_children(current_node, current_label, current_depth)
                for child in children:
                    queue.append((child[0], child[1], current_depth + 1)) 
            if current_depth <= self.depth:
                clones = self.get_children(current_node, current_label, current_depth, False)
                for clone in clones:
                    values_clones.append(clone[0])
                    labels_clones.append(clone[1])
        length = int(((self.numberOfChildren) ** (self.depth + 1) - 1) / (self.numberOfChildren - 1))
        length_leaves = int(self.numberOfChildren**self.depth)
        images = torch.cat([i for i in visited]).reshape(length, self.dim)
        labels_visited = torch.cat([i for i in labels_visited]).reshape(length, self.depth+1)[:,:self.depth]
        values_clones = torch.cat([i for i in values_clones]).reshape(self.numberOfsiblings*length, self.dim)
        labels_clones = torch.cat([i for i in labels_clones]).reshape(self.numberOfsiblings*length, self.depth+1)
        return images, labels_visited, values_clones, labels_clones

In [2]:
# 하이퍼볼릭 공간 정의 (Poincare ball model)
ball = geoopt.PoincareBall(c=1.0)

# 임의의 점 생성 및 하이퍼볼릭 공간에 임베딩
num_points = 100
dim = 2  # 2차원으로 시각화하기 쉽게 설정

In [11]:
# ManifoldParameter를 사용하여 하이퍼볼릭 공간 내의 점들 초기화
points = geoopt.ManifoldParameter(torch.randn(num_points, dim), manifold=ball).proj_()

# # detach()와 proj_()를 사용하여 프로젝션 수행
# projected_points = points.detach().clone()
# projected_points = projected_points.proj_()

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

In [None]:

# 하이퍼볼릭 점들 시각화
plot_poincare_disc(points)
plt.show()
