In [10]:
# -*- coding: utf-8 -*-

# *~ coding convention ~*
from typing import Optional, Tuple, List, Callable, Union
from overrides import overrides

# Python Standard Library
import collections
import itertools
import logging
import random
import codecs
import json
import os

# Python Installed Library
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence
from torch.nn.utils.rnn import pad_packed_sequence
from torch.nn.utils.rnn import pack_padded_sequence

RnnState = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]  # pylint: disable=invalid-name
RnnStateStorage = Tuple[torch.Tensor, ...]  # pylint: disable=invalid-name

In [2]:
# fuction: dict to namedtuple
def dict2namedtuple(dic):
    return collections.namedtuple('Namespace', dic.keys())(**dic)

# input your directories path
model_dir = 'C:\workspace\implement_elmo\elmo\configs'
args2 = dict2namedtuple(
    json.load(
        codecs.open(
            os.path.join(model_dir, 'config.json'), 
            'r', encoding='utf-8')
    )
)

# args2.config_path == 'cnn_50_100_512_4096_sample.json'

# load config
with open(os.path.join(model_dir, args2.config_path), 'r') as fin:
    config = json.load(fin)
    
token_embedding = torch.load('token_embedding.pt') 
masks = [torch.load(f'mask[{ix}].pt') for ix in range(3)]

In [5]:
token_embedding.shape # batch, seq_len, emb_dim

torch.Size([3, 10, 512])

In [8]:
masks[0].shape

torch.Size([3, 10])

In [22]:
torch.Tensor([1,2]).item()

ValueError: only one element tensors can be converted to Python scalars

In [68]:
class _EncoderBase(nn.Module):
    
    def __init__(self, 
                 stateful: bool=False,
                 batch_first: bool=True) -> None:
        super(_EncoderBase, self).__init__()
        self.stateful = stateful
        self.batch_first = batch_first
        self._states: Optional[RnnStateStorage] = None
    
    def sort_and_run_forward(self,
                             module: Callable[
                                 [PackedSequence, Optional[RnnState]],
                                 Tuple[Union[PackedSequence, torch.Tensor], 
                                       RnnState]],
                             inputs: torch.Tensor,
                             mask: torch.Tensor,
                             hidden_state: Optional[RnnState]=None):
        """First, count how many sequences are empty."""
        if self.batch_first:
            batch_size = mask.size(0)
        else:
            batch_size = mask.size(1)
        # one elt tensor to scalar
        num_valid = torch.sum(mask[:, 0]).int().item()
        
        sequence_lengths = self.get_lengths_from_binary_sequence_mask(mask, self.batch_first)
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices = \
            sort_batch_by_length(inputs, sequence_lengths)
        
        """Now create a PackedSequence with only the non-empty, sorted sequences."""
        # pad token 제외, 유의미한 값들만 packing
        packed_sequence_input = pack_padded_sequence(sorted_inputs[:num_valid, :, :],
                                                     sorted_sequence_lengths[:num_valid].data.tolist(),
                                                     batch_first=self.batch_first)
        
        """Prepare the init states."""
        if not self.stateful:
            if hidden_state == None:
                initial_states = hidden_state # Set None
            elif isinstance(hidden_state, tuple):
                initial_states = [state.index_select(1, sorting_indices)[:,:num_valid,:]
                                  for state in hidden_state]
            else:
                initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)
        else:
            initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)
            
    def _get_initial_states(self,
                            batch_size: int,
                            num_valid: int,
                            sorting_indices: torch.LongTensor) -> Optional[RnnState]:
        """
        RNN의 초기 상태를 반환
        추가적으로, 이 메서드는 batch의 새로운 요소의 초기 상태를 추가하기 위해 상태를 변경하여(mutate)
            호출시 batch size를 처리
        또한 이 메서드는
            1. 배치의 요소 seq. length로 상태를 정렬하는 것과
            2. pad가 끝난 row 제거도 처리
        중요한 것은 현재의 배치 크기가 이전에 호출되었을 때보다 더 크면 이 상태를 "혼합"하는 것이다.

        이 메서드는 (1) 처음 호출되어 아무 상태가 없는 경우 (2) RNN이 heterogeneous state를 가질 때
        의 경우를 처리해야 하기 때문에 return값이 복잡함

        (1) module이 처음 호출됬을 때 ``module``의 타입이 무엇이든 ``None`` 반환
        (2) Otherwise,
            - LSTM의 경우 tuple of ``torch.Tensor``
                Cell state, Hidden state
              shape: ``(num_layers, num_valid, state_size)``
                 and ``(num_layers, num_valid, memory_size)``
            - GRU의 경우  single ``torch.Tensor``
                Hidden state
              shape: ``(num_layers, num_valid, state_size)``
        """
        # 초기 상태는 None
        if self._states is None:
            return None

        # Otherwise, we have some previous states.
        if batch_size > self._states[0].size(1):
            # This batch is larger than the all previous states.
            # If so, resize the states.
            num_states_to_concat = batch_size - self._states[0].size(1)
            resized_states = []
            # state has shape (num_layers, batch_size, hidden_size)
            for state in self._states:
                # This _must_ be inside the loop because some
                # RNNs have states with different last dimension sizes.
                zeros = state.data.new(state.size(0),
                                       num_states_to_concat,
                                       state.size(2)).fill_(0)
                zeros = Variable(zeros)
                resized_states.append(torch.cat([state, zeros], 1))
            self._states = tuple(resized_states)
            correctly_shaped_states = self._states
        elif batch_size < self._states[0].size(1):
            # This batch is smaller than the previous one.
            correctly_shaped_states = tuple(state[:, :batch_size, :] for state in self._states)
        else:
            correctly_shaped_states = self._states

        # At this point, out states are of shape (num_layers, batch_size, hidden_size).
        # However, the encoder uses sorted sequences and additionally removes elements
        # of the batch which are fully padded. We need the states to match up to these
        # sorted and filtered sequences, so we do that in the next two blocks before
        # returning the states.
        if len(self._states) == 1:
            # GRU
            correctly_shaped_state = correctly_shaped_states[0]
            sorted_state = correctly_shaped_state.index_select(1, sorting_indices)
            return sorted_state[:, :num_valid, :]
        else:
            # LSTM
            sorted_states = [state.index_select(1, sorting_indices)
                             for state in correctly_shaped_states]
            return tuple(state[:, :num_valid, :] for state in sorted_states)

    def _update_states(self,
                       final_states: RnnStateStorage,
                       restoration_indices: torch.LongTensor) -> None:
        """
        RNN forward 동작 후에 state를 update
        새로운 state로 update하며 몇 가지 book-keeping을 실시
        즉, 상태를 해제하고 완전히 padding된 state가 업데이트되지 않도록 함
        마지막으로 graph가 매 batch iteration후에 gc되도록 계산 그래프에서
        state variable을 떼어냄.
        """
        # TODO(Mark)L seems weird to sort here, but append zeros in the subclasses.
        # which way around is best?
        print('_EncoderBase의 `_update_states` 메서드 실행')
        print(f'inputs:\nfinal_states = {final_states}\nrestoration_indices = {restoration_indices}')
        new_unsorted_states = [state.index_select(1, restoration_indices)
                               for state in final_states]
        print(f"new_unsorted_states = {new_unsorted_states}")
        print(f"self._states is None = {self._states is None}")
        if self._states is None:
            print("이전 상태가 존재하지 않습니다. new_unsorted_states로 새롭게 만들어 줍니다.")
            # We don't already have states, so just set the
            # ones we receive to be the current state.
            self._states = tuple([Variable(state.data)
                                  for state in new_unsorted_states])
            print('STATES:', self._states)
        else:
            print("이전 상태가 존재합니다. 현재 상태와 입력받은 final_state로 새로운 상태를 update합니다.")
            # Now we've sorted the states back so that they correspond to the original
            # indices, we need to figure out what states we need to update, because if we
            # didn't use a state for a particular row, we want to preserve its state.
            # Thankfully, the rows which are all zero in the state correspond exactly
            # to those which aren't used, so we create masks of shape (new_batch_size,),
            # denoting which states were used in the RNN computation.
            current_state_batch_size = self._states[0].size(1)
            new_state_batch_size = final_states[0].size(1)
            print(f"current_state_batch_size = {current_state_batch_size} = self._states[0].size(1)")
            print(f"new_state_batch_size = {new_state_batch_size} = final_states[0].size(1)")
            # Masks for the unused states of shape (1, new_batch_size, 1)
            used_new_rows_mask = [(state[0, :, :].sum(-1)
                                   != 0.0).float().view(1, new_state_batch_size, 1)
                                  for state in new_unsorted_states]
            new_states = []
            if current_state_batch_size > new_state_batch_size:
                # The new state is smaller than the old one,
                # so just update the indices which we used.
                for old_state, new_state, used_mask in zip(self._states,
                                                           new_unsorted_states,
                                                           used_new_rows_mask):
                    # zero out all rows in the previous state
                    # which _were_ used in the current state.
                    masked_old_state = old_state[:, :new_state_batch_size, :] * (1 - used_mask)
                    # The old state is larger, so update the relevant parts of it.
                    old_state[:, :new_state_batch_size, :] = new_state + masked_old_state
                    # Detatch the Variable.
                    new_states.append(torch.autograd.Variable(old_state.data))
            else:
                # The states are the same size, so we just have to
                # deal with the possibility that some rows weren't used.
                new_states = []
                for old_state, new_state, used_mask in zip(self._states,
                                                           new_unsorted_states,
                                                           used_new_rows_mask):
                    # zero out all rows which _were_ used in the current state.
                    masked_old_state = old_state * (1 - used_mask)
                    # The old state is larger, so update the relevant parts of it.
                    new_state += masked_old_state
                    # Detatch the Variable.
                    new_states.append(torch.autograd.Variable(new_state.data))

            # It looks like there should be another case handled here - when
            # the current_state_batch_size < new_state_batch_size. However,
            # this never happens, because the states themeselves are mutated
            # by appending zeros when calling _get_inital_states, meaning that
            # the new states are either of equal size, or smaller, in the case
            # that there are some unused elements (zero-length) for the RNN computation.
            self._states = tuple(new_states)
            
        
        
    @staticmethod
    def get_lengths_from_binary_sequence_mask(mask: torch.Tensor, batch_first: bool):
        axis = 0 if batch_first else -1
        return mask.sum(axis)
    
    @staticmethod
    def sort_batch_by_length(tensor: torch.autograd.Variable,
                             sequence_lengths: torch.autograd.Variable,
                             batch_idx: int=0):
        if not isinstance(tensor, Variable) or not isinstance(sequence_lengths, Variable):
            raise Exception("Both the tensor and sequence lengths must be torch.autograd.Variables.")
        sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
        sorted_tensor = tensor.index_select(batch_idx, permutation_index)
        _, restoration_indices = permutation_index.sort(0, descending=False)
        return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index

IndentationError: expected an indented block (<ipython-input-68-81686c2ac3b9>, line 83)