# ELMo bi-LM layer

In [86]:
# -*- coding: utf-8 -*-

import h5py
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# *~ coding convention ~*
from typing import Dict, List, Optional, Any, Tuple, Callable, Union
from overrides import overrides

# Python Standard Library
from collections import defaultdict
import collections
import itertools
import logging
import random
import codecs
import math
import json
import os

# Python Installed Library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import PackedSequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.utils.rnn import pack_padded_sequence

In [2]:
# fuction: dict to namedtuple
def dict2namedtuple(dic):
    return collections.namedtuple('Namespace', dic.keys())(**dic)

# input your directories path
model_dir = 'C:\workspace\implement_elmo\elmo\configs'
args2 = dict2namedtuple(
    json.load(
        codecs.open(
            os.path.join(model_dir, 'config.json'), 
            'r', encoding='utf-8')
    )
)

# args2.config_path == 'cnn_50_100_512_4096_sample.json'

# load config
with open(os.path.join(model_dir, args2.config_path), 'r') as fin:
    config = json.load(fin)
    
token_embedding = torch.load('token_embedding.pt') 
masks = [torch.load(f'mask[{ix}].pt') for ix in range(3)]

__init__

In [3]:
stateful = False
_states = None
use_cuda = torch.cuda.is_available()

input_size = config['encoder']['projection_dim']
hidden_size = config['encoder']['projection_dim']
cell_size = config['encoder']['dim']
num_layers = config['encoder']['n_layers']
memory_cell_clip_value = config['encoder']['cell_clip']
state_projection_clip_value = config['encoder']['proj_clip']
recurrent_dropout_probability = config['dropout']

print(f"input_size = {input_size}")
print(f"hidden_size = {hidden_size}")
print(f"cell_size = {cell_size}")
print(f"num_layers = {num_layers}")
print(f"memory_cell_clip_value = {memory_cell_clip_value}")
print(f"state_projection_clip_value = {state_projection_clip_value}")
print(f"recurrent_dropout_probability = {config['dropout']}")

forward_layers = []
backward_layers = []

lstm_input_size = input_size
go_forward = True

input_size = 512
hidden_size = 512
cell_size = 4096
num_layers = 2
memory_cell_clip_value = 3
state_projection_clip_value = 3
recurrent_dropout_probability = 0.1


In [6]:
class LstmCellWithProjection(nn.Module):
    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 cell_size: int,
                 go_forward: bool = True,
                 recurrent_dropout_probability: float = 0.0,
                 memory_cell_clip_value: Optional[float] = None,
                 state_projection_clip_value: Optional[float] = None) -> None:
        super(LstmCellWithProjection, self).__init__()
        # Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`.
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.cell_size = cell_size

        self.go_forward = go_forward
        self.state_projection_clip_value = state_projection_clip_value
        self.memory_cell_clip_value = memory_cell_clip_value
        self.recurrent_dropout_probability = recurrent_dropout_probability

        # We do the projections for all the gates all at once.
        self.input_linearity = nn.Linear(input_size, 4 * cell_size, bias=False)
        self.state_linearity = nn.Linear(hidden_size, 4 * cell_size, bias=True)

        # Additional projection matrix for making the hidden state smaller.
        self.state_projection = nn.Linear(cell_size, hidden_size, bias=False)
        self.reset_parameters()
        
    def reset_parameters(self):
        # Use sensible default initializations for parameters.
        block_orthogonal(self.input_linearity.weight.data, [self.cell_size, self.input_size])
        block_orthogonal(self.state_linearity.weight.data, [self.cell_size, self.hidden_size])

        self.state_linearity.bias.data.fill_(0.0)
        # Initialize forget gate biases to 1.0 as per An Empirical
        # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).
        self.state_linearity.bias.data[self.cell_size:2 * self.cell_size].fill_(1.0)
        
def block_orthogonal(tensor: torch.Tensor,
                     split_sizes: List[int],
                     gain: float = 1.0) -> None:
    """
    An initializer which allows initaliizing model parametes in "block".
    """
    if isinstance(tensor, Variable):
    # in pytorch 4.0, Variable equals Tensor
    #     block_orthogonal(tensor.data, split_sizes, gain)
    # else:
        sizes = list(tensor.size())
        if any([a % b != 0 for a, b in zip(sizes, split_sizes)]):
            raise ConfigurationError(
                "tensor dimentions must be divisible by their respective "
                f"split_sizes. Found size: {size} and split_sizes: {split_sizes}")
        indexes = [list(range(0, max_size, split))
                   for max_size, split in zip(sizes, split_sizes)]
        # Iterate over all possible blocks within the tensor.
        for block_start_indices in itertools.product(*indexes):
            index_and_step_tuples = zip(block_start_indices, split_sizes)
            block_slice = tuple([slice(start_index, start_index + step)
                                 for start_index, step in index_and_step_tuples])
            tensor[block_slice] = nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain)

In [7]:
for layer_index in range(num_layers):
    forward_layer = LstmCellWithProjection(lstm_input_size,
                                          hidden_size,
                                          cell_size,
                                          go_forward,
                                          recurrent_dropout_probability,
                                          memory_cell_clip_value,
                                          state_projection_clip_value)
    backward_layer = LstmCellWithProjection(lstm_input_size,
                                           hidden_size,
                                           cell_size,
                                           not go_forward,
                                           recurrent_dropout_probability,
                                           memory_cell_clip_value,
                                           state_projection_clip_value)
    if use_cuda:
        forward_layer = forward_layer.cuda()
        backward_layer = backward_layer.cuda()
    lstm_input_size = hidden_size
    forward_layers.append(forward_layer)
    backward_layers.append(backward_layer)

In [15]:
forward_layers

[LstmCellWithProjection(
   (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
   (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
   (state_projection): Linear(in_features=4096, out_features=512, bias=False)
 ), LstmCellWithProjection(
   (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
   (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
   (state_projection): Linear(in_features=4096, out_features=512, bias=False)
 )]

In [16]:
backward_layers

[LstmCellWithProjection(
   (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
   (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
   (state_projection): Linear(in_features=4096, out_features=512, bias=False)
 ), LstmCellWithProjection(
   (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
   (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
   (state_projection): Linear(in_features=4096, out_features=512, bias=False)
 )]

## 가보자~

In [25]:
inputs = token_embedding
mask = masks[0]

In [26]:
inputs

tensor([[[-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
           1.9630e-02, -2.7469e-02],
         [ 2.9012e-02,  4.7038e-03, -2.2742e-02,  ...,  1.6245e-02,
           2.9220e-02, -1.6054e-02],
         [-5.4546e-03, -2.7518e-02, -1.8246e-02,  ...,  9.4597e-03,
           2.9380e-02, -9.5604e-03],
         ...,
         [ 2.1840e-02, -2.4943e-02, -3.1549e-02,  ...,  2.9098e-02,
          -5.3617e-03, -1.8110e-02],
         [ 2.1840e-02, -2.4943e-02, -3.1549e-02,  ...,  2.9098e-02,
          -5.3617e-03, -1.8110e-02],
         [ 2.1840e-02, -2.4943e-02, -3.1549e-02,  ...,  2.9098e-02,
          -5.3617e-03, -1.8110e-02]],

        [[-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
           1.9630e-02, -2.7469e-02],
         [ 5.6490e-03, -2.6852e-02, -2.1564e-02,  ..., -4.3684e-03,
           5.7293e-02, -4.5267e-02],
         [ 2.6103e-02, -4.5548e-03, -1.5987e-02,  ...,  3.1253e-02,
           1.0739e-02, -5.7272e-02],
         ...,
         [ 2.1840e-02, -2

In [27]:
mask

tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')

In [28]:
inputs.shape, mask.shape

(torch.Size([3, 10, 512]), torch.Size([3, 10]))

In [30]:
batch_size, total_sequence_length = mask.size()
print(f"batch_size = {batch_size}")
print(f"total_sequence_length = {total_sequence_length}")

batch_size = 3
total_sequence_length = 10


In [80]:
# sorting
num_valid = torch.sum(mask[:, 0]).int().item()
print(f"num_valid = {num_valid}")

# get_lengths_from_binary_sequence_mask
sequence_lengths = mask.long().sum(-1)
print(f"sequence_lengths = {sequence_lengths}")

# sort_batch_by_length
assert (isinstance(inputs, Variable) and
        isinstance(sequence_lengths, Variable)), \
        "Both the tensor and sequence lengths must be "\
        "torch.autograd.Variables."

print('Sorting lengths by descending...')
sorted_length_and_permIx = sequence_lengths.sort(0, 
                                                 descending=True)
print("sorted_length_and_permIx = torch.return_types.sort("
      f"\n\t{sorted_length_and_permIx[0]}\t# sorted_sequence_lengths"
      f"\n\t{sorted_length_and_permIx[1]}\t# sorting_indices\n)")
sorted_sequence_lengths, sorting_indices = sorted_length_and_permIx

print("Sorting tensor...")
sorted_inputs = inputs.index_select(0, sorting_indices)

index_range = sequence_lengths.data.clone().copy_(
    torch.arange(0, len(sequence_lengths)))
print(f"index_range = {index_range}")

_, reverse_mapping = permutation_index.sort(0, descending=False)
restoration_indices = index_range.index_select(0, reverse_mapping)
print(f"restoration_indices = {restoration_indices}")

num_valid = 3
sequence_lengths = tensor([ 7,  6, 10], device='cuda:0')
Sorting lengths by descending...
sorted_length_and_permIx = torch.return_types.sort(
	tensor([10,  7,  6], device='cuda:0')	# sorted_sequence_lengths
	tensor([2, 0, 1], device='cuda:0')	# sorting_indices
)
Sorting tensor...
index_range = tensor([0, 1, 2], device='cuda:0')
restoration_indices = tensor([1, 2, 0], device='cuda:0')


In [81]:
sorted_inputs, sorted_inputs.shape

(tensor([[[-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
            1.9630e-02, -2.7469e-02],
          [ 7.7665e-03, -2.5998e-02, -2.0474e-02,  ...,  1.4450e-02,
            5.0952e-02, -1.0182e-02],
          [ 4.3900e-02, -2.0010e-02, -2.2308e-02,  ...,  1.1431e-02,
            4.4527e-02, -5.6586e-02],
          ...,
          [ 3.9943e-02, -6.3473e-04, -2.4333e-02,  ...,  1.9435e-02,
            5.1162e-02, -3.3509e-02],
          [ 2.2468e-02,  2.0974e-03, -7.3369e-03,  ...,  3.1688e-02,
            3.0309e-02, -4.4693e-02],
          [ 4.0999e-03, -6.1420e-03, -4.0220e-03,  ...,  3.6052e-02,
            3.8271e-02,  8.0397e-05]],
 
         [[-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
            1.9630e-02, -2.7469e-02],
          [ 2.9012e-02,  4.7038e-03, -2.2742e-02,  ...,  1.6245e-02,
            2.9220e-02, -1.6054e-02],
          [-5.4546e-03, -2.7518e-02, -1.8246e-02,  ...,  9.4597e-03,
            2.9380e-02, -9.5604e-03],
          ...,
    

In [82]:
sorted_sequence_lengths

tensor([10,  7,  6], device='cuda:0')

In [83]:
restoration_indices

tensor([1, 2, 0], device='cuda:0')

In [84]:
sorting_indices

tensor([2, 0, 1], device='cuda:0')

In [88]:
# pack_padded_sequence
packed_sequence_input = pack_padded_sequence(sorted_inputs[:num_valid, :, :],
                                             sorted_sequence_lengths[:num_valid].data.tolist(),
                                             batch_first=True)
packed_sequence_input

PackedSequence(data=tensor([[-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
          1.9630e-02, -2.7469e-02],
        [-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
          1.9630e-02, -2.7469e-02],
        [-3.3597e-03,  9.7976e-03,  1.4402e-03,  ...,  3.7601e-02,
          1.9630e-02, -2.7469e-02],
        ...,
        [ 3.9943e-02, -6.3473e-04, -2.4333e-02,  ...,  1.9435e-02,
          5.1162e-02, -3.3509e-02],
        [ 2.2468e-02,  2.0974e-03, -7.3369e-03,  ...,  3.1688e-02,
          3.0309e-02, -4.4693e-02],
        [ 4.0999e-03, -6.1420e-03, -4.0220e-03,  ...,  3.6052e-02,
          3.8271e-02,  8.0397e-05]], device='cuda:0',
       grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([3, 3, 3, 3, 3, 3, 2, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [90]:
packed_sequence_input.data.shape

torch.Size([23, 512])

- 처음일 경우, 아래와 같이 `initial_state`를 만든다.

In [93]:
hidden_state = None

In [94]:
stateful

False

In [97]:
initial_state = hidden_state

- 위에서 초기화한 상태를 가지고 lstm_forward를 실시

In [99]:
if initial_state is None:
    hidden_states = [None] * len(forward_layers)
hidden_states

[None, None]

In [102]:
inputs, batch_lengths = pad_packed_sequence(packed_sequence_input, 
                                            batch_first=True)

In [107]:
inputs.shape, batch_lengths

(torch.Size([3, 10, 512]), tensor([10,  7,  6]))

In [108]:
forward_output_sequence = inputs
backward_output_sequence = inputs

final_states = []
sequence_outputs = []

In [None]:
for layer_index, state in enumerate(hidden_states):
    forward_layer = forward_layers[layer_index]
    backward_layer = backward_layers[layer_index]
    
    forward_cache = forward_output_sequence
    backward_cache = backward_output_sequence
    
    # 맨 처음 실시될 당시에는 state == None임!
    forward_state = None
    backward_state = None
    