In [1]:
import os
import numpy as np
import scipy.sparse.csgraph as csg
from joblib import Parallel, delayed
import multiprocessing
import networkx as nx
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import math

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import learning_util as lu

In [2]:
# Distortion calculations

def acosh(x):
    return torch.log(x + torch.sqrt(x**2-1))

def dist_h(u,v):
    z  = 2*torch.norm(u-v,2)**2
    uu = 1. + torch.div(z,((1-torch.norm(u,2)**2)*(1-torch.norm(v,2)**2)))
    return acosh(uu)

def distance_matrix_euclidean(input):
    row_n = input.shape[0]
    mp1 = torch.stack([input]*row_n)
    mp2 = torch.stack([input]*row_n).transpose(0,1)
    dist_mat = torch.sum((mp1-mp2)**2,2).squeeze()
    return dist_mat

def distance_matrix_hyperbolic(input):
    row_n = input.shape[0]
    dist_mat = torch.zeros(row_n, row_n, device=device)
    for row in range(row_n):
        for i in range(row_n):
            if i != row:
                dist_mat[row, i] = dist_h(input[row,:], input[i,:])
    return dist_mat

def entry_is_good(h, h_rec): return (not torch.isnan(h_rec)) and (not torch.isinf(h_rec)) and h_rec != 0 and h != 0

def distortion_entry(h,h_rec):
    avg = abs(h_rec - h)/h
    return avg

def distortion_row(H1, H2, n, row):
    avg, good = 0, 0
    for i in range(n):
        if i != row and entry_is_good(H1[i], H2[i]):
            _avg = distortion_entry(H1[i], H2[i])
            good        += 1
            avg         += _avg
    if good > 0:
        avg /= good 
    else:
        avg, good = torch.tensor(0., device=device, requires_grad=True), torch.tensor(0., device=device, requires_grad=True)
    return (avg, good)

def distortion(H1, H2, n, jobs=16):
#     dists = Parallel(n_jobs=jobs)(delayed(distortion_row)(H1[i,:],H2[i,:],n,i) for i in range(n))
    dists = (distortion_row(H1[i,:],H2[i,:],n,i) for i in range(n))
    to_stack = [tup[0] for tup in dists]
    avg = torch.stack(to_stack).sum()/n
    return avg


#Loading the graph and getting the distance matrix.

def load_graph(file_name, directed=False):
    G = nx.DiGraph() if directed else nx.Graph()
    with open(file_name, "r") as f:
        for line in f:
            tokens = line.split()
            u = int(tokens[0])
            v = int(tokens[1])
            if len(tokens) > 2:
                w = float(tokens[2])
                G.add_edge(u, v, weight=w)
            else:
                G.add_edge(u,v)
    return G


def compute_row(i, adj_mat): 
    return csg.dijkstra(adj_mat, indices=[i], unweighted=True, directed=False)

def get_dist_mat(G):
    n = G.order()
    adj_mat = nx.to_scipy_sparse_matrix(G, nodelist=list(range(G.order())))
    t = time.time()
    
    num_cores = multiprocessing.cpu_count()
    dist_mat = Parallel(n_jobs=num_cores)(delayed(compute_row)(i,adj_mat) for i in range(n))
    dist_mat = np.vstack(dist_mat)
    return dist_mat


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [3]:
class Vocab:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {}
        self.n_words = 0

    def addSentence(self, sentence):
        for token in sentence:
            self.addWord(token['form'])

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
            

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
from conllu import parse_tree, parse_tree_incr, parse, parse_incr
from io import open
import scipy.sparse.csgraph as csg
import networkx as nx
from collections import defaultdict
import json
import string


def unroll(node, G):
    if len(node.children) != 0:
        for child in node.children:
            G.add_edge(node.token['id'], child.token['id'])
            unroll(child, G)
    return G

sentences = []
data_file = open("UD_English-EWT/en_ewt-ud-train.conllu", "r", encoding="utf-8")
for sentence in parse_incr(data_file):
    sentences.append(sentence)
    
MIN_LENGTH = 10
MAX_LENGTH = 50

def check_length(sentence):
    return len(sentence) < MAX_LENGTH and len(sentence) > MIN_LENGTH 

def filterSentences(sentences):
    return [sent for sent in sentences if check_length(sent)]

input_vocab = Vocab("ewt_train_trimmed")
filtered_sentences = filterSentences(sentences)

sentences_text = []
for sent in filtered_sentences:
    input_vocab.addSentence(sent)
    sentences_text.append(sent.metadata['text'])
    
dev_dict  = {}
for idx in range(0, len(filtered_sentences)):
    curr_tree = filtered_sentences[idx].to_tree()
    G_curr = nx.Graph()
    G_curr = unroll(curr_tree, G_curr)
    G = nx.relabel_nodes(G_curr, lambda x: x-1)
    nx.write_edgelist(G, "train/"+str(idx)+".edges", data=False)
    G_final = nx.convert_node_labels_to_integers(G_curr, ordering = "decreasing degree")
    nx.write_edgelist(G_final, "ewt_train/"+str(idx)+".edges", data=False)
    dev_dict[idx] = list(G_final.edges)



In [5]:
def indexesFromSentence(vocab, sentence):
    return [vocab.word2index[token['form']] for token in sentence]

def tensorFromSentence(vocab, sentence):
    indexes = indexesFromSentence(vocab, sentence)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)

def pairfromidx(idx):
    input_tensor = tensorFromSentence(input_vocab, filtered_sentences[idx])
    G = load_graph("train/"+str(idx)+".edges")
    target_matrix = get_dist_mat(G)
    target_tensor = torch.from_numpy(target_matrix).float().to(device)
    target_tensor.requires_grad = False
    n = G.order()
    return (input_tensor, target_tensor, n, sentences_text[idx])


In [6]:
class EncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
    
    
class HyperbolicEncoderLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(HyperbolicEncoderLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = HyperbolicLSTM(cell_class=HyperbolicLSTMCell, input_size=input_size, hidden_size=hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.lstm(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)
    

class Attention(nn.Module):
    def __init__(self, input_size, hidden_size, max_length=MAX_LENGTH):
        super(Attention, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.max_length = max_length
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)


    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, -1)
        attention_scores = self.attn(torch.cat((embedded[0], hidden.unsqueeze(0)), 1))
        attn_weights = F.softmax(attention_scores, dim=0)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        
        return output

In [7]:
#Hyperbolic modules.

class HypLinear(nn.Module):
    """Applies a hyperbolic "linear" transformation to the incoming data: :math:`y = xA^T + b`
    """

    def __init__(self, in_features, out_features, bias=True):
        super(HypLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))

        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(1, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input_):
        result = lu.torch_hyp_add(lu.torch_mv_mul_hyp(torch.transpose(self.weight,0,1), input_), self.bias) #(batch, input) x (input, output)
        return result

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


In [8]:
def trainVanilla(input_tensor, ground_truth, n, encoder, encoder_optimizer, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()
    encoder_optimizer.zero_grad()
 
    input_length = input_tensor.size(0)
    target_length = ground_truth.size(0)
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)

    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]
    
    dist_recovered = distance_matrix_euclidean(encoder_outputs)
    loss += distortion(ground_truth, dist_recovered, n)
    loss.backward()
    encoder_optimizer.step()

    return loss.item()

In [9]:
def trainWAttention(input_tensor, ground_truth, n, encoder, encoder_optimizer, attention, attention_optimizer, iter, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()
    encoder_optimizer.zero_grad()
    attention_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = ground_truth.size(0)
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    encoder_hiddens = torch.zeros(input_length, encoder.hidden_size, device=device)
    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)

    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]
        encoder_hiddens[ei] = encoder_hidden[0, 0]
        
    for idx in range(input_length):
        output = attention(input_tensor[idx], encoder_hiddens[idx], encoder_outputs)
        final_embeddings[idx] = output[0]
        
    dist_recovered = distance_matrix_euclidean(final_embeddings)
    loss += distortion(ground_truth, dist_recovered, n)
    loss.backward()
    encoder_optimizer.step()
    attention_optimizer.step()

    return loss.item(), final_embeddings

In [31]:
def trainEuclidean(encoder, attention, n_iters=7600, print_every=100, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  
    plot_loss_total = 0  

    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    attention_optimizer = optim.SGD(attention.parameters(), lr=learning_rate)
    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]

    euclidean_emb_dict = {}
    for iter in range(1, n_iters+1):     
        training_pair = training_pairs[iter]
        input_tensor = training_pair[0]
        target_matrix = training_pair[1]
        n = training_pair[2]
        loss, final_embeddings = trainWAttention(input_tensor, target_matrix, n, encoder, encoder_optimizer, attention, attention_optimizer, iter-1)
        torch.save(final_embeddings, "saved_tensors/"+str(iter-1)+".pt")
        euclidean_emb_dict[iter-1] = final_embeddings
#         loss = train(input_tensor, target_matrix, n, encoder, encoder_optimizer)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0
    
    return euclidean_emb_dict

hidden_size = 100
encoder = EncoderLSTM(input_vocab.n_words, hidden_size).to(device)
attention = Attention(input_vocab.n_words, hidden_size).to(device)
euclidean_emb_dict = trainEuclidean(encoder, attention)


In [10]:
euclidean_embeddings = {}
saved_tensors = os.listdir("saved_tensors/")
for file in saved_tensors:
    idx = int(file.split(".")[0])
    euclidean_embeddings[idx] = torch.load("saved_tensors/"+str(file), map_location=torch.device('cpu'))

In [11]:
# #Riemannian SGD

# import glob
# from torch.optim import Optimizer

# class RiemannianSGD(Optimizer):
#     """Riemannian stochastic gradient descent.
#     Args:
#         params (iterable): iterable of parameters to optimize or dicts defining
#             parameter groups
#         lr (float): learning rate
#     """

#     def __init__(self, params, lr):
#         # if lr is not required and lr < 0.0:
#         #     raise ValueError("Invalid learning rate: {}".format(lr))
#         defaults = dict(lr=lr)
#         super(RiemannianSGD, self).__init__(params, defaults)

#     def step(self):
#         """Performs a single optimization step.
#         Arguments:
#             lr (float): learning rate for the current update.
#         """
#         loss = None

#         for group in self.param_groups:
#             for p in group['params']:
#                 if p.grad is None:
#                     continue
#                 d_p = p.grad.data
#                 lr = group['lr']
       
#             if torch.all(p.grad > -1e4) and torch.all(p.grad < 1e4):
#                 p.data.add_(hyperbolic_step(p.data, d_p, lr))

#         return loss

# def batch_dot(u, v):
#     return torch.sum(u * v, dim=-1, keepdim=True)

# def natural_grad(v, dv):
#     vnorm_squared = batch_dot(v, v)
#     dv = dv * ((1 - vnorm_squared) ** 2 / 4).expand_as(dv)
#     return dv

# def batch_add(u, v, c=1):
#     numer = 1 + 2 * batch_dot(u, v) + batch_dot(v, v) * u + (1 - batch_dot(u, u)) * v
#     denom = 1 + 2 * batch_dot(u, v) + batch_dot(v, v) * batch_dot(u, u)

#     return numer/denom

# def batch_exp_map(x, v, c=1):
#     term = torch.tanh((torch.norm(v, dim=-1, keepdim=True) / (1 - torch.norm(x, dim=-1, keepdim=True).pow(2)))) * \
#                  (v/(torch.norm(v, dim=-1, keepdim=True)))
#     return batch_add(x, term, c)

# def hyperbolic_step(param, grad, lr):
#     ngrad = natural_grad(param, grad)
#     return batch_exp_map(param, -lr * ngrad, c=1)
from torch.optim.optimizer import Optimizer, required
spten_t = torch.sparse.FloatTensor


def poincare_grad(p, d_p):
    r"""
    Function to compute Riemannian gradient from the
    Euclidean gradient in the Poincaré ball.
    Args:
        p (Tensor): Current point in the ball
        d_p (Tensor): Euclidean gradient at p
    """
    if d_p.is_sparse:
        p_sqnorm = torch.sum(
            p.data[d_p._indices()[0].squeeze()] ** 2, dim=1,
            keepdim=True
        ).expand_as(d_p._values())
        n_vals = d_p._values() * ((1 - p_sqnorm) ** 2) / 4
        d_p = spten_t(d_p._indices(), n_vals, d_p.size())
    else:
        p_sqnorm = torch.sum(p.data ** 2, dim=-1, keepdim=True)
        d_p = d_p * ((1 - p_sqnorm) ** 2 / 4).expand_as(d_p)

    return d_p


def euclidean_grad(p, d_p):
    return d_p


def retraction(p, d_p, lr):
    if torch.all(d_p < 100) and torch.all(d_p>-100):
        p.data.add_(-lr, d_p)


class RiemannianSGD(Optimizer):
    r"""Riemannian stochastic gradient descent.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        rgrad (Function): Function to compute the Riemannian gradient from
            an Euclidean gradient
        retraction (Function): Function to update the parameters via a
            retraction of the Riemannian gradient
        lr (float): learning rate
    """

    def __init__(self, params, lr=required, rgrad=required, retraction=required):
        defaults = dict(lr=lr, rgrad=rgrad, retraction=retraction)
        super(RiemannianSGD, self).__init__(params, defaults)

    def step(self, lr=None):
        """Performs a single optimization step.
        Arguments:
            lr (float, optional): learning rate for the current update.
        """
        loss = None

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if lr is None:
                    lr = group['lr']
                d_p = group['rgrad'](p, d_p)
                group['retraction'](p, d_p, lr)

        return loss

In [12]:
# Do Euclidean to hyperbolic mapping in one FC layer. (using GT)

def trainFCHyp(euclidean_embs, ground_truth, n, fc, fc_optimizer, max_length=MAX_LENGTH):
    fc_optimizer.zero_grad()
 
    final_embeddings = torch.zeros(fc.in_features, fc.out_features, device=device)

    loss = 0
    for idx in range(fc.in_features):
        output = fc(euclidean_embs[idx])
        final_embeddings[idx] = output[0]

    dist_recovered = distance_matrix_hyperbolic(final_embeddings) 
    loss += distortion(ground_truth, dist_recovered, n)
    loss.backward()
    fc_optimizer.step()

    return loss.item()


In [61]:
def trainFCIters(fc, n_iters=800, print_every=100, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  
    plot_loss_total = 0  

    fc_optimizer = RiemannianSGD(fc.parameters(), lr=learning_rate, rgrad=poincare_grad, retraction=euclidean_retraction)

    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]

    for iter in range(1, n_iters + 1):     
        input_matrix = euclidean_embeddings[iter - 1]
        target_matrix = training_pairs[iter-1][1]
        n = training_pairs[iter-1][2]
        loss = trainFCHyp(euclidean_embeddings, target_matrix, n, fc, fc_optimizer)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

input_size = 100
output_size = 10
fc = nn.Linear(input_size, output_size).to(device)
trainFCIters(fc)

3m 49s (- 26m 49s) (100 12%) 0.5048
7m 31s (- 22m 33s) (200 25%) 0.5074
20m 41s (- 34m 28s) (300 37%) 0.5064
94m 6s (- 94m 6s) (400 50%) 0.4960
97m 48s (- 58m 40s) (500 62%) 0.5023
101m 30s (- 33m 50s) (600 75%) 0.4972
105m 12s (- 15m 1s) (700 87%) 0.4920
108m 54s (- 0m 0s) (800 100%) 0.4908


In [None]:
from torch.autograd import Variable


def dot(x,y): return torch.sum(x * y, -1)
def acosh(x):
    return torch.log(x + torch.sqrt(x**2-1))


class RParameter(nn.Parameter):
    def __new__(cls, data=None, requires_grad=True, sizes=None, exp=False):
        if data is None:
            assert sizes is not None
            data = (1e-3 * torch.randn(sizes, dtype=torch.double)).clamp_(min=-3e-3,max=3e-3)
        #TODO get partial data if too big i.e. data[0:n,0:d]
        ret =  super().__new__(cls, data, requires_grad=requires_grad)
        # ret.data    = data
        ret.initial_proj()
        ret.use_exp = exp
        return ret

    @staticmethod
    def _proj(x):
        raise NotImplemented

    def proj(self):
        self.data = self.__class__._proj(self.data.detach())
        # print(torch.norm(self.data, dim=-1))

    def initial_proj(self):
        """ Project the initialization of the embedding onto the manifold """
        self.proj()

    def modify_grad_inplace(self):
        pass

    @staticmethod
    def correct_metric(ps):
        for p in ps:
            if isinstance(p,RParameter):
                p.modify_grad_inplace()


# TODO can use kwargs instead of pasting defaults
class HyperboloidParameter(RParameter):
    def __new__(cls, data=None, requires_grad=True, sizes=None, exp=True):
        if sizes is not None:
            sizes = list(sizes)
            sizes[-1] += 1
        return super().__new__(cls, data, requires_grad, sizes, exp)

    @staticmethod
    def dot_h(x,y):
        return torch.sum(x * y, -1) - 2*x[...,0]*y[...,0]
    @staticmethod
    def norm_h(x):
        assert torch.all(HyperboloidParameter.dot_h(x,x) >= 0), torch.min(HyperboloidParameter.dot_h(x,x))
        return torch.sqrt(torch.clamp(HyperboloidParameter.dot_h(x,x), min=0.0))
    @staticmethod
    def dist_h(x,y):
        bad = torch.min(-HyperboloidParameter.dot_h(x,y) - 1.0)
        if bad <= -1e-4:
            print("bad dist", bad.item())
        return acosh(torch.clamp(-HyperboloidParameter.dot_h(x,y), min=(1.0+1e-8)))

    @staticmethod
    def _proj(x):
        """ Project onto hyperboloid """
        x_ = torch.tensor(x)
        x_tail = x_[...,1:]
        current_norms = torch.norm(x_tail,2,-1)
        scale      = (current_norms/1e7).clamp_(min=1.0)
        x_tail /= scale.unsqueeze(-1)
        x_[...,1:] = x_tail
        x_[...,0] = torch.sqrt(1 + torch.norm(x_tail,2,-1)**2)

        debug = True
        if debug:
            bad = torch.min(-HyperboloidParameter.dot_h(x_,x_))
            if bad <= 0.0:
                print("way off hyperboloid", bad)
            assert torch.all(-HyperboloidParameter.dot_h(x_,x_) > 0.0), f"way off hyperboloid {torch.min(-HyperboloidParameter.dot_h(x_,x_))}"
        xxx = x_ / torch.sqrt(torch.clamp(-HyperboloidParameter.dot_h(x_,x_), min=0.0)).unsqueeze(-1)
        return xxx
        # return x / (-HyperboloidParameter.norm_h(x)).unsqueeze(-1)

    def initial_proj(self):
        """ Project the initialization of the embedding onto the manifold """
        self.data[...,0] = torch.sqrt(1 + torch.norm(self.data.detach()[...,1:],2,-1)**2)
        self.proj()


    def exp(self, lr):
        """ Exponential map """
        x = self.data.detach()
        # print("norm", HyperboloidParameter.norm_h(x))
        v = -lr * self.grad

        retract = False
        if retract:
        # retraction
            # print("retract")
            self.data = x + v

        else:
            assert torch.all(1 - torch.isnan(v))
            n = self.__class__.norm_h(v).unsqueeze(-1)
            assert torch.all(1 - torch.isnan(n))
            n.clamp_(max=1.0)
            # e = torch.cosh(n)*x + torch.sinh(n)*v/n
            mask = torch.abs(n)<1e-7
            cosh = torch.cosh(n)
            cosh[mask] = 1.0
            sinh = torch.sinh(n)
            sinh[mask] = 0.0
            n[mask] = 1.0
            e = cosh*x + sinh/n*v
            # assert torch.all(-HyperboloidParameter.dot_h(e,e) >= 0), torch.min(-HyperboloidParameter.dot_h(e,e))
            self.data = e
        self.proj()


    def modify_grad_inplace(self):
        """ Convert Euclidean gradient into Riemannian """
        self.grad[...,0] *= -1
        self.grad -= self.__class__.dot_h(self.data, self.grad).unsqueeze(-1) / HyperboloidParameter.dot_h(self.data, self.data).unsqueeze(-1) * self.data


class Embedding(nn.Module):
    def __init__(self, dist_fn, param_cls, n, d, project=True, initialize=None, learn_scale=False, initial_scale=0.0):
        super().__init__()
        self.dist_fn = dist_fn
        self.n, self.d = n, d
        self.project   = project
        if initialize is not None: logging.info(f"Initializing {np.any(np.isnan(initialize.numpy()))} {initialize.size()} {(n,d)}")
        self.w = param_cls(data=initialize, sizes=(n,d))
        z =  torch.tensor([0.0], dtype=torch.double)
        if learn_scale:
            self.scale_log       = nn.Parameter(torch.tensor([initial_scale], dtype=torch.double))
        else:
            self.scale_log       = torch.tensor([initial_scale], dtype=torch.double, device=device)

    def normalize(self):
        self.w.proj()


In [13]:
def trainHyperbolic(input_tensor, ground_truth, n, encoder, encoder_optimizer, fc, fc_optimizer, max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()
    encoder_optimizer.zero_grad()
    fc_optimizer.zero_grad()
 
    input_length = input_tensor.size(0)
    target_length = ground_truth.size(0)
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    final_embeddings = torch.zeros(input_length, encoder.hidden_size, device=device)

    loss = 0
    for ei in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[ei], encoder_hidden)
        encoder_outputs[ei] = encoder_output[0, 0]
        
    for idx in range(input_length):
        output = fc(encoder_outputs[idx])
        final_embeddings[idx] = output[0]

    dist_recovered = distance_matrix_hyperbolic(final_embeddings) 
    loss += distortion(ground_truth, dist_recovered, n)
    loss.backward()
    encoder_optimizer.step()
    fc_optimizer.step()

    return loss.item()



In [14]:
#Does end to end hyperbolic.

def trainHypIters(encoder, fc, n_iters=800, print_every=100, plot_every=100, learning_rate=0.01):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  
    plot_loss_total = 0  

    encoder_optimizer = RiemannianSGD(encoder.parameters(), lr=learning_rate)
    fc_optimizer = RiemannianSGD(fc.parameters(), lr=learning_rate)
    training_pairs = [pairfromidx(idx) for idx in range(n_iters)]

    for iter in range(1, n_iters + 1):     
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_matrix = training_pair[1]
        n = training_pair[2]
        loss = trainHyperbolic(input_tensor, target_matrix, n, encoder, encoder_optimizer, fc, fc_optimizer)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
                                         iter, iter / n_iters * 100, print_loss_avg))

        if iter % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0



In [15]:
hidden_size = 100
encoder = EncoderLSTM(input_vocab.n_words, hidden_size).to(device)
fc = nn.Linear(hidden_size, hidden_size).to(device)
# trainHypIters(encoder, fc)