In [1]:
# -*- coding: utf-8 -*-

# *~ coding convention ~*
from overrides import overrides
from typing import Callable

# Python Standard Library
import collections
import itertools
import logging
import random
import codecs
import json
import os

# Python Installed Library
import torch
import torch.nn as nn
from torch.autograd import Variable

In [2]:
# fuction: dict to namedtuple
def dict2namedtuple(dic):
    return collections.namedtuple('Namespace', dic.keys())(**dic)

# input your directories path
model_dir = 'C:\workspace\implement_elmo\elmo\configs'
args2 = dict2namedtuple(
    json.load(
        codecs.open(
            os.path.join(model_dir, 'config.json'), 
            'r', encoding='utf-8')
    )
)

# args2.config_path == 'cnn_50_100_512_4096_sample.json'

# load config
with open(os.path.join(model_dir, args2.config_path), 'r') as fin:
    config = json.load(fin)
    
token_embedding = torch.load('token_embedding.pt') 
masks = [torch.load(f'mask[{ix}].pt') for ix in range(3)]

초기 세팅

In [3]:
# _EncoderBase
stateful = False
_states = None
use_cuda = torch.cuda.is_available()

input_size = config['encoder']['projection_dim']
hidden_size = config['encoder']['projection_dim']
cell_size = config['encoder']['dim']
num_layers = config['encoder']['n_layers']
memory_cell_clip_value = config['encoder']['cell_clip']
state_projection_clip_value = config['encoder']['proj_clip']
recurrent_dropout_probability = config['dropout']

print(f"input_size = {input_size}")
print(f"hidden_size = {hidden_size}")
print(f"cell_size = {cell_size}")
print(f"num_layers = {num_layers}")
print(f"memory_cell_clip_value = {memory_cell_clip_value}")
print(f"state_projection_clip_value = {state_projection_clip_value}")
print(f"recurrent_dropout_probability = {config['dropout']}")

forward_layers = []
backward_layers = []

lstm_input_size = input_size
go_forward = True

input_size = 512
hidden_size = 512
cell_size = 4096
num_layers = 2
memory_cell_clip_value = 3
state_projection_clip_value = 3
recurrent_dropout_probability = 0.1


In [80]:
from collections import defaultdict
from typing import Dict, List, Optional, Any, Tuple, Callable
import logging
import itertools
import math
import torch
from torch.autograd import Variable

def get_lengths_from_binary_sequence_mask(mask: torch.Tensor):
    return mask.long().sum(-1)

def get_dropout_mask(dropout_probability: float,
                     tensor_for_masking: Variable):
    print('*-*** get_dropout_mask ***-*')
    binary_mask = tensor_for_masking.clone()
    print('binary_mask', binary_mask)
    binary_mask.data.copy_(torch.rand(tensor_for_masking.size()) > dropout_probability)
    print(f'binary_mask = {torch.rand(tensor_for_masking.size()) > dropout_probability}')
    dropout_mask = binary_mask.float().div(1.0 - dropout_probability)
    print(f"Calc 1.0 / (1 - p) or 0.0")
    print(f"dropout_mask = {dropout_mask}")
    print('*-*** ---------------- ***-*')
    return dropout_mask

def block_orthogonal(tensor: torch.Tensor,
                     split_sizes: List[int],
                     gain: float = 1.0) -> None:
    """
    An initializer which allows initaliizing model parametes in "block".
    """
    if isinstance(tensor, Variable):
    # in pytorch 4.0, Variable equals Tensor
    #     block_orthogonal(tensor.data, split_sizes, gain)
    # else:
        sizes = list(tensor.size())
        if any([a % b != 0 for a, b in zip(sizes, split_sizes)]):
            raise ConfigurationError(
                "tensor dimentions must be divisible by their respective "
                f"split_sizes. Found size: {size} and split_sizes: {split_sizes}")
        indexes = [list(range(0, max_size, split))
                   for max_size, split in zip(sizes, split_sizes)]
        # Iterate over all possible blocks within the tensor.
        for block_start_indices in itertools.product(*indexes):
            index_and_step_tuples = zip(block_start_indices, split_sizes)
            block_slice = tuple([slice(start_index, start_index + step)
                                 for start_index, step in index_and_step_tuples])
            tensor[block_slice] = nn.init.orthogonal_(tensor[block_slice].contiguous(), gain=gain)
            
def sort_batch_by_length(tensor: torch.autograd.Variable,
                         sequence_lengths: torch.autograd.Variable):
    if not isinstance(tensor, Variable) or not isinstance(sequence_lengths, Variable):
        raise Exception("Both the tensor and sequence lengths must be torch.autograd.Variables.")
        
    sorted_sequence_lengths, permutation_index = sequence_lengths.sort(0, descending=True)
    sorted_tensor = tensor.index_select(0, permutation_index)
    
    # This is ugly, but required - we are creating a new variable at runtime, so we
    # must ensure it has the correct CUDA vs non-CUDA type. We do this by cloning and
    # refilling one of the inputs to the function.
    index_range = sequence_lengths.data.clone().copy_(torch.arange(0, len(sequence_lengths)))
    # This is the equivalent of zipping with index, sorting by the original
    # sequence lengths and returning the now sorted indices.
    index_range = Variable(index_range.long())
    _, reverse_mapping = permutation_index.sort(0, descending=False)
    restoration_indices = index_range.index_select(0, reverse_mapping)
    return sorted_tensor, sorted_sequence_lengths, restoration_indices, permutation_index

In [198]:
# 아직 코드 리뷰안한 코드!
from typing import Optional, Tuple, List, Callable, Union

import h5py
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence, pack_padded_sequence
from torch.autograd import Variable

# We have two types here for the state, because storing the state in something
# which is Iterable (like a tuple, below), is helpful for internal manipulation
# - however, the states are consumed as either Tensors or a Tuple of Tensors, so
# returning them in this format is unhelpful.
RnnState = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]  # pylint: disable=invalid-name
RnnStateStorage = Tuple[torch.Tensor, ...]  # pylint: disable=invalid-name


class _EncoderBase(nn.Module):
    # pyling: disable=abstract-method
    """
    This abstract class serves as a base for the 3 ``Encoder`` abstractions in AllenNLP.
    - :class:`~allennlp.modules.seq2seq_encoders.Seq2SeqEncoders`
    - :class:`~allennlp.modules.seq2vec_encoders.Seq2VecEncoders`
    Additionally, this class provides functionality for sorting sequences by length
    so they can be consumed by Pytorch RNN classes, which require their inputs to be
    sorted by length. Finally, it also provides optional statefulness to all of it's
    subclasses by allowing the caching and retrieving of the hidden states of RNNs.
    """
    def __init__(self, stateful: bool = False) -> None:
        super(_EncoderBase, self).__init__()
        self.stateful = stateful
        self._states: Optional[RnnStateStorage] = None

    def sort_and_run_forward(self,
                             module: Callable[[PackedSequence, Optional[RnnState]],
                                              Tuple[Union[PackedSequence, torch.Tensor], RnnState]],
                             inputs: torch.Tensor,
                             mask: torch.Tensor,
                             hidden_state: Optional[RnnState] = None):
        """
        Pytorch RNNs는 input이 passing되기 전에 정렬되있어야 함
        Seq2xxxEncoders가 이러한 기능을 모두 사용하기에 base class로 제공
        """
        # In some circumstances you may have sequences of zero length. ``pack_padded_sequence``
        # requires all sequence lengths to be > 0, so remove sequences of zero length before
        # calling self._module, then fill with zeros.

        # First count how many sequences are empty.
        batch_size = mask.size(0)
        num_valid = torch.sum(mask[:, 0]).int().item()
        print(f"\tbatch_size = {batch_size}, num_valid = {num_valid}")

        sequence_lengths = get_lengths_from_binary_sequence_mask(mask)
        print(f"\tsequence_lengths = {sequence_lengths}")
        sorted_inputs, sorted_sequence_lengths, restoration_indices, sorting_indices = \
            sort_batch_by_length(inputs, sequence_lengths)
        print(f"\t1. sorted_inputs.shape = {sorted_inputs.shape}")
        print(f"\t2. sorted_sequence_lengths = {sorted_sequence_lengths}")
        print(f"\t3. restoration_indices = {restoration_indices}")
        print(f"\t4. sorting_indices = {sorting_indices}")
        # Now create a PackedSequence with only the non-empty, sorted sequences.
        # pad token 제외, 유의미한 값들만 packing
        packed_sequence_input = pack_padded_sequence(sorted_inputs[:num_valid, :, :],
                                                     sorted_sequence_lengths[:num_valid].data.tolist(),
                                                     batch_first=True)
        print(f"\t             sorted_inputs.shape  = {sorted_inputs.shape}")
        print(f"\tpacked_sequence_input.data.shape  = {packed_sequence_input.data.shape}")
        print(f"\tpacked_sequence_input.batch_sizes = {packed_sequence_input.batch_sizes}")
        # Prepare teh initial states.
        print(f"\tself.stateful is {self.stateful}")
        if not self.stateful:
            print("\tstateful is False,", end='')
            print("If hidden_state is ", end='')
            if hidden_state == None:
                print("None,\n\t\tinitial_states = hidden_state")
                initial_states = hidden_state
            elif isinstance(hidden_state, tuple):
                print("tuple,\n\t\tinitial_states = [state.index_select(1, sorting_indices)[:, :num_valid, :] for state in hidden_state]")
                initial_states = [state.index_select(1, sorting_indices)[:, :num_valid, :]
                                  for state in hidden_state]
            else:
                print("not both None and tuple,\n\t\tConduct `_get_initial_states`")
                initial_stats = self._get_initial_states(batch_size, num_valid, sorting_indices)
        else:
            print("\tstateful is True,\n\t\tConduct `_get_initial_states`")
            initial_states = self._get_initial_states(batch_size, num_valid, sorting_indices)

        # Actually call the module on the sorted PackedSequence
        print("\tRUN `_lstm_forward`... by initial_states")
        module_output, final_states = module(packed_sequence_input, initial_states)
        

        return module_output, final_states, restoration_indices

    def _get_initial_states(self,
                            batch_size: int,
                            num_valid: int,
                            sorting_indices: torch.LongTensor) -> Optional[RnnState]:
        """
        RNN의 초기 상태를 반환
        추가적으로, 이 메서드는 batch의 새로운 요소의 초기 상태를 추가하기 위해 상태를 변경하여(mutate)
            호출시 batch size를 처리
        또한 이 메서드는
            1. 배치의 요소 seq. length로 상태를 정렬하는 것과
            2. pad가 끝난 row 제거도 처리
        중요한 것은 현재의 배치 크기가 이전에 호출되었을 때보다 더 크면 이 상태를 "혼합"하는 것이다.

        이 메서드는 (1) 처음 호출되어 아무 상태가 없는 경우 (2) RNN이 heterogeneous state를 가질 때
        의 경우를 처리해야 하기 때문에 return값이 복잡함

        (1) module이 처음 호출됬을 때 ``module``의 타입이 무엇이든 ``None`` 반환
        (2) Otherwise,
            - LSTM의 경우 tuple of ``torch.Tensor``
              shape: ``(num_layers, num_valid, state_size)``
                 and ``(num_layers, num_valid, memory_size)``
            - GRU의 경우  single ``torch.Tensor``
              shape: ``(num_layers, num_valid, state_size)``
        """
        # We don't know the state sizes the first time calling forward,
        # so we let the module define what it's initial hidden state looks like.
        if self._states is None:
            return None

        # Otherwise, we have some previous states.
        if batch_size > self._states[0].size(1):
            # This batch is larger than the all previous states.
            # If so, resize the states.
            num_states_to_concat = batch_size - self._states[0].size(1)
            resized_states = []
            # state has shape (num_layers, batch_size, hidden_size)
            for state in self._states:
                # This _must_ be inside the loop because some
                # RNNs have states with different last dimension sizes.
                zeros = state.data.new(state.size(0),
                                       num_states_to_concat,
                                       state.size(2)).fill_(0)
                zeros = Variable(zeros)
                resized_states.append(torch.cat([state, zeros], 1))
            self._states = tuple(resized_states)
            correctly_shaped_states = self._states
        elif batch_size < self._states[0].size(1):
            # This batch is smaller than the previous one.
            correctly_shaped_states = tuple(staet[:, :batch_size, :] for state in self._states)
        else:
            correctly_shaped_states = self._states

        # At this point, out states are of shape (num_layers, batch_size, hidden_size).
        # However, the encoder uses sorted sequences and additionally removes elements
        # of the batch which are fully padded. We need the states to match up to these
        # sorted and filtered sequences, so we do that in the next two blocks before
        # returning the states.
        if len(self._states) == 1:
            # GRU
            correctly_shaped_state = correctly_shaped_states[0]
            sorted_state = correctly_shaped_state.index_select(1, sorting_indices)
            return sorted_state[:, :num_valid, :]
        else:
            # LSTM
            sorted_states = [state.index_select(1, sorting_indices)
                             for state in correctly_shaped_states]
            return tuple(state[:, :num_valid, :] for state in sorted_states)

    def _update_states(self,
                       final_states: RnnStateStorage,
                       restoration_indices: torch.LongTensor) -> None:
        """
        RNN forward 동작 후에 state를 update
        새로운 state로 update하며 몇 가지 book-keeping을 실시
        즉, 상태를 해제하고 완전히 padding된 state가 업데이트되지 않도록 함
        마지막으로 graph가 매 batch iteration후에 gc되도록 계산 그래프에서
        state variable을 떼어냄.
        """
        # TODO(Mark)L seems weird to sort here, but append zeros in the subclasses.
        # which way around is best?
        print('_EncoderBase의 `_update_states` 메서드 실행')
        print(f'inputs:\nfinal_states = {final_states}\nrestoration_indices = {restoration_indices}')
        new_unsorted_states = [state.index_select(1, restoration_indices)
                               for state in final_states]
        print(f"new_unsorted_states = {new_unsorted_states}")
        print(f"self._states is None = {self._states is None}")
        if self._states is None:
            print("이전 상태가 존재하지 않습니다. new_unsorted_states로 새롭게 만들어 줍니다.")
            # We don't already have states, so just set the
            # ones we receive to be the current state.
            self._states = tuple([Variable(state.data)
                                  for state in new_unsorted_states])
            print('STATES:', self._states)
        else:
            print("이전 상태가 존재합니다. 현재 상태와 입력받은 final_state로 새로운 상태를 update합니다.")
            # Now we've sorted the states back so that they correspond to the original
            # indices, we need to figure out what states we need to update, because if we
            # didn't use a state for a particular row, we want to preserve its state.
            # Thankfully, the rows which are all zero in the state correspond exactly
            # to those which aren't used, so we create masks of shape (new_batch_size,),
            # denoting which states were used in the RNN computation.
            current_state_batch_size = self._states[0].size(1)
            new_state_batch_size = final_states[0].size(1)
            print(f"current_state_batch_size = {current_state_batch_size} = self._states[0].size(1)")
            print(f"new_state_batch_size = {new_state_batch_size} = final_states[0].size(1)")
            # Masks for the unused states of shape (1, new_batch_size, 1)
            used_new_rows_mask = [(state[0, :, :].sum(-1)
                                   != 0.0).float().view(1, new_state_batch_size, 1)
                                  for state in new_unsorted_states]
            new_states = []
            if current_state_batch_size > new_state_batch_size:
                # The new state is smaller than the old one,
                # so just update the indices which we used.
                for old_state, new_state, used_mask in zip(self._states,
                                                           new_unsorted_states,
                                                           used_new_rows_mask):
                    # zero out all rows in the previous state
                    # which _were_ used in the current state.
                    masked_old_state = old_state[:, :new_state_batch_size, :] * (1 - used_mask)
                    # The old state is larger, so update the relevant parts of it.
                    old_state[:, :new_state_batch_size, :] = new_state + masked_old_state
                    # Detatch the Variable.
                    new_states.append(torch.autograd.Variable(old_state.data))
            else:
                # The states are the same size, so we just have to
                # deal with the possibility that some rows weren't used.
                new_states = []
                for old_state, new_state, used_mask in zip(self._states,
                                                           new_unsorted_states,
                                                           used_new_rows_mask):
                    # zero out all rows which _were_ used in the current state.
                    masked_old_state = old_state * (1 - used_mask)
                    # The old state is larger, so update the relevant parts of it.
                    new_state += masked_old_state
                    # Detatch the Variable.
                    new_states.append(torch.autograd.Variable(new_state.data))

            # It looks like there should be another case handled here - when
            # the current_state_batch_size < new_state_batch_size. However,
            # this never happens, because the states themeselves are mutated
            # by appending zeros when calling _get_inital_states, meaning that
            # the new states are either of equal size, or smaller, in the case
            # that there are some unused elements (zero-length) for the RNN computation.
            self._states = tuple(new_states)

    def reset_states(self):
        self._states = None


class ElmobiLm(_EncoderBase):
    def __init__(self, config, use_cuda=False):
        super(ElmobiLm, self).__init__(stateful=True)
        self.config = config
        self.use_cuda = use_cuda
        input_size = config['encoder']['projection_dim']
        hidden_size = config['encoder']['projection_dim']
        cell_size = config['encoder']['dim']
        num_layers = config['encoder']['n_layers']
        memory_cell_clip_value = config['encoder']['cell_clip']
        state_projection_clip_value = config['encoder']['proj_clip']
        recurrent_dropout_probability = config['dropout']
        
        print('ELMo biLM layer params')
        print(f"\tinput_size = {input_size}")
        print(f"\thidden_size = {hidden_size}")
        print(f"\tcell_size = {cell_size}")
        print(f"\tnum_layers = {num_layers}")
        print(f"\tmemory_cell_clip_value = {memory_cell_clip_value}")
        print(f"\tstate_projection_clip_value = {state_projection_clip_value}")
#         print(f"\trecurrent_dropout_probability = {config['dropout']}")

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.cell_size = cell_size

        forward_layers = []
        backward_layers = []

        lstm_input_size = input_size
        go_forward = True
        for layer_index in range(num_layers):
            forward_layer = LstmCellWithProjection(lstm_input_size,
                                                   hidden_size,
                                                   cell_size,
                                                   go_forward,
                                                   recurrent_dropout_probability,
                                                   memory_cell_clip_value,
                                                   state_projection_clip_value).cuda()
            backward_layer = LstmCellWithProjection(lstm_input_size,
                                                    hidden_size,
                                                    cell_size,
                                                    not go_forward,
                                                    recurrent_dropout_probability,
                                                    memory_cell_clip_value,
                                                    state_projection_clip_value).cuda()
            if use_cuda:
                forward_layer = forward_layer.cuda()
                backward_layer = backward_layer.cuda()
            lstm_input_size = hidden_size

            self.add_module('forward_layer_{}'.format(layer_index), forward_layer)
            self.add_module('backward_layer_{}'.format(layer_index), backward_layer)
            forward_layers.append(forward_layer)
            backward_layers.append(backward_layer)
        self.forward_layers = forward_layers
        self.backward_layers = backward_layers
        print(f"forward_layers = {forward_layers}")
        print(f"backward_layers = {backward_layers}")

    def forward(self, inputs, mask):
        print('FORWARD!!!!**************')
        batch_size, total_sequence_length = mask.size()
        print(f"batch_size = {batch_size}")
        print(f"total_sequence_length = {total_sequence_length}")
        print("_EncoderBase.sort_and_run_forward 메서드 실시...")
        stacked_sequence_output, final_states, restoration_indices = \
            self.sort_and_run_forward(self._lstm_forward, inputs, mask)
        print(f"stacked_sequence_output.shape = {stacked_sequence_output.shape}")
        print(f"final_states = {final_states}")
        print(f"restoration_indices = {restoration_indices}")
        num_layers, num_valid, returned_timesteps, encoder_dim = stacked_sequence_output.size()
        # Add back invalid rows which were removed in the call to sort_and_run_forward.
        print("stacked")
        print(f"num_layers = {num_layers}")
        print(f"num_valid = {num_valid}")
        print(f"returned_timesteps = {returned_timesteps}")
        print(f"encoder_dim = {encoder_dim}")
        print(f"num_valid < batch_size -> {num_valid < batch_size}")
        if num_valid < batch_size:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size - num_valid,
                                                     returned_timesteps,
                                                     encoder_dim).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 1)

            # The states also need to have invalid rows added back.
            new_states = []
            for state in final_states:
                state_dim = state.size(-1)
                zeros = state.data.new(num_layers, batch_size - num_valid, state_dim).fill_(0)
                zeros = Variable(zeros)
                new_states.append(torch.cat([state, zeros], 1))
            final_states = new_states

        # It's possible to need to pass sequences which are padded to longer than the
        # max length of the sequence to a Seq2StackEncoder. However, packing and unpacking
        # the sequences mean that the returned tensor won't include these dimensions, because
        # the RNN did not need to process them. We add them back on in the form of zeros here.
        sequence_length_difference = total_sequence_length - returned_timesteps
        print("sequence_length_difference = total_sequence_length - returned_timesteps")
        print(f"sequence_length_difference = {sequence_length_difference}")
        print(f"sequence_length_difference is larger than 0? : {sequence_length_difference > 0}")
        if sequence_length_difference > 0:
            zeros = stacked_sequence_output.data.new(num_layers,
                                                     batch_size,
                                                     sequence_length_difference,
                                                     stacked_sequence_output[0].size(-1)).fill_(0)
            zeros = Variable(zeros)
            stacked_sequence_output = torch.cat([stacked_sequence_output, zeros], 2)
        print('UPDATE STATES... inputs: final_states, restoration_indices')
        self._update_states(final_states, restoration_indices)

        # Restore the original indices and return the sequence.
        # Has shape (num_layers, batch_size, sequence_length, hidden_size)
        return stacked_sequence_output.index_select(1, restoration_indices)


    def _lstm_forward(self,
                      inputs: PackedSequence,
                      initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> \
        Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        print(f"\t\tinitial_state is None? {initial_state is None}")
        if initial_state is None:
            print("\t\tOops, Assign hidden_state = [None] * len(self.forward_layers)")
            hidden_states: List[Optional[Tuple[torch.Tensor,
                                         torch.Tensor]]] = [None] * len(self.forward_layers)
            print(f"\t\thidden_states = {hidden_states}")
        elif initial_state[0].size()[0] != len(self.forward_layers):
            print(f"\t\tinitial_state[0].size()[0] = {initial_state[0].size()[0]}")
            print(f"\t\tlen(self.forward_layers) = {len(self.forward_layers)}")
            raise Exception("Initial states were passed to forward() but the number of "
                            "initial states does not match the number of layers.")
        else:
            print("\t\tinitial is not None and it's size equal to forward_layers' length,")
            print("\t\tthen hidden_states is")
            print(f"\t\t A = initial_state[0].split(1, 0) = {initial_state[0].split(1, 0)}")
            print(f"\t\t B = initial_state[1].split(1, 0) = {initial_state[1].split(1, 0)}")
            print("\t\t hidden_states = list(zip(A, B))")
            hidden_states = list(zip(initial_state[0].split(1, 0),
                                     initial_state[1].split(1, 0)))
            print(f"\t\t               = {hidden_states}")
        
        print("\t\tinputs is `PackedSequence`")
        print(f"\t\ttype(inputs) = {type(inputs)}")
        print(f"\t\t\tinputs.data.shape = {inputs.data.shape}")
        print(f"\t\t\tinputs.batch_sizes = {inputs.batch_sizes}")
        print(f"\t\t\tinputs.sorted_indices = {inputs.sorted_indices}")
        print(f"\t\t\tinputs.unsorted_indices = {inputs.unsorted_indices}")
        
        print("\t\tRestore PAD_char to inputs...")
        inputs, batch_lengths = pad_packed_sequence(inputs, batch_first=True)
        print("\t\t바뀐 inputs의 정보 출력")
        print(f"\t\ttype(inputs) = {type(inputs)}")
        print(f"\t\t\tinputs.shape = {inputs.shape}")
        print(f"\t\tbatch_lengths = {batch_lengths}")
        print("\t\tAssign forward_output_sequence = backward_output_sequence = inputs")
        forward_output_sequence = inputs
        backward_output_sequence = inputs
        
        print("\t\tSet final_states, sequqnce_outputs as empty list, []")
        final_states = []
        sequence_outputs = []
        for layer_index, state in enumerate(hidden_states):
            print(f"\t\tGet a forward layer and backward layer at layer {layer_index+1}")
            forward_layer = getattr(self, 'forward_layer_{}'.format(layer_index))
            backward_layer = getattr(self, 'backward_layer_{}'.format(layer_index))
            
            print("\t\tCaching...: output_sequence to cache both forward and backward")
            forward_cache = forward_output_sequence
            backward_cache = backward_output_sequence
            
            print(f"\t\tstate is None? {state is None}")
            if state is not None:
                print("\t\t\tAlright, Set hidden_state/memory_state for both forward and backward")
                print(f"\t\t\tstate[0](hidden_state) = {state[0]}")
                print(f"\t\t\tstate[1](memory_state) = {state[1]}")
                forward_hidden_state, backward_hidden_state = state[0].split(self.hidden_size, 2)
                forward_memory_state, backward_memory_state = state[1].split(self.cell_size, 2)
                forward_state = (forward_hidden_state, forward_memory_state)
                backward_state = (backward_hidden_state, backward_memory_state)
            else:
                print("\t\t\tOops, then forward and backward state is also 'None'")
                forward_state = None
                backward_state = None
                
            print("\t\tRUN forward_layer.forward method...")
            forward_output_sequence, forward_state = forward_layer(forward_output_sequence,
                                                                   batch_lengths,
                                                                   forward_state)
            print("\t\tRUN backward_layer.forward method...")
            backward_output_sequence, backward_state = backward_layer(backward_output_sequence,
                                                                      batch_lengths,
                                                                      backward_state)
            # Skip connections, just adding the input to the output.
            if layer_index != 0:
                print('\t\tsince layer_index != 0, adding cache to output sequence')
                forward_output_sequence += forward_cache
                backward_output_sequence += backward_cache
            
            
            sequence_outputs.append(torch.cat([forward_output_sequence,
                                               backward_output_sequence], -1))
            # Append the state tuples in a list, so that we can return
            # the final states for all the layers.
            final_states.append((torch.cat([forward_state[0], backward_state[0]], dim=-1),
                                 torch.cat([forward_state[1], backward_state[1]], dim=-1)))

        stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs)
        # Stack the hidden state and memory for each layer into 2 tensors of shape
        # (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size)
        # respectively.
        final_hidden_states, final_memory_states = zip(*final_states)
        final_state_tuple: Tuple[torch.FloatTensor,
                                 torch.FloatTensor] = (torch.cat(final_hidden_states, 0),
                                                       torch.cat(final_memory_states, 0))
        return stacked_sequence_outputs, final_state_tuple
    
class LstmCellWithProjection(torch.nn.Module):
    def __init__(self,
                 input_size: int,
                 hidden_size: int,
                 cell_size: int,
                 go_forward: bool = True,
                 recurrent_dropout_probability: float = 0.0,
                 memory_cell_clip_value: Optional[float] = None,
                 state_projection_clip_value: Optional[float] = None) -> None:
        super(LstmCellWithProjection, self).__init__()
        # Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`.
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.cell_size = cell_size

        self.go_forward = go_forward
        self.state_projection_clip_value = state_projection_clip_value
        self.memory_cell_clip_value = memory_cell_clip_value
        self.recurrent_dropout_probability = recurrent_dropout_probability

        # We do the projections for all the gates all at once.
        self.input_linearity = nn.Linear(input_size, 4 * cell_size, bias=False)
        self.state_linearity = nn.Linear(hidden_size, 4 * cell_size, bias=True)

        # Additional projection matrix for making the hidden state smaller.
        self.state_projection = nn.Linear(cell_size, hidden_size, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        # Use sensible default initializations for parameters.
        block_orthogonal(self.input_linearity.weight.data, [self.cell_size, self.input_size])
        block_orthogonal(self.state_linearity.weight.data, [self.cell_size, self.hidden_size])

        self.state_linearity.bias.data.fill_(0.0)
        # Initialize forget gate biases to 1.0 as per An Empirical
        # Exploration of Recurrent Network Architectures, (Jozefowicz, 2015).
        self.state_linearity.bias.data[self.cell_size:2 * self.cell_size].fill_(1.0)

    def forward(self,  # pylint: disable=arguments-differ
                inputs: torch.FloatTensor,
                batch_lengths: List[int],
                initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
        print(f"\t\t\tinputs.size() = {inputs.size()}")
        batch_size = inputs.size()[0]
        total_timesteps = inputs.size()[1]
        print('\t\t\tUnpacking batch_size, total_timesteps = inputs.size()')
        print(f'\t\t\tbatch_size = {batch_size}, total_timesteps = {total_timesteps}')

        # We have to use this '.data.new().fill_' pattern to create tensors with the correct
        # type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors.
        output_accumulator = Variable(inputs.data.new(batch_size,
                                                      total_timesteps,
                                                      self.hidden_size).fill_(0))
        print(f"\t\t\tCreate tensor(output_accumulator) which has ({batch_size}, {total_timesteps}, {self.hidden_size}) shape, filling 0.")
        print(f"\t\t\tis `initial_state` is None? {initial_state is None}")
        if initial_state is None:
            print("\t\t\t\tOh, then create full_batch_previous memory and state by "
                  f"({batch_size}, {self.cell_size}) tensor filling 0.")
            full_batch_previous_memory = Variable(inputs.data.new(batch_size,
                                                                  self.cell_size).fill_(0))
            full_batch_previous_state = Variable(inputs.data.new(batch_size,
                                                                 self.hidden_size).fill_(0))
        else:
            print("\t\t\t\tOk, Using `initial_state`, create full_batch_previous memory and state.")
            print(f"\t\t\t\t(previous_state) = initial_state[0] = {initial_state[0]}")
            print(f"\t\t\t\tfull_batch_previous_state = initial_state[0].squeeze(0) = {initial_state[0].squeeze(0)}")
            full_batch_previous_state = initial_state[0].squeeze(0)
            print(f"\t\t\t\t(previous_memory) = initial_state[1] = {initial_state[1]}")
            print(f"\t\t\t\tfull_batch_previous_memory = initial_state[1].squeeze(0) = {initial_state[1].squeeze(0)}")
            full_batch_previous_memory = initial_state[1].squeeze(0)

        current_length_index = batch_size - 1 if self.go_forward else 0
        print(f"\t\t\t\tSet current_length_index... is it forward?? {self.go_forward}")
        if self.go_forward:
            print(f"\t\t\t\tOk, forward!! current_length_index = batch_size - 1 = {batch_size - 1}")
            current_length_index = batch_size - 1
        else:
            print(f"\t\t\t\tOops, backward!! current_length_index = 0")
            current_length_index = 0
            
        print('\t\t\t\tis recurrent_dropout_probability is larger than 0?', self.recurrent_dropout_probability > 0.0)
        print('\t\t\t\tand is training?', self.training)
        if self.recurrent_dropout_probability > 0.0 and self.training:
            print('\t\t\t\tok, both is True. Execute `get_dropout_mask` function! using full_abtch_previous_state!')
            dropout_mask = get_dropout_mask(self.recurrent_dropout_probability,
                                            full_batch_previous_state)
        else:
            print('\t\t\t\toh, is not trainig. then dropout_mask = None.')
            dropout_mask = None
            
        print(f"\t\t\t\tStarting Loops with {total_timesteps}...")
        for timestep in range(total_timesteps):
            # The index depends on which end we start.
            index = timestep if self.go_forward else total_timesteps - timestep - 1
            print(f"\t\t\t\tindex = {index} since {'forward' if self.go_forward else 'backward'}")

            # What we are doing here is finding the index into the batch dimension
            # which we need to use for this timestep, because the sequences have
            # variable length, so once the index is greater than the length of this
            # particular batch sequence, we no longer need to do the computation for
            # this sequence. The key thing to recognise here is that the batch inputs
            # must be _ordered_ by length from longest (first in batch) to shortest
            # (last) so initially, we are going forwards with every sequence and as we
            # pass the index at which the shortest elements of the batch finish,
            # we stop picking them up for the computation.
            if self.go_forward:
                print('\t\t\t\tIn case forward')
                print(f"\t\t\t\tbatch_lengths[current_length_index] <= index = {batch_lengths[current_length_index] <= index}")
                while batch_lengths[current_length_index] <= index:
                    print("\t\t\t\tcurrent_length_index -= 1")
                    current_length_index -= 1
            # If we're going backwards, we are _picking up_ more indices.
            else:
                # First conditional: Are we already at the maximum number of elements in the batch?
                # Second conditional: Does the next shortest sequence beyond the current batch
                # index require computation use this timestep?
                print('\t\t\t\tIn case backward,')
                print(f"\t\t\t\tbatch_lengths[current_length_index] <= index = {batch_lengths[current_length_index] <= index}")
                while current_length_index < (len(batch_lengths) - 1) and \
                                batch_lengths[current_length_index + 1] > index:
                    print("\t\t\t\tcurrent_length_index += 1")
                    current_length_index += 1
            print(f'\t\t\t\tbatch_lengths[length_index] is {batch_lengths[current_length_index]}')

            # Actually get the slices of the batch which we
            # need for the computation at this timestep.
            # shape (batch_size, cell_size)
            print("\t\t\t\tGet a previous memory...")
            print(full_batch_previous_memory[0: current_length_index + 1])
            previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone()
            print(previous_memory.shape)
            # Shape (batch_size, hidden_size)
            print("\t\t\t\tGet a previous state...")
            print(full_batch_previous_memory[0: current_length_index + 1])
            previous_state = full_batch_previous_state[0: current_length_index + 1].clone()
            print(previous_state.shape)
            # Shape (batch_size, input_size)
            timestep_input = inputs[0: current_length_index + 1, index]
            print("\t\t\t\tGet a timestep input...")
            print(timestep_input)
            print(timestep_input.shape)

            # Do the projections for all the gates all at once.
            # Both have shape (batch_size, 4 * cell_size)
            print("\t\t\t\tProjection to 4*cell_size...")
            projected_input = self.input_linearity(timestep_input)
            print("\t\t\t\t`input_linearity`: W1 * timestep_input")
            print(f"\t\t\t\tprojected_input.shape = {projected_input.shape}")
            projected_state = self.state_linearity(previous_state)
            print("\t\t\t\t`state_linearity`: W2 * previous_state + b")
            print(f"\t\t\t\tprojected_state.shape = {projected_state.shape}")

            # Main LSTM equations using relevant chunks of the big linear
            # projections of the hidden state and inputs.
            print("\t\t\t\tCalc LSTM hidden unit...")
            input_gate = torch.sigmoid(projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] +
                                       projected_state[:, (0 * self.cell_size):(1 * self.cell_size)])
            forget_gate = torch.sigmoid(projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] +
                                        projected_state[:, (1 * self.cell_size):(2 * self.cell_size)])
            memory_init = torch.tanh(projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] +
                                     projected_state[:, (2 * self.cell_size):(3 * self.cell_size)])
            output_gate = torch.sigmoid(projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] +
                                        projected_state[:, (3 * self.cell_size):(4 * self.cell_size)])
            memory = input_gate * memory_init + forget_gate * previous_memory

            # Here is the non-standard part of this LSTM cell; first, we clip the
            # memory cell, then we project the output of the timestep to a smaller size
            # and again clip it.
            print(f"\t\t\t\tis memory_cell_clip_value is exist? {'Yes' if self.memory_cell_clip_value else 'No'}")
            if self.memory_cell_clip_value:
                print(f"\t\t\t\tOh, it's float. Set lower bound and upper bound at memory_cell_clip_value", end='')
                print(self.memory_cell_clip_value)
                # pylint: disable=invalid-unary-operand-type
                memory = torch.clamp(memory, -self.memory_cell_clip_value, self.memory_cell_clip_value)
            else:
                print("\t\t\t\tOh, it's None. passing the way.")

            print("\t\t\t\tCalc next timestep output...")
            # shape (current_length_index, cell_size)
            pre_projection_timestep_output = output_gate * torch.tanh(memory)

            # shape (current_length_index, hidden_size)
            timestep_output = self.state_projection(pre_projection_timestep_output)
            print(f"\t\t\t\tstate_projection_clip_value is exist? {'Yes' if self.state_projection_clip_value else 'No'}")
            if self.state_projection_clip_value:
                print(f"\t\t\t\tOh, it's float. Set lower bound and upper bound at state_projection_clip_value", end='')
                print(self.state_projection_clip_value)
                # pylint: disable=invalid-unary-operand-type
                timestep_output = torch.clamp(timestep_output,
                                              -self.state_projection_clip_value,
                                              self.state_projection_clip_value)
            else:
                print("\t\t\t\tOh, it's None. passing the way.")

            # Only do dropout if the dropout prob is > 0.0 and we are in training mode.
            print("\t\t\t\tIf dropout_mask exists, Adjust.")
            if dropout_mask is not None:
                timestep_output = timestep_output * dropout_mask[0: current_length_index + 1]

            # We've been doing computation with less than the full batch, so here we create a new
            # variable for the the whole batch at this timestep and insert the result for the
            # relevant elements of the batch into it.
            print('\t\t\t\tset full_batch_previous memory/state!!')
            full_batch_previous_memory = Variable(full_batch_previous_memory.data.clone())
            full_batch_previous_state = Variable(full_batch_previous_state.data.clone())
            full_batch_previous_memory[0:current_length_index + 1] = memory
            full_batch_previous_state[0:current_length_index + 1] = timestep_output
            output_accumulator[0:current_length_index + 1, index] = timestep_output

        # Mimic the pytorch API by returning state in the following shape:
        # (num_layers * num_directions, batch_size, ...). As this
        # LSTM cell cannot be stacked, the first dimension here is just 1.
        final_state = (full_batch_previous_state.unsqueeze(0),
                       full_batch_previous_memory.unsqueeze(0))
        print(f"\t\t\t\tfinal_state = {final_state}")

        return output_accumulator, final_state

In [199]:
encoder = ElmobiLm(config, use_cuda=True)

ELMo biLM layer params
	input_size = 512
	hidden_size = 512
	cell_size = 4096
	num_layers = 2
	memory_cell_clip_value = 3
	state_projection_clip_value = 3
forward_layers = [LstmCellWithProjection(
  (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
  (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
  (state_projection): Linear(in_features=4096, out_features=512, bias=False)
), LstmCellWithProjection(
  (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
  (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
  (state_projection): Linear(in_features=4096, out_features=512, bias=False)
)]
backward_layers = [LstmCellWithProjection(
  (input_linearity): Linear(in_features=512, out_features=16384, bias=False)
  (state_linearity): Linear(in_features=512, out_features=16384, bias=True)
  (state_projection): Linear(in_features=4096, out_features=512, bias=False)
), LstmCellWithProjection(
  (input_line

In [200]:
encoder_output = encoder(token_embedding, Variable(masks[0]))

FORWARD!!!!**************
batch_size = 3
total_sequence_length = 10
_EncoderBase.sort_and_run_forward 메서드 실시...
	batch_size = 3, num_valid = 3
	sequence_lengths = tensor([ 7,  6, 10], device='cuda:0')
	1. sorted_inputs.shape = torch.Size([3, 10, 512])
	2. sorted_sequence_lengths = tensor([10,  7,  6], device='cuda:0')
	3. restoration_indices = tensor([1, 2, 0], device='cuda:0')
	4. sorting_indices = tensor([2, 0, 1], device='cuda:0')
	             sorted_inputs.shape  = torch.Size([3, 10, 512])
	packed_sequence_input.data.shape  = torch.Size([23, 512])
	packed_sequence_input.batch_sizes = tensor([3, 3, 3, 3, 3, 3, 2, 1, 1, 1])
	self.stateful is True
	stateful is True,
		Conduct `_get_initial_states`
	RUN `_lstm_forward`... by initial_states
		initial_state is None? True
		Oops, Assign hidden_state = [None] * len(self.forward_layers)
		hidden_states = [None, None]
		inputs is `PackedSequence`
		type(inputs) = <class 'torch.nn.utils.rnn.PackedSequence'>
			inputs.data.shape = torch.Size(

       device='cuda:0', grad_fn=<SelectBackward>)
torch.Size([1, 512])
				Projection to 4*cell_size...
				`input_linearity`: W1 * timestep_input
				projected_input.shape = torch.Size([1, 16384])
				`state_linearity`: W2 * previous_state + b
				projected_state.shape = torch.Size([1, 16384])
				Calc LSTM hidden unit...
				is memory_cell_clip_value is exist? Yes
				Oh, it's float. Set lower bound and upper bound at memory_cell_clip_value3
				Calc next timestep output...
				state_projection_clip_value is exist? Yes
				Oh, it's float. Set lower bound and upper bound at state_projection_clip_value3
				If dropout_mask exists, Adjust.
				set full_batch_previous memory/state!!
				index = 9 since forward
				In case forward
				batch_lengths[current_length_index] <= index = False
				batch_lengths[length_index] is 10
				Get a previous memory...
tensor([[-0.0154, -0.0035,  0.0223,  ...,  0.0023,  0.0046, -0.0197]],
       device='cuda:0', grad_fn=<SliceBackward>)
torch.Size([1, 409

tensor([[ 3.9943e-02, -6.3473e-04, -2.4333e-02, -2.6465e-02,  3.6282e-02,
         -8.1245e-02,  2.0234e-02, -4.7533e-02, -1.4706e-02,  1.8934e-02,
          4.9761e-04, -6.6569e-02,  3.0510e-02,  8.3242e-03,  3.4185e-03,
          4.2514e-02,  3.6875e-02,  1.1155e-02, -6.4569e-02, -2.6975e-02,
          1.4409e-02, -8.3901e-03,  1.4958e-02,  1.4861e-02,  1.6609e-02,
          4.6618e-02,  8.1440e-02,  6.3781e-02,  4.3447e-02,  5.9358e-03,
         -5.9265e-04,  3.1957e-02, -3.8824e-03,  6.6324e-02,  1.1476e-03,
          3.0135e-02,  2.8017e-02, -2.6136e-03, -8.7879e-02,  8.0971e-03,
          4.2865e-02,  2.4699e-02, -1.6333e-03,  4.5484e-02,  4.8548e-02,
         -2.7261e-02,  4.9371e-02, -1.1094e-02, -6.2017e-02,  6.3433e-02,
          1.1362e-02, -3.7070e-02, -2.8188e-02, -3.8195e-02, -5.8539e-02,
          1.2449e-02, -2.7576e-03,  3.2095e-02,  7.0868e-03,  7.0801e-03,
         -2.9167e-02, -6.1884e-02,  3.3293e-02,  7.4093e-03,  8.6460e-03,
          2.1394e-02,  2.8104e-02, -2.

         -3.4580e-04, -3.3531e-04]], device='cuda:0', grad_fn=<SliceBackward>)
torch.Size([3, 4096])
				Get a previous state...
tensor([[-2.3543e-04,  1.6494e-03, -8.9001e-05,  ...,  1.4458e-03,
          5.2153e-05, -1.1236e-03],
        [ 1.3564e-03,  1.1446e-03,  9.0791e-04,  ...,  5.4099e-04,
         -4.4836e-04, -8.9195e-04],
        [ 2.0109e-03,  4.3985e-04, -1.1258e-03,  ...,  1.2718e-03,
         -3.4580e-04, -3.3531e-04]], device='cuda:0', grad_fn=<SliceBackward>)
torch.Size([3, 512])
				Get a timestep input...
tensor([[ 0.0000,  0.0024, -0.0022,  ..., -0.0039, -0.0060, -0.0030],
        [ 0.0058,  0.0014, -0.0022,  ..., -0.0024, -0.0074, -0.0039],
        [ 0.0045, -0.0001, -0.0016,  ..., -0.0028, -0.0049, -0.0042]],
       device='cuda:0', grad_fn=<SelectBackward>)
torch.Size([3, 512])
				Projection to 4*cell_size...
				`input_linearity`: W1 * timestep_input
				projected_input.shape = torch.Size([3, 16384])
				`state_linearity`: W2 * previous_state + b
				projected_

tensor([[ 6.5776e-04,  8.5008e-04, -9.0526e-04,  4.5458e-03, -7.3213e-04,
          2.0647e-03, -1.1265e-03, -3.2782e-03,  1.6024e-03,  3.6667e-03,
         -4.1003e-04, -1.7901e-03,  2.6226e-04, -3.7832e-04,  0.0000e+00,
          0.0000e+00,  1.1847e-03,  3.1437e-03, -2.5433e-03,  1.4463e-03,
         -5.9337e-04,  1.8439e-03, -1.7969e-04, -2.8421e-03,  0.0000e+00,
          7.0549e-04, -1.5668e-03, -7.0174e-04, -1.2155e-04, -2.2479e-03,
         -6.7636e-04,  1.9502e-03,  1.0213e-03, -5.0221e-04, -2.5682e-03,
         -1.5692e-03,  0.0000e+00, -6.0628e-04,  1.0332e-04, -7.1813e-04,
          1.1502e-05,  4.9686e-03, -0.0000e+00,  9.6054e-04, -3.1009e-03,
         -1.9814e-03, -2.2883e-03,  0.0000e+00,  1.3948e-03, -2.6666e-03,
          5.6644e-04, -2.0831e-04, -0.0000e+00,  1.8394e-03,  0.0000e+00,
         -7.8157e-05,  0.0000e+00, -2.1115e-03,  5.4939e-03, -0.0000e+00,
          1.2039e-03, -4.7978e-03,  1.9393e-03, -3.2964e-03, -6.0457e-05,
         -6.8667e-04,  2.8970e-03, -5.

       device='cuda:0', grad_fn=<SelectBackward>)
torch.Size([3, 512])
				Projection to 4*cell_size...
				`input_linearity`: W1 * timestep_input
				projected_input.shape = torch.Size([3, 16384])
				`state_linearity`: W2 * previous_state + b
				projected_state.shape = torch.Size([3, 16384])
				Calc LSTM hidden unit...
				is memory_cell_clip_value is exist? Yes
				Oh, it's float. Set lower bound and upper bound at memory_cell_clip_value3
				Calc next timestep output...
				state_projection_clip_value is exist? Yes
				Oh, it's float. Set lower bound and upper bound at state_projection_clip_value3
				If dropout_mask exists, Adjust.
				set full_batch_previous memory/state!!
				index = 0 since backward
				In case backward,
				batch_lengths[current_length_index] <= index = False
				batch_lengths[length_index] is 6
				Get a previous memory...
tensor([[-6.2347e-05,  6.1063e-04, -2.6934e-03,  ...,  4.7033e-03,
         -3.0298e-03, -3.0434e-03],
        [ 1.1751e-03,  1.1402e-03

In [24]:
encoder_output.shape

torch.Size([2, 3, 10, 1024])

In [25]:
encoder_output = encoder(token_embedding, Variable(masks[0]))

FORWARD!!!!**************
batch_size = 3
total_sequence_length = 10
_EncoderBase.sort_and_run_forward 메서드 실시...
	batch_size = 3, num_valid = 3
	sequence_lengths = tensor([ 7,  6, 10], device='cuda:0')
	1. sorted_inputs.shape = torch.Size([3, 10, 512])
	2. sorted_sequence_lengths = tensor([10,  7,  6], device='cuda:0')
	3. restoration_indices = tensor([1, 2, 0], device='cuda:0')
	4. sorting_indices = tensor([2, 0, 1], device='cuda:0')
	             sorted_inputs.shape  = torch.Size([3, 10, 512])
	packed_sequence_input.data.shape  = torch.Size([23, 512])
	packed_sequence_input.batch_sizes = tensor([3, 3, 3, 3, 3, 3, 2, 1, 1, 1])
	self.stateful is True
	stateful is True,
		Conduct `_get_initial_states`
	RUN `_lstm_forward`... by initial_states
		initial_state is None? False
		initial is not None and it's size equal to forward_layers' length,
		then hidden_states is
		 A = initial_state[0].split(1, 0) = (tensor([[[0., 0., 0.,  ..., -0., 0., 0.],
         [0., 0., 0.,  ..., -0., 0., 0.],
  

In [12]:
encoder_output.shape

torch.Size([2, 3, 10, 1024])