In [2]:
import torch
from torch import softmax
from torch.nn import CrossEntropyLoss, NLLLoss, KLDivLoss
from torch import nn
from patbert.common import medical
from os.path import join
data_path = "C:\\Users\\fjn197\\PhD\\projects\\PHAIR\\pipelines\\patbert\\data"
import pickle

  from .autonotebook import tqdm as notebook_tqdm


pytorch crossentropyloss takes logits, since we want to operate on probabilities, it's better to use NLL and sum,
let's see if we get the same results

In [3]:
logits = torch.tensor([1,5,5], dtype=torch.float32)
logprobs = torch.log(softmax(logits, dim=0))
clas_true = torch.tensor([1])
cr = CrossEntropyLoss()
print('CE',cr(logits.view(1,-1), clas_true))
nll = NLLLoss()
print('NLL',nll(logprobs.view(1,-1), clas_true))
cr(logits.view(1,-1), clas_true)==nll(logprobs.view(1,-1), clas_true)
print("We can use NLL loss with logprobs")

CE tensor(0.7023)
NLL tensor(0.7023)
We can use NLL loss with logprobs


we need also a loss which can take soft labels, here we use directly KL divergence

In [4]:
kl = KLDivLoss()
probs = torch.tensor([[1, 0.8, 0.1],[0.2,0.7,1]], dtype=torch.float)
probs_true = torch.tensor([[1, 0, 0],[0,0,1]], dtype=torch.float)
kl(probs.view(1,-1), probs_true.view(1,-1))



tensor(-0.3333)

## Lets try to implement the masking with a target in parallel, to access leaf nodes

In [5]:
# example targets and tree
leaf_nodes = torch.tensor([[1,2,1],[1,2,2], [1,1,0], [2,1,2],[2,1,3]])
y_true_enc = torch.tensor([[[1,2,0], [1,0,0], [2,1,0], [2,1,0]],
                        [[1,2,0], [1,1,0], [2,0,0], [2,1,0],]], requires_grad=False)
leaf_logits = torch.tensor([[10.0, 12, 4, 3, 0],[1,2,.3,0,0]], dtype=torch.float32, requires_grad=True)
leaf_probs = softmax(leaf_logits, dim=-1)
print(leaf_probs)

tensor([[1.1915e-01, 8.8044e-01, 2.9535e-04, 1.0865e-04, 5.4096e-06],
        [2.0199e-01, 5.4908e-01, 1.0031e-01, 7.4310e-02, 7.4310e-02]],
       grad_fn=<SoftmaxBackward0>)


In [6]:
# should be part of a class, leaf nodes stay the same
def get_leaf_node_probabilities(leaf_probs:torch.tensor, y_true_enc: torch.tensor, leaf_nodes: torch.tensor):
    """Selects leaf probabilities for a given target tensor.
    Args:
        leaf_probs (torch.tensor): Probabilities (batchsize, num_leaf_nodes)
        y_true_enc (torch.tensor): Target vector (batchsize, seq_len, levels)
        leaf_nodes (torch.tensor): Leaf nodes (num_leaf_nodes, levels)
    Returns:
        torch.tensor: Selected leaf probabilities (batchsize, seq_len)"""
    # we want to match all the leaf nodes with a target, e.g. target: 1,2,0 should select 1,2,1 and 1,2,2
    zeros_mask = y_true_enc == 0
    leaf_mask = (leaf_nodes == y_true_enc[:, :, None, :]) | zeros_mask[:, :, None,:] # select all leaf nodes that match the target
    leaf_mask = leaf_mask.all(dim=-1).to(torch.int16)

    leaf_probs = leaf_probs[:,None,:].expand(leaf_mask.shape) # batch, seq_len, num_leafes
    selected_leaf_probs = leaf_probs * leaf_mask
    selected_leaf_probs = selected_leaf_probs.sum(dim=-1)
    return selected_leaf_probs

In [7]:
# final_probs = get_leaf_node_probabilities(leaf_logits, y_true_enc, leaf_nodes)
def flat_softmax_cross_entropy(leaf_logits, y_true_enc, leaf_nodes):
    leaf_probs = softmax(leaf_logits, dim=-1)
    selected_leaf_probs = get_leaf_node_probabilities(leaf_probs, y_true_enc, leaf_nodes)
    # print(selected_leaf_probs)
    log_probs = torch.log(selected_leaf_probs)
    log_probs = log_probs.flatten() # batchsize * seq_len
    loss = nll(log_probs.unsqueeze(-1), torch.zeros_like(log_probs, dtype=torch.int64))
    return loss
flat_softmax_cross_entropy(leaf_logits, y_true_enc, leaf_nodes)

tensor(3.0696, grad_fn=<NllLossBackward0>)

 We can alternatively perform matrix multiplication to go up one level of hierarchy from the leaf probabilities
 for example if we have three leaf probabilities [p1, p2, p3] and p1 and p2 are from the same parent, we would multiply by a matrix [[1, 1, 0], [0,0,1]] from the left, to get [p1+p2, p3]
 for details see Hierarchical Classification at multiple operating points

## Loss on multiple levels

it might be easier to construct the targets on different levels and compute the loss as we did previously
e.g. [1,2,0] would be [1,0,0], [1,2,0], [1,2,int] on the three levels that are present
We might either say that on the lowest level we don't have a target and the loss is thus 0*log(q_i) or we maximize the entropy on levels below to not allow for random predictions, meaning
l_below_target = 1/n*sum(log(q_i))

In [178]:
# we reuse this example
leaf_nodes = torch.tensor([[1,2,1],[1,2,2], [1,1,0], [2,1,2],[2,1,3], [3,0,0]])
y_true_enc = torch.tensor([[[1,2,0], [1,0,0], [2,1,0], [2,1,0], [3,0,0]],
                        [[1,2,1], [1,1,0], [2,1,2], [2,1,0],[1,0,0]]], requires_grad=False)
leaf_logits = torch.tensor([[10.0, 12, 4, 3, 4, 0],[1,2,.3,0,0,0]], dtype=torch.float32, requires_grad=True)
leaf_probs = softmax(leaf_logits, dim=-1)
print(leaf_probs)

tensor([[1.1912e-01, 8.8018e-01, 2.9527e-04, 1.0862e-04, 2.9527e-04, 5.4080e-06],
        [1.8802e-01, 5.1110e-01, 9.3369e-02, 6.9170e-02, 6.9170e-02, 6.9170e-02]],
       grad_fn=<SoftmaxBackward0>)


 Level 1

In [179]:
from typing import List, Dict
from patbert.features import utils
class FlatSoftmaxMultOP:
    def __init__(self, leaf_nodes, trainable_weights=0) -> None:
        self.leaf_nodes = leaf_nodes
        self.lvl_mappings = self.get_level_mappings()
        self.lvl_sel_mats = self.get_level_selection_mats()
        weights = self.initialize_geometric_weights()
        if trainable_weights:
            weights.requires_grad = True
        

    def __call__(self, predicted_probs:torch.tensor, y_true_enc:torch.tensor)->float:
        level_probs = []
        for i, mat in enumerate(self.lvl_sel_mats):
            print('mat', mat)
            print('pred', predicted_probs)
            prob_lvl = mat @ predicted_probs
            y_true_lvl = y_true_enc[:, :i+1]
            print('lvl_prob:', y_true_lvl)


    def initialize_geometric_weights(self):
        """We initialize weights as e**(-i)"""
        return torch.exp(-1*torch.arange(len(self.lvl_sel_mats)))

    def get_true_prob(level, y_true_enc):
        """Returns the probability for a given level and target"""
        pass

    def get_level_mappings(self):
        """Returns a dictionary, where the key is the level and the value is a tensor
        with the corresponding indices for this level.
        """
        lvl_mappings = []
        for i in range(len(self.leaf_nodes[0])):
            leaf_nodes_part = self.leaf_nodes[:,:i+1]
            _, unique_indices = torch.unique(leaf_nodes_part, dim=0, return_inverse=True) # we can use these to enumerate the target, in order to access the correct prob
            lvl_mappings.append([leaf_nodes_part, unique_indices])
        return lvl_mappings

    def get_level_selection_mats(self)->List[torch.tensor]:
        """For every level builds a matrix, such that when multiplied from the left
        with leaf probabilities, returns probabs for every node on this level. 
        Matrices are returned as list. 
        """
        mats = []
        nodes = []
        for i in range(len(self.leaf_nodes[0])):
            leaf_nodes_part = self.leaf_nodes[:,:i+1]
            # remove zero nodes (below leaf)
            unique_nodes, unique_indices = torch.unique(leaf_nodes_part, dim=0, return_inverse=True) # we can use these to enumerate the target, in order to access the correct prob
            # print(unique_nodes)
            # print('indices:', unique_indices)
            # Map each unique row to an integer based on its position in the sorted unique tensor
            mat = self.create_leaf_selection_matrix(unique_indices)
            # zero means we are below the leaf, so we set the corresponding row to zero
            zero_mask = leaf_nodes_part[:,-1]==0
            mat[:,zero_mask] = 0 
            all_zero_mask = mat.sum(dim=1)!=0
            mat = mat[all_zero_mask] # remove zero rows (corresponding to nodes from other levels)
            # mat = mat[mat.nonzero(dim=1)]
            unique_nodes = unique_nodes[all_zero_mask]
            mats.append(mat)
            nodes.append(unique_nodes)
        return mats, nodes
    # write a function that takes torch.tensor([1,1,1,2,2]) and returns a matrix that looks like this torch.tensor([[1,1,1,0,0],[0,0,0,1,1]])
    @staticmethod
    def create_leaf_selection_matrix(indices):
        """This function takes an array of integers and 
        returns a matrix where each row is a one-hot encoded version of the input.
        e,g, [1,1,1,2,2] -> [[1,1,1,0,0],[0,0,0,1,1]]]"""
        unique_values = torch.unique(indices)
        mask = (indices.unsqueeze(0) == unique_values.unsqueeze(1)).float()
        return mask

In [180]:
leaf_nodes

tensor([[1, 2, 1],
        [1, 2, 2],
        [1, 1, 0],
        [2, 1, 2],
        [2, 1, 3],
        [3, 0, 0]])

In [181]:
fsm = FlatSoftmaxMultOP(leaf_nodes)
mats, nodes = fsm.get_level_selection_mats()
mats, nodes

([tensor([[1., 1., 1., 0., 0., 0.],
          [0., 0., 0., 1., 1., 0.],
          [0., 0., 0., 0., 0., 1.]]),
  tensor([[0., 0., 1., 0., 0., 0.],
          [1., 1., 0., 0., 0., 0.],
          [0., 0., 0., 1., 1., 0.]]),
  tensor([[1., 0., 0., 0., 0., 0.],
          [0., 1., 0., 0., 0., 0.],
          [0., 0., 0., 1., 0., 0.],
          [0., 0., 0., 0., 1., 0.]])],
 [tensor([[1],
          [2],
          [3]]),
  tensor([[1, 1],
          [1, 2],
          [2, 1]]),
  tensor([[1, 2, 1],
          [1, 2, 2],
          [2, 1, 2],
          [2, 1, 3]])])

Here, we will access the proper indices, which we then use to obtain the probabilities

In [182]:
# for one target it's easy
selection_index = (nodes[1]==torch.tensor([1,2])).all(dim=1).nonzero()[0]
print(selection_index)

tensor([1])


Now we do it for two batches of sequences of targets

In [183]:
print(leaf_nodes.shape, 'leaf_nodes') # n_nodes, levels
print(y_true_enc.shape, 'y_true_enc') # batchsize, seq_len, levels

torch.Size([6, 3]) leaf_nodes
torch.Size([2, 5, 3]) y_true_enc


In [184]:
nodes[2]

tensor([[1, 2, 1],
        [1, 2, 2],
        [2, 1, 2],
        [2, 1, 3]])

Level 3

In [217]:
A = torch.zeros(2,4)
m = torch.tensor([[1,0,0,0,0],[1,0,0,1,1]])
probs = 1/m.sum(dim=0).float()
probs.repeat(2,1)

tensor([[0.5000,    inf,    inf, 1.0000, 1.0000],
        [0.5000,    inf,    inf, 1.0000, 1.0000]])

In [260]:
def get_target_child_mask(target_mask, y_flat):
    """Returns a mask that is true if the node is a child of the target"""
    target_mask_parent = target_mask[:,:,:-1].all(dim=2) & (~target_mask[:,:,-1])  # mask nodes that match target up to the last level 
    target_mask_parent = target_mask_parent & (y_flat[:,-1]==0) # on the last level the target is zero
    return target_mask_parent

def get_child_probabilities(target_child_mask):
    """Child probabilities are computed as 1/number of children"""
    return (1 / target_child_mask.sum(dim=0).float()).repeat(target_child_mask.shape[0], 1)

def get_target_mask(nodes_lvl, y_flat):
    """Returns a mask of shape num_nodes x num_targets x n_levels that is true if the node matches the target"""
    return nodes_lvl[:,None,:]==y_flat

In [270]:
def construct_target_probability_matrix(nodes, y_true_enc, level):
    """
    For a specific level, construct a matrix of target probabilities.
    If target is given on level: one hot encoding
    elif target is given on a higher level: probabilities split equally among all children
    else: zero
    Parameters: 
        nodes: list of tensors, where each tensor is a list of nodes for a specific level
        y_true_enc: target tensor of shape (batchsize, seq_len, levels) 
        level: level for which we want to construct the matrix (starts with 1)
    Returns:
        A: matrix of shape (batchsize x seq_len, num_nodes) where each row corresponds to a target and each column to a node on the given level"""
    nodes_lvl = nodes[level-1]
    num_nodes = nodes_lvl.shape[0]

    y_flat = torch.flatten(y_true_enc[:,:,:level], start_dim=0, end_dim=1) # flatten batch dim to simplify

    probability_matrix = torch.zeros((num_nodes, y_flat.shape[0])) # we populate this matrix num_nodes x num_targets

    target_mask = get_target_mask(nodes_lvl, y_flat) # mask nodes that match target, shape: num_nodes x num_targets x levels

    target_mask_exact = target_mask.all(dim=2) # mask nodes that match target exactly
    probability_matrix[target_mask_exact] = 1
    target_child_mask = get_target_child_mask(target_mask, y_flat)
    
    child_probabilities = get_child_probabilities(target_child_mask)

    probability_matrix[target_child_mask] = child_probabilities[target_child_mask]

    return probability_matrix.T

In [272]:
construct_target_probability_matrix(nodes, y_true_enc, 1)

torch.Size([3, 1]) torch.Size([10, 1])
torch.Size([3, 10, 1]) target_mask


tensor([[1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.],
        [1., 0., 0.],
        [0., 1., 0.],
        [0., 1., 0.],
        [1., 0., 0.]])

In [269]:
nodes[1]

tensor([[1, 1],
        [1, 2],
        [2, 1]])

In [264]:
y_true_enc

tensor([[[1, 2, 0],
         [1, 0, 0],
         [2, 1, 0],
         [2, 1, 0],
         [3, 0, 0]],

        [[1, 2, 1],
         [1, 1, 0],
         [2, 1, 2],
         [2, 1, 0],
         [1, 0, 0]]])

In [240]:

y_flat = torch.flatten(y_true_enc, start_dim=0, end_dim=1) # remove batch dimension to make it simples

node = nodes[-1]
num_nodes = node.shape[0]
probability_matrix = torch.zeros((y_flat.shape[0], num_nodes))
mask = (node[:,None,:]==y_flat)
# populate exact
mask_exact = mask.all(dim=2)
# probability 1 for exact matches else zero
# print(mask_exact)
probability_matrix[mask_exact.T] = 1
# print(probability_matrix)
# now let's treat targets, which are parents of the nodes on this level
print(y_flat.shape)
print(mask.shape)
mask_above = mask[:,:,:-1].all(dim=2) & (~mask[:,:,-1]) 
mask_above = mask_above & (y_flat[:,-1]==0)
# print(mask_above.shape, 'mask_above')
print(1 / mask_above.sum(dim=0).float())
child_probabilities = (1 / mask_above.sum(dim=0).float()).repeat(mask_above.shape[0], 1).T
probability_matrix[mask_above.T] = child_probabilities[mask_above.T]
print(probability_matrix)
# idx = reversed(torch.Tensor(range(1,8)))

# print(mask.shape)
mask_exact = mask.all(dim=2)
mask_missing = (~mask).all(dim=2) # we don't need this, since probs can be just zero

print(mask.shape)
print(mask[2])
# print(mask_missing)
print(mask_above)

torch.Size([10, 3])
torch.Size([4, 10, 3])
tensor([0.5000,    inf, 0.5000, 0.5000,    inf,    inf,    inf,    inf, 0.5000,
           inf])
tensor([[0.5000, 0.5000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5000, 0.5000],
        [0.0000, 0.0000, 0.5000, 0.5000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.0000, 0.0000],
        [0.0000, 0.0000, 0.5000, 0.5000],
        [0.0000, 0.0000, 0.0000, 0.0000]])
torch.Size([4, 10, 3])
tensor([[False, False, False],
        [False, False, False],
        [ True,  True, False],
        [ True,  True, False],
        [False, False, False],
        [False, False, False],
        [False,  True, False],
        [ True,  True,  True],
        [ True,  True, False],
        [False, False, False]])
tensor([[ True, False, False, False, False, False, False, False, False, False],
        [ True, False, Fa

In [146]:
y = y_true_enc[:,:,:2] #level 2
y_flat = torch.flatten(y, start_dim=0, end_dim=1) # remove batch dimension to make it simples

node = nodes[1]
# idx = reversed(torch.Tensor(range(1,8)))

mask = (node[:,None,:]==y_flat)
# print(mask.shape)
mask_exact = mask.all(dim=2)
mask_missing = (~mask).all(dim=2)
mask_above = mask[:,:,:-1].all(dim=2) & (~mask[:,:,-1])
print(mask.shape)
print(mask[0])
print(mask_missing)
print(mask_above)
# indices = torch.nonzero(mask[:,:])
# indices_list = indices[:,0].tolist()
# indices_list += [-1] * (mask.shape[1] - len(indices_list))
# indices_list 
# mask

torch.Size([3, 10, 2])
tensor([[ True, False],
        [ True, False],
        [False,  True],
        [False,  True],
        [False, False],
        [ True, False],
        [ True,  True],
        [False, False],
        [False,  True],
        [ True, False]])
tensor([[False, False, False, False,  True, False, False,  True, False, False],
        [False, False,  True,  True,  True, False, False,  True,  True, False],
        [ True,  True, False, False,  True,  True, False, False, False,  True]])
tensor([[ True,  True, False, False, False,  True, False, False, False,  True],
        [False,  True, False, False, False, False,  True, False, False,  True],
        [False, False, False, False, False, False, False,  True, False, False]])


In [10]:
def get_true_prob(level, y_true_enc):
    """Returns the probability for a given level and target"""
    y_true_enc = y_true_enc[:, :, :level+1]
    
    return y_true_enc
get_true_prob(1, y_true_enc)

tensor([[[1, 2],
         [1, 0],
         [2, 1],
         [2, 1],
         [3, 0]],

        [[1, 2],
         [1, 1],
         [2, 0],
         [2, 1],
         [1, 0]]])

# TODO: figure this part out!!

In [11]:
flat_softmax = FlatSoftmaxMultOP(leaf_nodes, trainable_weights=0)
mappings = flat_softmax.lvl_mappings
# level 1
level1_nodes = mappings[1][0]
level1_node_ints = mappings[1][1]
y_enc2 = y_true_enc[:,:2]
y_enc2 = y_enc2[~(y_enc2==0).any(dim=1),:]
print(y_enc2)
mask = (level1_nodes == y_enc2[:,None,:]).all(dim=-1)
print(mask.shape)
print(mask)

print(level1_nodes)
print(y_true_enc[0][:, :1])
# level1_node_ints.expand()[mask]
print(level1_node_ints)
print(level1_node_ints.expand(4,-1).transpose(0,1))

IndexError: The shape of the mask [2, 3] at index 1 does not match the shape of the indexed tensor [2, 2, 3] at index 1

In [226]:
unique_nodes, _ = torch.unique(leaf_nodes[:,:2], return_inverse=True, dim=0)

In [245]:
y2 = y_true_enc[:,:,:2]
y2

tensor([[[1, 2],
         [1, 0],
         [2, 1],
         [2, 1],
         [3, 0]],

        [[1, 2],
         [1, 1],
         [2, 0],
         [2, 1],
         [1, 0]]])

In [246]:
print(unique_nodes)

tensor([[1, 1],
        [1, 2],
        [2, 1],
        [3, 0]])


In [274]:
y_true_enc[:,:,:2]

tensor([[[1, 2],
         [1, 0],
         [2, 1],
         [2, 1],
         [3, 0]],

        [[1, 2],
         [1, 1],
         [2, 0],
         [2, 1],
         [1, 0]]])

In [275]:
y_enc2 = y_true_enc[:,:,:2]
y_enc2 = y_enc2[~(y_enc2==0).any(dim=2),:]
print(y_enc2)

tensor([[1, 2],
        [2, 1],
        [2, 1],
        [1, 2],
        [1, 1],
        [2, 1]])


In [268]:
mask = (unique_nodes == y_enc2[:,:,None,:]).all(dim=-1)
print(mask.shape)
print(mask)
torch.nonzero(mask)

IndexError: too many indices for tensor of dimension 2

In [12]:
# write a function that takes torch.tensor([1,1,1,2,2]) and returns a matrix that looks like this torch.tensor([[1,1,1,0,0],[0,0,0,1,1]])
def create_leaf_selection_matrix(tensor):
    """This function takes an array of integers and 
    returns a matrix where each row is a one-hot encoded version of the input.
    e,g, [1,1,1,2,2] -> [[1,1,1,0,0],[0,0,0,1,1]]]"""
    unique_values = torch.unique(tensor)
    mask = (tensor.unsqueeze(0) == unique_values.unsqueeze(1)).float()
    return mask

def get_probs_on_level(probs, leaf_nodes, level):
    tensor = leaf_nodes[:, :level]
    selection_matrix = create_leaf_selection_matrix(tensor)
    return selection_matrix@probs

In [13]:
get_probs_on_level(leaf_probs[0], leaf_nodes, 0)

RuntimeError: The size of tensor a (6) must match the size of tensor b (0) at non-singleton dimension 1

 the model should return logits for every leaf node:
 in this case, we will get 6 logits

In [157]:
logits = torch.tensor([1, 1, 1,1,1,1], dtype=torch.float32)

# this is one of the leaf nodes, so we can use the corresponding probability:
target = torch.tensor([1, 1, 1])
loss = compute_loss(logits, target, leafs)
target2 = torch.tensor([1, 1, 0])
# if last entry is zero, we sum over leafs that start with the same integers

loss2 = compute_loss(logits, target2, leafs)
print(loss, loss2)

UnboundLocalError: local variable 'newleafs' referenced before assignment

In [2]:
main_vocab = torch.load(join(data_path,"tokenized\\mimic3\\plain\\vocabulary.pt"))
atc_codes = medical.MedicalCodes().get_atc()
icd_codes = medical.MedicalCodes().get_icd()
codes = []
codes.append('[PAD]')
codes = codes + atc_codes[:5] + icd_codes[:5]

In [3]:
codes

['[PAD]',
 'MA01',
 'MA01A',
 'MA01AA',
 'MA01AA01',
 'MA01AA02',
 'DA00',
 'DA000',
 'DA001',
 'DA009',
 'DA01']

In [4]:
main_vocab = {}
for random_code in codes:
    main_vocab[random_code] = len(main_vocab)
vocab_ls = medical.SKSVocabConstructor(main_vocab)()

In [5]:
pad = []
ma01 = []
ma01aa02  = []
da009 = []
for vocab in vocab_ls:
    pad.append(vocab['[PAD]'])
    da009.append(vocab['DA009'])
    ma01.append(vocab['MA01'])
    ma01aa02.append(vocab['MA01AA02'])
print(pad)
print(da009)
print(ma01)
print(ma01aa02)


[0, 0, 0, 0, 0, 0]
[1, 1, 2, 11, 0, 0]
[2, 1, 3, 0, 0, 0]
[2, 1, 3, 12, 12, 4]


In [192]:
for vocab in vocab_ls:
    print(vocab['DA009'])

1
1
2
11
0
0


In [49]:
vocab = {
    'D':1,
    'M':2,
    'DA00':10,
    'DA01':11,
    'DB99':13,
    'MA00':20,
    'DA000':100,
    'DA001':101,
    'MA000':200,
    'MA001':201,}

In [6]:
def get_ex_pred(ids):
    ex_pred = torch.rand(204)
    for id in ids:
        ex_pred[id] = 100
    return softmax(ex_pred)
def get_targets(target_id):
    targets_ls = []
    while target_id>0:
        targets = torch.zeros(204)
        targets[target_id] = 1
        targets_ls.append(targets)
        target_id = target_id//10
    return targets_ls


In [7]:
# import cross entropy loss
ce = CrossEntropyLoss()
def f_loss(pred, targets):
    loss = 0 
    for i, target in enumerate(targets):
        loss+= 1/(10**i)*float(ce(pred.unsqueeze(0), target.unsqueeze(0)) )
    return loss

In [8]:
target_id = 201
targets = get_targets(target_id)

In [3]:
target = torch.zeros(10)
target[0] = 1

In [24]:
torch.tensor(logits, dtype=torch.float32)

tensor([ 0., 10., 10.,  0.,  5.,  5.])

In [43]:
tokens = ['A', 'B', 'Aa', 'Ab', 'Ba', 'Bb']
logits = [0, 100, 100, 0, 0, 100]
level0_probs = Softmax(dim=0)(torch.tensor(logits[:2], dtype=torch.float32))
level1_probs_A = Softmax(dim=0)(torch.tensor(logits[2:4], dtype=torch.float32))
level1_probs_B = Softmax(dim=0)(torch.tensor(logits[4:], dtype=torch.float32))
print(level0_probs)
print(level1_probs_A)
print(level1_probs_B)

tensor([3.7835e-44, 1.0000e+00])
tensor([1.0000e+00, 3.7835e-44])
tensor([3.7835e-44, 1.0000e+00])


RuntimeError: 0D or 1D target tensor expected, multi-target not supported