# Notebook settings
This notebook is a universal frame for training the DistilBERT+Adapted HMCN-F model. It can take in any Parquet dataset with suitable configuration.

**This version is for datasets with separate Parquets for training, validation and test sets.**

In [1]:
### Dataset configuration
# The main folder. It should be located inside datasets/. Inside it should be three .parquet folders or files.
DATASET_NAME   = 'DBPedia'
# The input text column
TEXT_COL_NAME  = 'text'
# Which column to use as labelled classes. It should be a column of lists of strings.
CLASS_COL_NAME = 'category'
# How many hierarchical levels to work on. Note that the dataset must also have at least this many levels for every example.
DEPTH = 3

### Checkpoint configuration
# Whether to train from scratch or to load a checkpoint
TRAIN_FROM_SCRATCH=True
# Checkpoint iteration to load if not training from scratch
LOAD_ITERATION=0
# Last or best results from that iteration?
LOAD_BEST=True

### System configuration
# Will try to use your NVIDIA GPU if one is available. Set to False to force CPU computation
PREFER_GPU         = True
# If you don't have the huggingface transformers library installed, flip this to True.
# You only need to do this once. Once DistilBERT has been downloaded, it will be cached in your system's default user cache folder.
# Once it is cached, please set this to False to avoid redownloads.
INSTALL_DISTILBERT = False

# Import common libraries
And also set up a few things.

In [2]:
import dask.dataframe as dd
import numpy as np
import torch
from tqdm.notebook import tqdm
import numpy as np
import shutil, sys
from sklearn import metrics
import os

# Set up GPU if available
device = 'cuda' if torch.cuda.is_available() and PREFER_GPU else 'cpu'
print('Using', device)

Using cuda


# Import data

In [3]:
data = dd.read_parquet('../../datasets/{}/train.parquet'.format(DATASET_NAME))
data_val = dd.read_parquet('../../datasets/{}/val.parquet'.format(DATASET_NAME))
data_test = dd.read_parquet('../../datasets/{}/val.parquet'.format(DATASET_NAME))
data.head(10)

Unnamed: 0,text,category
0,"William Alexander Massey (October 7, 1856 – Ma...","[Agent, Politician, Senator]"
1,Lions is the sixth studio album by American ro...,"[Work, MusicalWork, Album]"
2,"Pirqa (Aymara and Quechua for wall, hispaniciz...","[Place, NaturalPlace, Mountain]"
3,Cancer Prevention Research is a biweekly peer-...,"[Work, PeriodicalLiterature, AcademicJournal]"
4,The Princeton University Chapel is located on ...,"[Place, Building, HistoricBuilding]"
5,Sistrurus catenatus edwardsii is a subspecies ...,"[Species, Animal, Reptile]"
6,"The 1st Battalion, 68th Armor Regiment (1–68 A...","[Agent, Organisation, MilitaryUnit]"
7,John Warren Davis (commonly known as J. Warren...,"[Agent, Person, Judge]"
8,"Alfrēds Hartmanis (November 1, 1881, Riga, Lat...","[Agent, Athlete, ChessPlayer]"
9,The International Association of Plumbing and ...,"[Agent, Organisation, TradeUnion]"


# Categorical-encode the classes
Our implementation of HMCN-F uses ordered global-space indices. That is, in the numerical order will be all classes on the first level, THEN those on the second and so on, with each level's first index starting after the previous level's last index.

For categorical encoding to work, the columns themselves must be in Dask's `category` datatype, instead of the default `object` type for non-numerical columns.

We'll only do the scanning step over the training set, that is, we'll assume that the training set is decent enough to include at least one example for every class.

In [4]:
def preprocess_classes(data, original_name, depth, verbose=False):
    """
    Build a list of unique class names for each level and create bidirectional mappings.
    """
    cls2idx = []
    idx2cls = []
    for i in range(depth): 
        category_li = data[original_name].apply(
            lambda lst: lst[i], meta=(original_name, 'object')
        ).astype('category').cat.as_known()
        if verbose:
            print(category_li.cat.classes)
        cls2idx.append(dict([
            (category, index) 
            for (index, category) 
            in enumerate(category_li.cat.categories)
        ]))
        idx2cls.append(list(category_li.cat.categories))
    return cls2idx, idx2cls

cls2idx, idx2cls = preprocess_classes(data, CLASS_COL_NAME, DEPTH)    
print(cls2idx)
print('\n')
print(idx2cls)

[{'Agent': 0, 'Device': 1, 'Event': 2, 'Place': 3, 'Species': 4, 'SportsSeason': 5, 'TopicalConcept': 6, 'UnitOfWork': 7, 'Work': 8}, {'Actor': 0, 'AmusementParkAttraction': 1, 'Animal': 2, 'Artist': 3, 'Athlete': 4, 'BodyOfWater': 5, 'Boxer': 6, 'BritishRoyalty': 7, 'Broadcaster': 8, 'Building': 9, 'Cartoon': 10, 'CelestialBody': 11, 'Cleric': 12, 'ClericalAdministrativeRegion': 13, 'Coach': 14, 'Comic': 15, 'ComicsCharacter': 16, 'Company': 17, 'Database': 18, 'EducationalInstitution': 19, 'Engine': 20, 'Eukaryote': 21, 'FictionalCharacter': 22, 'FloweringPlant': 23, 'FootballLeagueSeason': 24, 'Genre': 25, 'GridironFootballPlayer': 26, 'Group': 27, 'Horse': 28, 'Infrastructure': 29, 'LegalCase': 30, 'MotorcycleRider': 31, 'MusicalArtist': 32, 'MusicalWork': 33, 'NaturalEvent': 34, 'NaturalPlace': 35, 'Olympics': 36, 'Organisation': 37, 'OrganisationMember': 38, 'PeriodicalLiterature': 39, 'Person': 40, 'Plant': 41, 'Politician': 42, 'Presenter': 43, 'Race': 44, 'RaceTrack': 45, 'Rac

Now we can generate indices to use as class labels for training:

In [5]:
def class_to_index(data, original_name, cls2idx, depth):
    data['codes'] = data[original_name].apply(
        lambda lst: [
            cls2idx[i][cat] 
            for (i, cat) 
            in enumerate(lst[:depth])
        ],
        meta=(original_name, 'object')
    ).astype('object')

class_to_index(data, CLASS_COL_NAME, cls2idx, DEPTH)
class_to_index(data_val, CLASS_COL_NAME, cls2idx, DEPTH)
class_to_index(data_test, CLASS_COL_NAME, cls2idx, DEPTH)

Lastly, binarise them.

In [6]:
from functools import reduce
# C-HMCNN needs global-space indices. As such we need to offset the level codes.
# We still make use of the local-space code above to increase commonality with other models.
def index_to_binary(data, index_col_name, offsets, sz, verbose=False):
    if verbose:
        print('Using offsets:', offsets)
    
    def generate_binary(codes):
        b = np.zeros(sz, dtype=int)
        indices = np.array(codes, dtype=int) + offsets[:-1]
        if verbose:
            print(codes, offsets, indices)
        b[indices] = 1
        return b.tolist()
    
    data[index_col_name + '_b'] = data[index_col_name].apply(
        lambda lst: generate_binary(lst),
        meta=(index_col_name + '_b', 'object')
    )
    
level_sizes = [*map(lambda lst: len(lst), idx2cls)]
level_offsets = np.array(reduce(lambda acc, elem: acc + [acc[-1] + elem], level_sizes, [0]))

index_to_binary(data, 'codes', level_offsets, sum(level_sizes), verbose=False)
index_to_binary(data_val, 'codes', level_offsets, sum(level_sizes), verbose=False)
index_to_binary(data_test, 'codes', level_offsets, sum(level_sizes), verbose=False)

data.head(10)

Unnamed: 0,text,category,codes,codes_b
0,"William Alexander Massey (October 7, 1856 – Ma...","[Agent, Politician, Senator]","[0, 42, 185]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,Lions is the sixth studio album by American ro...,"[Work, MusicalWork, Album]","[8, 33, 4]","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
2,"Pirqa (Aymara and Quechua for wall, hispaniciz...","[Place, NaturalPlace, Mountain]","[3, 35, 132]","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,Cancer Prevention Research is a biweekly peer-...,"[Work, PeriodicalLiterature, AcademicJournal]","[8, 39, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
4,The Princeton University Chapel is located on ...,"[Place, Building, HistoricBuilding]","[3, 9, 98]","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
5,Sistrurus catenatus edwardsii is a subspecies ...,"[Species, Animal, Reptile]","[4, 2, 172]","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ..."
6,"The 1st Battalion, 68th Armor Regiment (1–68 A...","[Agent, Organisation, MilitaryUnit]","[0, 37, 126]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
7,John Warren Davis (commonly known as J. Warren...,"[Agent, Person, Judge]","[0, 40, 111]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
8,"Alfrēds Hartmanis (November 1, 1881, Riga, Lat...","[Agent, Athlete, ChessPlayer]","[0, 4, 46]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
9,The International Association of Plumbing and ...,"[Agent, Organisation, TradeUnion]","[0, 37, 210]","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


We can try recovering category names from this encoding to see if it is still in the original hierarchical order.

In [7]:
def retrieve_classes(codes, idx2cls):
    return [ idx2cls[i][code] for (i, code) in enumerate(codes) ]

print('Original:', data.loc[3].compute()['category'][0:DEPTH])

print('Retrieved:', retrieve_classes(data['codes'].loc[3].compute().iloc[0], idx2cls))

Original: 3    [Work, PeriodicalLiterature, AcademicJournal]
3          [Agent, EducationalInstitution, School]
3                        [Place, Settlement, Town]
Name: category, dtype: object
Retrieved: ['Work', 'PeriodicalLiterature', 'AcademicJournal']


# Hierarchy generation
In this model, the hierarchical error penalty is simply $L_H = \lambda \times max(Y_{child} - Y_{parent})$. As such, we need to keep track of each node's parent and vectorise the calculation.

For now I'll be implementing this as simple arrays of category codes (in the global categorical space). We can then use these arrays of codes as vectorised indices to pull out $Y_{parent}$s and have our loss function somewhat vectorised too.

In [8]:
# TODO: Bring the above code into this thing's constructor entirely.
from functools import reduce
class PerLevelHierarchy:
    # level_sizes is a list of (distinct) class counts per hierarchical level.
    #   Its length dictates the maximum hierarchy construction depth.
    #   (that is, our above code)
    # classes is the list of distinct classes, in the order we have assembled.
    def __init__(self, data, cls2idx):
        self.levels = [ len(d.keys()) for d in cls2idx ] # TODO: Rename to level_sizes
        self.classes = reduce(lambda acc, elem: acc + elem, [ list(d.keys()) for d in cls2idx ], [])
        # Where each level starts in a global n-hot category vector
        # Its last element is coincidentally the length, which also allows us
        # to simplify the slicing code by blindly doing [offset[i] : offset[i+1]]
        self.level_offsets = reduce(lambda acc, elem: acc + [acc[len(acc) - 1] + elem], self.levels, [0])
        # Use -1 to indicate 'undiscovered'
        self.parent_of = [-1] * len(self.classes)
        for lst in data['codes']:
            # First-level classes' parent is root, but here we set them to themselves.
            # This effectively zeroes out the hierarchical loss for this level.
            self.parent_of[lst[0]] = lst[0]
            for i in range(1, len(self.levels)):
                child_idx = lst[i] + self.level_offsets[i]
                parent_idx = lst[i-1] + self.level_offsets[i - 1]
                if self.parent_of[child_idx] == -1:
                    self.parent_of[child_idx] = parent_idx

In [9]:
hierarchy = PerLevelHierarchy(data, cls2idx)
hierarchy.parent_of

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 0,
 3,
 4,
 0,
 0,
 3,
 0,
 0,
 0,
 3,
 8,
 3,
 0,
 3,
 0,
 8,
 0,
 0,
 8,
 0,
 1,
 4,
 0,
 4,
 5,
 6,
 0,
 0,
 4,
 3,
 7,
 0,
 0,
 8,
 2,
 3,
 2,
 0,
 0,
 8,
 0,
 4,
 0,
 0,
 2,
 3,
 0,
 3,
 3,
 0,
 3,
 2,
 8,
 8,
 3,
 2,
 0,
 0,
 0,
 5,
 3,
 3,
 2,
 3,
 3,
 0,
 0,
 0,
 0,
 8,
 48,
 9,
 26,
 38,
 42,
 15,
 49,
 35,
 11,
 25,
 19,
 11,
 49,
 57,
 42,
 49,
 67,
 13,
 29,
 13,
 36,
 26,
 16,
 65,
 13,
 68,
 65,
 13,
 67,
 74,
 49,
 27,
 11,
 13,
 26,
 56,
 17,
 26,
 49,
 67,
 70,
 13,
 21,
 18,
 44,
 49,
 13,
 21,
 41,
 42,
 23,
 12,
 24,
 12,
 51,
 50,
 60,
 63,
 67,
 13,
 11,
 50,
 75,
 50,
 53,
 67,
 13,
 38,
 13,
 22,
 43,
 49,
 60,
 49,
 58,
 62,
 12,
 50,
 75,
 60,
 11,
 64,
 55,
 30,
 13,
 20,
 44,
 63,
 13,
 71,
 51,
 64,
 32,
 50,
 13,
 13,
 67,
 77,
 18,
 67,
 19,
 53,
 13,
 49,
 18,
 18,
 65,
 75,
 11,
 13,
 49,
 49,
 13,
 14,
 26,
 46,
 28,
 72,
 48,
 24,
 13,
 51,
 58,
 51,
 60,
 49,
 46,
 64,
 49,
 11,
 49,
 50,
 44,
 44,
 44,
 18,
 60,


## Checkpoints

In [10]:
def load_checkpoint(checkpoint_fpath, model):
    encoder, classifier = model
    checkpoint = torch.load(checkpoint_fpath)
    classifier.load_state_dict(checkpoint['state_dict'])
    return (encoder, classifier)

def save_checkpoint(state, is_best, checkpoint_path, best_model_path):
    f_path = checkpoint_path
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_path
        shutil.copyfile(f_path, best_fpath)

## Metrics
We define hierarchical accuracy as simply the averaged accuracy over each level. Same for precision. With HMCN-F, we use the final output $P_F$.
In addition to those, at the end of the testing phase we'll also compute the average area under the precision-recall curve (AU(PRC)).

In [11]:
def get_metrics(outputs, targets, level_sizes, print_metrics=True):
    offsets = [0] + level_sizes
    level_codes = [ 
        np.argmax(outputs[:, offsets[level] : offsets[level + 1]], axis=1) + offsets[level] 
        for level in range(len(level_sizes))
    ]
    
    target_codes = np.array([ np.nonzero(lst)[0] for lst in targets ], dtype=int)
    
    accuracies = [ metrics.accuracy_score(level_codes[level], target_codes[:, level]) for level in range(len(level_sizes)) ]
    precisions = [ metrics.precision_score(level_codes[level], target_codes[:, level], average='weighted') for level in range(len(level_sizes)) ]
    
    global_accuracy = sum(accuracies)/len(accuracies)
    global_precision = sum(precisions)/len(precisions)
    
    if print_metrics:
        for i in range(len(level_sizes)):
            print('Level {}:'.format(i))
            print("Accuracy:", accuracies[i])
            # Model Precision: what percentage of positive tuples are labeled as such?
            print("Precision:", precisions[i],'\n')
        print('Path average:')
        print('Accuracy:', global_accuracy)
        print('Precision:', global_precision)
    
    return np.array([accuracies[-1], precisions[-1], global_accuracy, global_precision])

# Data and model preparation

## Installing DistilBERT
Alternative to full-fat BERT, roughly matching its performance while being faster.

In [12]:
if not INSTALL_DISTILBERT:
    os.environ['TRANSFORMERS_OFFLINE'] = '1'
else:
    !pip install transformers
    
import transformers as ppb
tokenizer = ppb.DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
base_encoder = ppb.DistilBertModel.from_pretrained('distilbert-base-uncased')
base_encoder_state = base_encoder.state_dict()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Define our dataset adapter class
This wraps around our data and provides a PyTorch-compatible interface.

In [13]:
from torch.utils.data import IterableDataset
class CustomDataset(IterableDataset):
    def __init__(self, df, hierarchy, tokenizer, max_len, text_col_name = TEXT_COL_NAME):
        self.tokenizer = tokenizer
        self.iterator = df.itertuples()
        # Level sizes
        self.levels = hierarchy.levels
        self.level_offsets = hierarchy.level_offsets
        self.max_len = max_len

    def __iter__(self):
        return self

    def __next__(self):
        row = next(self.iterator)
        text = str(getattr(row, TEXT_COL_NAME))
        text = " ".join(text.split())
        inputs = self.tokenizer(
            text,
            None, # No text_pair
            add_special_tokens=True, # CLS, SEP
            max_length=self.max_len, # For us it's a hyperparam. See next cells.
            padding='max_length',
            truncation=True
            # BERT tokenisers return attention masks by default
        )

        labels = 

        result = {
            'ids': torch.tensor(inputs['input_ids'], dtype=torch.long),
            'mask': torch.tensor(inputs['attention_mask'], dtype=torch.long),
            'labels': torch.tensor(getattr(row, 'codes_b'), dtype=torch.long),
        }

        return result

SyntaxError: invalid syntax (2381706092.py, line 28)

Regarding that `max_len` hyperparameter, let's see the text lengths' distribution:

In [None]:
data[TEXT_COL_NAME].apply(lambda s: len(s.split())).compute().hist()

We prefer `max_len` to be a power of two that covers most of the strings.

In [None]:
###
### TUNING HYPERPARAMETERS?
### Simply adjust here then run this cell and those below it. No need to run those above.
###

folder_name = 'checkpoints-' + DATASET_NAME
!mkdir $folder_name
CHECKPOINT_IDX = len(os.listdir(folder_name)) // 2
CHECKPOINT_PATH = './{}/{}_current.pt'.format(folder_name, CHECKPOINT_IDX)
BEST_CHECKPOINT_PATH = './{}/{}_best.pt'.format(folder_name, CHECKPOINT_IDX)

config = {
    'cls_lr': 1e-03,
    'lambda_h': 0.7,
    'epochs': 5,
    'dropout': 0.25,
    'global_hidden_sizes': [384] * len(hierarchy.levels),
    'local_hidden_sizes': [384] * len(hierarchy.levels),
    'global_weight': 0.5,
    'hidden_nonlinear': 'relu'
}

### Don't change these if you need to compare with published results
MAX_LEN = 64
TRAIN_MINIBATCH_SIZE = 4
VAL_TEST_MINIBATCH_SIZE = 64

# Don't touch this part
train_minibatch_count = len(data) // TRAIN_MINIBATCH_SIZE
val_minibatch_count = len(data_val) // VAL_TEST_MINIBATCH_SIZE
test_minibatch_count = len(data_test) // VAL_TEST_MINIBATCH_SIZE

## Prepare the model itself
Here we use DistilBERT as the encoding layers, followed by our implementation of HMCN-F.

### HMCN-F
HMCN-F is specifically designed for maximizing the learn-
ing capacity regarding the hierarchical structure of the la-
beled data.

In this model, information flows in two ways:
i) the main flow, which begins with the input layer and tra-
verses all fully-connected (FC) layers until it reaches the
global output; and ii) the local flows, which also begin in the
input layer and pass by their respective global FC layers but
also through specific local FC layers, finally ending at the
corresponding local output. For generating the final prediction, all local outputs are then concatenated and pooled with
the global output for a consensual prediction.

Code2paper notation mapping:
- `feature_size` = $|D|$
- `global_hidden_sizes` = list of $|A^i_G|$ for i in $[1, |H|]$
- `local_hidden_sizes` = list of $|A^i_L|$ for i in $[1, |H|]$
- hierarchy:
  - `len(hierarchy.levels)` = $|H|$
  - `len(hierarchy.classes)` = $|C|$
- `global_weight` = $\beta$

**One significant difference between our version and the one in the paper** is that we replace batch normalisation with layer normalisation, which doesn't wreack havoc on NLP tasks like ours.

The FC (linear) layers comprise 384 ReLU
neurons, followed by a batch normalization, residual connections, and dropout of 60%. Dropout is important given
that these models could easily overfit the small training sets.

In [None]:
from tqdm.notebook import tqdm

class HMCNF(torch.nn.Module):
  def __init__(
      self, 
      input_dim, 
      hierarchy,
      config,
      ):
    super(HMCNF, self).__init__()

    # Back up some parameters for use in forward()
    self.depth = len(hierarchy.levels)
    self.global_weight = config['global_weight']

    # Construct global layers (main flow)
    global_layers = []
    global_layer_norms = []
    for i in range(len(hierarchy.levels)):
      if i == 0:
        global_layers.append(
            torch.nn.Linear(input_dim, config['global_hidden_sizes'][0]))
      else:
        global_layers.append(
            torch.nn.Linear(config['global_hidden_sizes'][i-1] + input_dim, config['global_hidden_sizes'][i]))
      global_layer_norms.append(torch.nn.LayerNorm(config['global_hidden_sizes'][i]))
    self.global_layers = torch.nn.ModuleList(global_layers)
    self.global_layer_norms = torch.nn.ModuleList(global_layer_norms)
    # Global prediction layer
    self.global_prediction_layer = torch.nn.Linear(
        config['global_hidden_sizes'][-1] + input_dim, 
        len(hierarchy.classes)
        )
    
    # Construct local branches (local flow).
    # Each local branch has two linear layers: a transition layer and a local
    # classification layer 
    transition_layers = []
    local_layer_norms = []
    local_layers = []
    
    for i in range(len(hierarchy.levels)):
      transition_layers.append(
          torch.nn.Linear(config['global_hidden_sizes'][i], config['local_hidden_sizes'][i]),
      )
      local_layer_norms.append(
          torch.nn.LayerNorm(config['local_hidden_sizes'][i])
      )
      local_layers.append(
          torch.nn.Linear(config['local_hidden_sizes'][i], hierarchy.levels[i])
      )
    self.local_layer_norms = torch.nn.ModuleList(local_layer_norms)
    self.transition_layers = torch.nn.ModuleList(transition_layers)
    self.local_layers = torch.nn.ModuleList(local_layers)
    
    # Activation functions
    self.hidden_nonlinear = torch.nn.ReLU() if config['hidden_nonlinear'] == 'relu' else torch.nn.Tanh()
    self.output_nonlinear = torch.nn.Sigmoid()

    # Dropout
    self.dropout = torch.nn.Dropout(p=config['dropout'])

  def forward(self, x):
    # We have |D| hidden layers plus one global prediction layer
    local_outputs = []
    output = x # Would be global path output until the last step
    for i in range(len(self.global_layers)):
      # Global path
      if i == 0:
        # Don't concatenate x into the first layer's input
        output = self.hidden_nonlinear(
            self.global_layer_norms[i](
                self.global_layers[i](output)
                )
            )
      else:
        output = self.hidden_nonlinear(self.global_layer_norms[i](
                self.global_layers[i](torch.cat([output, x], dim=1))
              )
            )

      # Local path. Note the dropout between the transition ReLU layer and the local layer.
      local_output = self.dropout(
          self.hidden_nonlinear(
              self.local_layer_norms[i](self.transition_layers[i](output))
              )
            )
      local_output = self.output_nonlinear(self.local_layers[i](local_output))
      local_outputs.append(local_output)

      # Dropout main flow for next layer
      # TODO: investigate introducing layernorm here
      output = self.dropout(output)

    global_outputs = self.output_nonlinear(
      self.global_prediction_layer(torch.cat([output, x], dim=1))
      )
    local_outputs_concat = torch.cat(local_outputs, dim=1)
    output = self.global_weight * global_outputs + (1 - self.global_weight) * local_outputs_concat
    return output, local_outputs

### Entire model

In [None]:
encoder = base_encoder
encoder.load_state_dict(base_encoder_state)
encoder.to(device)

depth = len(hierarchy.levels)
classifier = HMCNF(
  768, # DistilBERT outputs 768 values.
  hierarchy,
  config
)

classifier.to(device)

# Training time

In [None]:
from tqdm.notebook import tqdm

def train_model(config, data, data_val, model, hierarchy, checkpoint_path, best_checkpoint_path):
  encoder, classifier = model

  # TODO: Somehow do this on GPU
  parent_of = torch.LongTensor(hierarchy.parent_of)

  # Store validation metrics after each epoch
  val_metrics = np.empty((4, 0), dtype=float)

  # Keep min validation (test set) loss so we can separately back up our best-yet model
  val_loss_min = np.Inf

  criterion = torch.nn.BCELoss()
  optimizer = torch.optim.Adam(params=classifier.parameters(), lr=config['cls_lr'])

  # Hierarchical loss gain
  lambda_h = config['lambda_h']
  for epoch in range(1, config['epochs'] + 1):
    train_set = CustomDataset(data, tokenizer, MAX_LEN)
    val_set = CustomDataset(data_val, tokenizer, MAX_LEN)

    train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=TRAIN_MINIBATCH_SIZE)
    val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=VAL_TEST_MINIBATCH_SIZE)
    train_loss = 0
    val_loss = 0
    # Put model into training mode. Note that this call DOES NOT train it yet.
    classifier.train()
    print('Epoch {}: Training'.format(epoch))
    for batch_idx, batch in enumerate(tqdm(train_loader)):
      ids = batch['ids'].to(device, dtype = torch.long)
      mask = batch['mask'].to(device, dtype = torch.long)
      targets = batch['labels'].to(device, dtype = torch.float)

      features = encoder(ids, mask)[0][:,0,:]
      output, local_outputs = classifier(features)

      optimizer.zero_grad()

      # We have three loss functions: (g)lobal, (l)ocal, and (h)ierarchical.
      loss_g = criterion(output, targets)
      loss_l = sum([ criterion(
            local_outputs[level],
            targets[:, hierarchy.level_offsets[level] : hierarchy.level_offsets[level + 1]]
            ) for level in range(len(hierarchy.levels))])
      output_cpu = output.cpu().detach()
      loss_h = torch.sum(lambda_h * torch.clamp(torch.FloatTensor(
        output_cpu - 
        output_cpu.index_select(1, parent_of)
      ), min=0) ** 2)
      loss = loss_g + loss_l + loss_h

      # PyTorch defaults to accumulating gradients, but we don't need that here
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      train_loss = train_loss + (loss.item() - train_loss) / (batch_idx + 1)

    print('Epoch {}: Testing'.format(epoch))
    
    
    # Switch to evaluation (prediction) mode. Again, this doesn't evaluate anything.
    classifier.eval()

    val_targets = np.empty((0, len(hierarchy.classes)), dtype=bool)
    val_outputs = np.empty((0, len(hierarchy.classes)), dtype=float)
    local_val_outputs = np.empty((0, len(hierarchy.classes)), dtype=float)

    # We're only testing here, so don't run the backward direction (no_grad).
    with torch.no_grad():
      total_loss_g = 0
      total_loss_l = 0
      total_loss_h = 0
      for batch_idx, batch in enumerate(tqdm(val_loader)):
        ids = batch['ids'].to(device, dtype = torch.long)
        mask = batch['mask'].to(device, dtype = torch.long)
        targets = batch['labels'].to(device, dtype = torch.float)

        features = encoder(ids, mask)[0][:,0,:]
        output, local_outputs = classifier(features)

        loss_g = criterion(output, targets)
        loss_l = sum([ criterion(
              local_outputs[level],
              targets[:, hierarchy.level_offsets[level] : hierarchy.level_offsets[level + 1]]
              ) for level in range(len(hierarchy.levels))])
        output_cpu = output.cpu().detach()
        loss_h = torch.sum(lambda_h * torch.clamp(torch.FloatTensor(
          output_cpu - 
          output_cpu.index_select(1, parent_of)
        ), min=0) ** 2)
        loss = loss_g + loss_l + loss_h

        total_loss_g += loss_g
        total_loss_l += loss_l
        total_loss_h += loss_h

        val_loss = val_loss + (loss.item() - val_loss) / (batch_idx + 1)

        val_targets = np.concatenate([val_targets, targets.cpu().detach().numpy()])
        val_outputs = np.concatenate([val_outputs, output_cpu.numpy()])
        # Concatenate local test outputs
        local_val_outputs_concat = np.concatenate([*map(lambda t: t.cpu().detach().numpy(), local_outputs)], axis=1)
        local_val_outputs = np.concatenate([local_val_outputs, local_val_outputs_concat])

      # calculate average losses
      #print('before cal avg train loss', train_loss)
      print('Average minibatch global loss:', total_loss_g / len(val_loader))
      print('Average minibatch local loss:', total_loss_l / len(val_loader))
      print('Average minibatch hierarchical loss:', total_loss_h / len(val_loader))
    
      val_metrics = np.concatenate(
          [
            val_metrics, 
            np.expand_dims(
              get_metrics(val_outputs, val_targets, hierarchy.levels), 
              axis=1
            )
          ],
          axis=1
      )
      train_loss = train_loss/len(train_loader)
      val_loss = val_loss/len(val_loader)
      # Print training/validation statistics 
      print('Avgerage training loss: {:.6f}\nAverage validation loss: {:.6f}'.format( 
            train_loss,
            val_loss
            ))

      # create checkpoint variable and add important data
      checkpoint = {
            'state_dict': classifier.state_dict(),
            'optimizer': optimizer.state_dict()
      }

      best_yet = False
      if val_loss <= val_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(val_loss_min,val_loss))
        # save checkpoint as best model
        best_yet = True
        val_loss_min = val_loss
      save_checkpoint(checkpoint, best_yet, checkpoint_path, best_checkpoint_path)
    print('Epoch {}: Done\n'.format(epoch))
  return (encoder, classifier), val_metrics

# Alternative: just load from disk
def run_model(model, loader, hierarchy):
  encoder, classifier = model
  # Switch to evaluation (prediction) mode. Again, this doesn't evaluate anything.
  classifier.eval()

  all_targets = np.empty((0, len(hierarchy.classes)), dtype=bool)
  all_outputs = np.empty((0, len(hierarchy.classes)), dtype=float)
  all_local_outputs = np.empty((0, len(hierarchy.classes)), dtype=float)

  # We're only testing here, so don't run the backward direction (no_grad).
  with torch.no_grad():
    for batch_idx, batch in enumerate(tqdm(loader)):
      ids = batch['ids'].to(device, dtype = torch.long)
      mask = batch['mask'].to(device, dtype = torch.long)
      targets = batch['labels']

      features = encoder(ids, mask)[0][:,0,:]
      output, local_outputs = classifier(features)

      all_targets = np.concatenate([all_targets, targets.numpy()])
      all_outputs = np.concatenate([all_outputs, output.cpu().detach().numpy()])
      # Concatenate local test outputs
      local_outputs_concat = np.concatenate([*map(lambda t: t.cpu().detach().numpy(), local_outputs)], axis=1)
      all_local_outputs = np.concatenate([all_local_outputs, local_outputs_concat])
  return {
      'targets': all_targets,
      'outputs': all_outputs,
      'local_outputs': all_local_outputs,
  }

In [None]:
import matplotlib.pyplot as plt

trained_model = None
if TRAIN_FROM_SCRATCH:
    trained_model, val_metrics = train_model(
        config,
        data,
        data_val,
        (encoder, classifier),
        hierarchy,
        CHECKPOINT_PATH,
        BEST_CHECKPOINT_PATH,
    )
    x = np.arange(config['epochs'])
    fig, ax = plt.subplots()  # Create a figure and an axes.
    ax.plot(x, val_metrics[0], label='leaf accuracy')
    ax.plot(x, val_metrics[1], label='leaf precision')
    ax.plot(x, val_metrics[2], label='average global accuracy')
    ax.plot(x, val_metrics[3], label='average global precision')
    ax.set_xlabel('epoch')  # Add an x-label to the axes.
    ax.set_ylabel('score')  # Add a y-label to the axes.
    ax.set_title("Accuracy/precision over epochs")  # Add a title to the axes.
    ax.legend()  # Add a legend.
    fig.show()
else:
    load_path = '{}/{}_{}.pt'.format(folder_name, LOAD_ITERATION, 'best' if LOAD_BEST else 'current')
    trained_model = load_checkpoint(load_path, (encoder, classifier))

In [None]:
test_set = CustomDataset(data_test, tokenizer, MAX_LEN)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=VAL_TEST_MINIBATCH_SIZE)
test_result = run_model(trained_model, test_loader, hierarchy)

# Evaluation
We'll mainly use the leaf prediction in real-world applications to ensure 100% hierarchy matches. However, we'll still test with the global encoding just to see what we are getting.

In [None]:
get_metrics(test_result['outputs'], test_result['targets'], level_sizes)

# Rectified leaf AU(PRC) due to an sklearn bug.
# We add one artificial example that belongs to all classes at once and a corresponding prediction
# full of true positives. This way each class has at least one true positive, even if the test set
# does not contain enough examples to cover all classes.
rectified_outputs = np.concatenate([test_result['outputs'][:, level_offsets[-2]:], np.ones((1, level_sizes[-1]))], axis=0)
rectified_targets = np.concatenate([test_result['targets'][:, level_offsets[-2]:], np.ones((1, level_sizes[-1]), dtype=bool)], axis=0)

print('\n')
print('Rectified leaf-level AU(PRC) score:', metrics.average_precision_score(rectified_targets, rectified_outputs))

## Hierarchical predictions
Let's have another visual match-up, but this time for the entire hierarchy.

In [None]:
path_codes = np.concatenate([
    np.expand_dims(
        np.argmax(test_result['outputs'][:, hierarchy.level_offsets[level] : hierarchy.level_offsets[level + 1]], axis=1),
        axis=1
    ) for level in range(len(hierarchy.levels))],
    axis=1
)

target_codes = np.array([ np.nonzero(lst)[0] - hierarchy.level_offsets[:-1] for lst in test_result['targets'] ], dtype=int)
print(path_codes.shape)
print(target_codes.shape)

In [None]:
predicted_classes = [retrieve_classes(row, idx2cls) for row in tqdm(path_codes)]
actual_classes = [retrieve_classes(row, idx2cls) for row in tqdm(target_codes)]
import pandas as pd
comp_df = pd.DataFrame({ 'Hierarchical prediction': predicted_classes, 'Actual hierarchy': actual_classes})
comp_df

# Past results
- `_1`: Equal to `_9` in Walmart_Marketing. However, we now use 10% validation 10% test instead of 20% test.

```
```