In [8]:
#only external lib allowed
import torch
import torch.nn as nn

import heapq # stdlib, basically the huffman tree.

In [9]:
class HierarchicalSoftmaxNode(nn.Module):
    
    def __init__(self, symbol, freq, classifier_vector=None, left=None, right=None):
        
        super(HierarchicalSoftmaxNode, self).__init__()
        self.classifier_vector = classifier_vector  # Classifier vector associated with this node
        self.left = left  # Reference to the left child node
        self.right = right  # Reference to the right child node
        self.symbol = symbol    # Leaf node: character
        self.freq = freq  # Internal node: value is the frequency


        
        # tree direction 0 or 1 
        self.huff = ''
        
        
    #comparator, lt
    def __lt__(self, other):
        return self.freq < other.freq
    
    #tostring   
    def __str__(self):
        return '(' + str(self.symbol)+','+str(self.freq) + ',' + str(self.huff) + ')'
    
        
    
    def forward(self, input):
        # Compute the probability based on the classifier vector
        # use a sigmoid function here
        return torch.sigmoid(torch.dot(self.classifier_vector, input))

In [10]:
def dict_to_nodes(d,vector_size):
    
    nodes = []
    
    for k,v in d.items():

        heapq.heappush(nodes, HierarchicalSoftmaxNode(symbol=k,freq=v,classifier_vector=torch.randn(vector_size)))
    

    return nodes

In [11]:
#returns root of the tree.
def build_tree(nodes):
    
    while len(nodes) > 1:
  
        # sort all the nodes in ascending order 
        # based on their frequency 
        left = heapq.heappop(nodes)
        right = heapq.heappop(nodes)

        # assign directional value to these nodes 
        # least
        
        left.huff = '0'
        #2nd least
        right.huff = '1'
        
        # combine the 2 smallest nodes to create 
        # new INTERNAL node as their parent 
        newNode = HierarchicalSoftmaxNode(symbol=left.symbol+right.symbol,freq=left.freq+right.freq, left=left, right=right,classifier_vector=torch.randn(vector_size))
        
        heapq.heappush(nodes, newNode)
        
    return nodes

In [12]:
# utility function for printing
def in_order_traversal(node,path):
    
    if node.left != None:
        
        in_order_traversal(node.left,path + node.left.huff)
    
    print(node)
    print(path)
    
    if node.right != None:
  
        in_order_traversal(node.right,path + node.huff)
    

In [6]:
# # Example usage:
# # Create a simple hierarchical softmax tree structure with three nodes
# node1 = HierarchicalSoftmaxNode(classifier_vector=torch.randn(10))
# node2 = HierarchicalSoftmaxNode(classifier_vector=torch.randn(10))
# node3 = HierarchicalSoftmaxNode(classifier_vector=torch.randn(10))

# # Set the hierarchy by connecting nodes
# node1.left = node2
# node1.right = node3

# # Calculate probability for a binary path (e.g., 01)
# #random input vector...
# input_vector = torch.randn(10)  # Example input vector
# print(input_vector)
# path = [0, 1]
# current_node = node1

# for bit in path:
#     if bit == 0:
#         current_node.left = node2
#     else:
#         current_node.left = node3

# print(current_node)

        
# probability = current_node(input_vector)
# print("Probability:", probability.item())

In [13]:
# Example usage
if __name__ == "__main__":

    ft = {
        
        'c':63,
        'a':27,
        'z':72,
        'b':17,
        'y':89
        
    }
    #vector size needs to be consistent
    vector_size = 100
    nodes = dict_to_nodes(ft,vector_size)

 
    ht = build_tree(nodes)
    root = ht[0]
#     print(root.left.left.left)
    print(in_order_traversal(root,''))

    

(b,17,0)
000
(ba,44,0)
00
(a,27,1)
000
(bac,107,0)
0
(c,63,1)
00
(baczy,268,)

(z,72,0)
0
(zy,161,1)

(y,89,1)
1
None


In [16]:
# # Calculate probability for a binary path (e.g., 01)
# #random input vector...

#word inputs are from the first model.
input_vector = torch.randn(100)  # Example input vector
print(input_vector)

path = [0, 1]
# 
# nodec = HierarchicalSoftmaxNode(symbol='s',freq=2,classifier_vector=torch.randn(100))

# for bit in path:
#     if bit == 0:
#         current_node.left = node2
#     else:
#         current_node.left = node3

# print(current_node)

        
probability = root(input_vector)
print("Probability:", probability.item())

tensor([-0.4548,  2.0100, -0.9720, -0.1868,  1.0166,  0.2454,  0.9000,  0.2919,
         0.8699, -1.0776, -0.7021,  0.3959, -1.1772,  0.0042, -1.6499, -0.8422,
        -1.2302, -0.3482,  0.1525,  0.5280, -0.1044,  0.3351,  0.1235, -0.7984,
         0.4337,  0.1426, -1.2984, -0.0863,  0.2242,  1.3191, -0.3685, -0.4374,
        -0.9463,  0.1679, -0.2838, -0.0890, -1.8985, -0.2662,  0.2577,  1.0705,
        -0.3483,  0.7257, -0.4420, -1.2803, -0.0086,  0.3645, -2.6064, -0.5855,
        -0.9720,  1.0682,  0.2688,  2.0939,  0.1621, -0.5371,  0.1490, -0.5604,
        -1.2002, -0.7579,  0.4235,  1.6897,  0.3964, -0.3861,  0.5726, -0.6173,
        -0.5422,  1.2196, -0.4933, -0.9939,  0.4104,  0.4222,  0.9236,  0.3939,
        -0.5635, -0.0964,  0.6098,  0.3082, -1.5390, -1.3897, -1.0531,  1.2085,
        -1.0832, -0.6796,  0.2832, -1.1013,  0.5982, -0.7136, -0.6607, -0.3279,
         0.3901,  0.7604,  1.2414,  1.0218,  0.2926,  0.5706, -0.9382, -0.4806,
        -0.2924,  0.5856, -1.2028, -0.85