In [1]:
from operator import itemgetter

import torch
import torch.nn as nn



In [2]:
beam_size = 5
max_length = 50
device = 'cpu'

In [16]:
word_indice = [torch.LongTensor(beam_size).zero_().to(device)+1]
print(f'word_indice = {word_indice}')
beam_indice = [torch.LongTensor(beam_size).zero_().to(device)-1]
print(f'beam_indice = {beam_indice}')
cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(device)]
print(f'cumulative_probs = {cumulative_probs}')
masks = [torch.BoolTensor(beam_size).zero_().to(device)]
print(f'masks = {masks}')

print('--------------------------------------------')

prev_status = {}
batch_dims = {}
prev_status_config = {
    'prev_state_0': {                   # input에 해당하는 state
        'init_status': None,
        'batch_dim_index': 0,
    },
    'prev_state_1': {                   # 첫번째 블락을 통과한 state
        'init_status': None,
        'batch_dim_index': 0,
    }}

for prev_status_name, each_config in prev_status_config.items():
    print(prev_status_name, each_config)
    init_status = each_config['init_status']
    batch_dim_index = each_config['batch_dim_index']
    if init_status is not None:
        prev_status[prev_status_name] = torch.cat([init_status]*beam_size, dim = batch_dim_index)
    else:
        prev_status[prev_status_name] = None
    batch_dims[prev_status_name] = batch_dim_index
current_time_step = 0
done_cnt = 0



word_indice = [tensor([1, 1, 1, 1, 1])]
beam_indice = [tensor([-1, -1, -1, -1, -1])]
cumulative_probs = [tensor([0., -inf, -inf, -inf, -inf])]
masks = [tensor([False, False, False, False, False])]
--------------------------------------------
prev_state_0 {'init_status': None, 'batch_dim_index': 0}
prev_state_1 {'init_status': None, 'batch_dim_index': 0}


In [19]:
y_hat = word_indice[-1].unsqueeze(-1)

In [21]:
y_hat.shape

torch.Size([5, 1])

In [22]:
prev_status

{'prev_state_0': None, 'prev_state_1': None}

In [23]:
fab_input = [torch.tensor([0,0,0,0,0]).unsqueeze(-1), torch.tensor([1,0,0,0,0]).unsqueeze(-1)]
torch.cat(fab_input, dim = 0)

tensor([[0],
        [0],
        [0],
        [0],
        [0],
        [1],
        [0],
        [0],
        [0],
        [0]])

In [113]:
from operator import itemgetter

import torch
import torch.nn as nn

# import simple_nmt.data_loader as data_loader

LENGTH_PENALTY = .2
MIN_LENGTH = 5


class SingleBeamSearchBoard():
    '''
    From Board
        - input : x_t
        - last hidden : h_t_1
        - last cell : c_t_1
        - last hidden_tilde : h_tilde_t_1
    
    '''
    def __init__(
        self,
        device,
        prev_status_config, # Fake minibatch를 만들기 위해선, input, hidden, cell, hidden_tilde가 필요함. 이게 prev_status_config에 들어감.// type은 dict의 dict
        beam_size=5,
        max_length=255,
    ):
        '''
        init에 previous_status를 저장함. -> prev_status
        

        '''
        self.beam_size = beam_size
        self.max_length = max_length

        # To put data to same device.
        self.device = device
        # Inferred word index for each time-step. For now, initialized with initial time-step. 첫번째니까 빔 사이즈 만큼 모두 BOS가 들어가야함.
        self.word_indice = [torch.LongTensor(beam_size).zero_().to(self.device) + 1]
            # [tensor([0,0,0,0,0])]
            # 추후에 tensor([1,0,0,0,0]) 이런게 쌓임.
        # Beam index for selected word index, at each time-step. 빔 사이즈 만큼 -1을 채워넣은 텐서
        self.beam_indice = [torch.LongTensor(beam_size).zero_().to(self.device) - 1]
        # Cumulative log-probability for each beam. 
        self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')]*(beam_size - 1)).to(self.device)]
            # 처음 cumulative_probs에서 [0, -inf, -inf, -inf, -inf]로 하고싶음. 왜냐면 BOS는 확정적인거라 1의 확률을 갖는데 log(BOS) = 0임.
            # 가지가 분할하기 전에는 한개의 확률만 갖으므로 나머지는 -inf로 채움.
            # 첫번째 빔에서(0)만 5개의 후보가 뽑힐거야
        # 1 if it is done else 0
        self.masks = [torch.BoolTensor(beam_size).zero_().to(self.device)]
            # 0이면 진행, 1(EOS)면 끝.

        # We don't need to remember every time-step of hidden states:
        #       prev_hidden, prev_cell, prev_h_t_tilde
        # What we need is remember just last one.

        #-------------------- 맨처음 세팅해줘야 하는것 -----------------------
        self.prev_status = {}
        self.batch_dims = {}
        for prev_status_name, each_config in prev_status_config.items():
            init_status = each_config['init_status'] # 바로 전의 state을 가져오고
            batch_dim_index = each_config['batch_dim_index'] # 배치 인덱스를 가져와
            if init_status is not None:
                self.prev_status[prev_status_name] = torch.cat([init_status] * beam_size,
                                                               dim=batch_dim_index)
                    # s2s - hidden, cell :  [L, B*beam_size, H] 여기서 B는 1임.
            else:
                self.prev_status[prev_status_name] = None
                    # s2s - h_tilde : [B*beam, 1, hidden]
            self.batch_dims[prev_status_name] = batch_dim_index
                # s2s - {hidden_state : 1, ...}
        self.current_time_step = 0
        self.done_cnt = 0






    def get_length_penalty(
        self,
        length,
        alpha=LENGTH_PENALTY,
        min_length=MIN_LENGTH,
    ):
        # Calculate length-penalty,
        # because shorter sentence usually have bigger probability.
        # In fact, we represent this as log-probability, which is negative value.
        # Thus, we need to multiply bigger penalty for shorter one.
        p = ((min_length + 1) / (min_length + length))**alpha
        # 6/(5+1)
        return p



    def is_done(self):
        '''
        빔사이즈보다, done_cnt가 크거나 같으면 1을 리턴, 아니면 0을 리턴.   
        done_cnt는 collect_result에서 업데이트 될것.     
        '''

        # Return 1, if we had EOS more than 'beam_size'-times.
        if self.done_cnt >= self.beam_size:
            return 1
        return 0



    def get_batch(self):
        '''
        returning [beam_size,1] : last step output(V_t)
                [baem_size, L, H] : prev_state_i  = 이거 튜플임.
        '''
        y_hat = self.word_indice[-1].unsqueeze(-1)
            # word_indice : tensor([0,0,0,0,0]) 이전 타임 스탭의 출력물을 가져옴. -> unsqueeze(-1) : [5,1]
        # |y_hat| = (beam_size, 1)
        # if model != transformer:
        #     |hidden| = |cell| = (n_layers, beam_size, hidden_size)
        #     |h_t_tilde| = (beam_size, 1, hidden_size) or None
        # else:
        #     |prev_state_i| = (beam_size, length, hidden_size),
        #     where i is an index of layer.
        return y_hat, self.prev_status




    #@profile 
    def collect_result(self, y_hat, prev_status):
        # |y_hat| = (beam_size, 1, output_size)
        # prev_status is a dict, which has following keys:
        # if model != transformer:
        #     |hidden| = |cell| = (n_layers, beam_size, hidden_size)
        #     |h_t_tilde| = (beam_size, 1, hidden_size)
        # else:
        #     |prev_state_i| = (beam_size, length, hidden_size),
        #     where i is an index of layer.
        output_size = y_hat.size(-1)

        self.current_time_step += 1

        #---------------- Calculate cumulative log-probability. ----------------------
        # First, fill -inf value to last cumulative probability, if the beam is already finished.
        # Second, expand -inf filled cumulative probability to fit to 'y_hat'.
        # (beam_size) --> (beam_size, 1, 1) --> (beam_size, 1, output_size)
        # Third, add expanded cumulative probability to 'y_hat'
        cumulative_prob = self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf'))
            # cumulative_probs들이 5개씩 탁탁탁 쌓일텐데 그중 마지막걸 가져와서
            # 마지막 마스킹 정보를 갖고와서, True이면 마스크를 하고 with -inf로, 아니면 마스킹을 하지 않는다.
        cumulative_prob = y_hat + cumulative_prob.view(-1, 1, 1).expand(self.beam_size, 1, output_size) # broadcasting되기 때문에 expand안해도됨.
        # |cumulative_prob| = (beam_size, 1, output_size)

        # Now, we have new top log-probability and its index.
        # We picked top index as many as 'beam_size'.
        # Be aware that we picked top-k from whole batch through 'view(-1)'.

        # Following lines are using torch.topk, which is slower than torch.sort.
        # top_log_prob, top_indice = torch.topk(
        #     cumulative_prob.view(-1), # (beam_size * output_size,)
        #     self.beam_size,
        #     dim=-1,
        # )

        # Following lines are using torch.sort, instead of using torch.topk.
        top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)
            # torch.sort를 사용하면 : values, indice두개를 내뱉는다.
        top_log_prob, top_indice = top_log_prob[:self.beam_size], top_indice[:self.beam_size]
            # 상위 5개만 갖고온다.

        # |top_log_prob| = (beam_size,)
        # |top_indice| = (beam_size,)

        # Because we picked from whole batch, original word index should be calculated again.
        self.word_indice += [top_indice.fmod(output_size)]
            # fmod : element-wise나머지 구하기. // devided by output_size
            # outputsize로 나누면 원래 vocab_index가 리턴이 되겠네
        # Also, we can get an index of beam, which has top-k log-probability search result.
        self.beam_indice += [top_indice.div(float(output_size)).long()]
            # 41030 -> 4번째 빔에서, 1030번째 단어. 여기서 구하고자 하는것은 몇번째 빔인지 구하고 싶음.

        # Add results to history boards.
        self.cumulative_probs += [top_log_prob]
        self.masks += [torch.eq(self.word_indice[-1], 2)] # Set finish mask if we got EOS.
            # torch equal (word_indice[-1], data_loader.EOS) -> 1 if it is, else 0
        # Calculate a number of finished beams.
        self.done_cnt += self.masks[-1].float().sum() # EOS가 몇개 있엇는지 확인

        # In beam search procedure, we only need to memorize latest status.
        # For seq2seq, it would be lastest hidden and cell state, and h_t_tilde.
        # The problem is hidden(or cell) state and h_t_tilde has different dimension order.
        # In other words, a dimension for batch index is different.
        # Therefore self.batch_dims stores the dimension index for batch index.
        # For transformer, lastest status is each layer's decoder output from the biginning.
        # Unlike seq2seq, transformer has to memorize every previous output for attention operation.
        for prev_status_name, prev_status in prev_status.items():
            self.prev_status[prev_status_name] = torch.index_select(
                prev_status,
                dim=self.batch_dims[prev_status_name], # 어떤 차원을 뽑아올지
                index=self.beam_indice[-1] # 정해진 dim에서 몇번째 데이터를 뽑아올지.
            ).contiguous()







    def get_n_best(self, n=1, length_penalty=.2):
        '''
        output : 최고의 확률을 갖는 5개를 선별하고, 
        sentences와 probs를 return한다.

        sentences : [[2021, 3, 1394, ...],
                    [3019, ,20, 391, ...],
                    ...
                    [1010, 50, 0203, ...]]

        probs : [0.3, 0.5, 0.1, 0.3, 0.4]

        '''

        sentences, probs, founds = [], [], []

        for t in range(len(self.word_indice)):  # for each time-step,,,
            # word_indice : [[0,0,0,0,0],[1,0,0,0,0],[1,0,0,0,0],...]
            for b in range(self.beam_size):  # for each beam, // 5번
                if self.masks[t][b] == 1:  # if we had EOS on this time-step and beam, EOS를 찾으면,
                    # Take a record of penaltified log-proability.
                    probs += [self.cumulative_probs[t][b] * self.get_length_penalty(t, alpha=length_penalty)]
                    founds += [(t, b)] # 어디서 EOS를 찾았는지 -> 그 확률이 어떻게 되나 나중에 역추적할라고
                    # 재수가 없으면 EOS없이 max len으로 끝날수 있음.

        # Also, collect log-probability from last time-step, for the case of EOS is not shown.
        for b in range(self.beam_size):
            if self.cumulative_probs[-1][b] != -float('inf'): # If this beam does not have EOS,
                if not (len(self.cumulative_probs) - 1, b) in founds:
                    probs += [self.cumulative_probs[-1][b] * self.get_length_penalty(len(self.cumulative_probs),
                                                                                     alpha=length_penalty)]
                    founds += [(t, b)]

        # Sort and take n-best.
        sorted_founds_with_probs = sorted(
            zip(founds, probs),
            key=itemgetter(1),
            reverse=True,
        )[:n]
        probs = []

        for (end_index, b), prob in sorted_founds_with_probs:
            sentence = []

            # Trace from the end.
            for t in range(end_index, 0, -1):
                sentence = [self.word_indice[t][b]] + sentence
                b = self.beam_indice[t][b] # 빔따라서 거꾸로 올라감.

            sentences += [sentence]
            probs += [prob]

        return sentences, probs


In [123]:



def _generate_mask( x, length):
    '''

    에) length : [4,3,2]
    '''
    mask = []

    max_length = max(length)
    for l in length:
        if max_length - l > 0:
            # If the length is shorter than maximum length among samples,
            # set last few values to be 1s to remove attention weight.
            mask += [torch.cat([x.new_ones(1, l).zero_(),
                                x.new_ones(1, (max_length - l))
                                ], dim=-1)]
        else:
            # If length of sample equals to maximum length among samples,
            # set every value in mask to be 0.
            mask += [x.new_ones(1, l).zero_()]

    mask = torch.cat(mask, dim=0).bool()
    # |mask| = (batch_size, max_length)

    return mask



x = [torch.randn(64,20,100), torch.randint(20,[64])]

batch_size = x[0].size(0)
n_dec_layers = 10

decoder = nn.Sequential(*[nn.Linear(100,100) for _ in range(n_dec_layers)])

mask = _generate_mask(x[0], x[1])
# |mask| = (batch_size, n)
x = x[0]

mask_enc = mask.unsqueeze(1).expand(mask.size(0), x.size(1), mask.size(-1))
mask_dec = mask.unsqueeze(1)
# |mask_enc| = (batch_size, n, n)
# |mask_dec| = (batch_size, 1, n)

z = torch.randn(64, 20, 100)
# |z| = (batch_size, n, hidden_size)
# --------------------여기까지 search 함수와 똑같음 --------------------------


prev_status_config = {}
for layer_index in range(n_dec_layers + 1):
    prev_status_config['prev_state_%d' % layer_index] = {
        'init_status': None,
        'batch_dim_index': 0,
    }
    
    
boards = [
    SingleBeamSearchBoard(
        z.device,
        prev_status_config,
        beam_size=beam_size,
        max_length=max_length,
    ) for _ in range(batch_size)
]
done_cnt = [board.is_done() for board in boards]

length = 0

while sum(done_cnt) < batch_size and length <= max_length:
    fab_input, fab_z, fab_mask = [],[],[]
    fab_prevs = [[] for _ in range(n_dec_layers + 1)]
    
    
    for i, board in enumerate(boards):
        if board.is_done() == 0:
            y_hat_i, prev_status = board.get_batch()
            
            fab_input += [y_hat_i]
            fab_z += [z[i].unsqueeze(0)]*beam_size
            fab_mask += [mask_dec[i].unsqueeze(0)]*beam_size
            
            for layer_index in range(n_dec_layers + 1):
                prev_i = prev_status['prev_state_%d' % layer_index]
                if prev_i is not None:
                    fab_prevs[layer_index] += [prev_i]
                else:
                    fab_prevs[layer_index] = None
    fab_input = torch.cat(fab_input, dim=0)
    fab_z     = torch.cat(fab_z,     dim=0)
    fab_mask  = torch.cat(fab_mask,  dim=0)
    for i, fab_prev in enumerate(fab_prevs): # i == layer_index
        if fab_prev is not None:
            fab_prevs[i] = torch.cat(fab_prev, dim=0)
            
    h_t = fab_input.unsqueeze(-1).expand(*fab_input.size(), 100)
    h_t = torch.randn(320,1,100)
    
    if fab_prevs[0] is None:
        fab_prevs[0] = h_t
    else:
        fab_prevs[0] = torch.cat([fab_prevs[0], h_t], dim = 1)
        
    for layer_index, block in enumerate(decoder._modules.values()):
        prev = fab_prevs[layer_index]
        
        h_t = block(h_t)
        
        if fab_prevs[layer_index + 1] is None:
            fab_prevs[layer_index + 1] = h_t
        else:
            fab_prevs[layer_index + 1] = torch.cat([fab_prevs[layer_index + 1], h_t], dim = 1)
            
    
#     y_hat_t = generator(h_t)
    y_hat_t = torch.randn(320, 1, 100)
    
    cnt = 0
    for board in boards:
        if board.is_done() == 0:
            begin = cnt*beam_size
            end = begin + beam_size
            
            prev_status = {}
            for layer_index in range(n_dec_layers + 1):
                prev_status['prev_state_%d' % layer_index] = fab_prevs[layer_index][begin:end]
                
            board.collect_result(y_hat_t[begin:end], prev_status)
            
            cnt += 1
            
    done_cnt = [board.is_done() for board in boards]
    length += 1
    
    if length == 20:
        break
    
    
    

In [134]:
fab_prevs[0].shape

torch.Size([320, 20, 100])

In [125]:
board.is_done()

0

In [135]:
a = [[]]

for i in range(3):
    a[0] += torch.randn(64,10,100)

In [131]:
board.get_batch()[1]['prev_state_0'].shape

torch.Size([5, 20, 100])

In [142]:
aa = nn.LayerNorm(100)
aa(torch.randn(320,10,100))

tensor([[[ 0.8279, -0.2189, -0.3758,  ..., -1.0492, -0.2075, -0.1194],
         [ 0.5449, -0.4042,  0.7085,  ..., -1.5870,  0.2523,  1.5955],
         [-1.2392,  0.2312, -2.3776,  ..., -0.4245,  1.4392, -0.0119],
         ...,
         [-0.5817, -1.4574, -0.9687,  ...,  0.8331,  0.0382,  1.8559],
         [ 0.2154, -2.1307, -0.4138,  ..., -0.6675, -0.9962, -0.7747],
         [-0.8420,  0.5955, -0.6288,  ..., -1.0931, -0.5311, -1.2907]],

        [[ 1.1565, -1.1198, -0.9311,  ...,  0.8487,  1.5364, -1.0445],
         [ 0.8156, -1.6106,  0.7313,  ...,  0.4643,  1.3690,  1.3075],
         [ 0.6905,  0.9609,  2.2514,  ..., -0.8298,  0.1296, -0.5510],
         ...,
         [-0.3318,  1.1115,  0.2640,  ...,  1.2754,  0.8217, -0.0123],
         [-0.0852,  0.4577,  0.8864,  ..., -0.7627, -0.3184, -0.9649],
         [-0.1376,  0.8702, -0.0674,  ..., -0.0086,  0.0292,  1.4140]],

        [[-1.4239,  1.2559, -0.0205,  ..., -1.4270,  1.9788, -0.6371],
         [-0.1323, -0.7480,  1.4446,  ...,  0

In [83]:
h_t.shape

torch.Size([320, 20, 100])

In [67]:
fab_z.shape

torch.Size([320, 20, 100])

In [73]:
torch.cat([z[0].unsqueeze(0)]*5, dim = 0).shape

torch.Size([5, 20, 100])

In [174]:
cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')]*(beam_size - 1))]
cumulative_probs

[tensor([0., -inf, -inf, -inf, -inf])]

In [184]:
output_size = 10000
y_hat = torch.randn(5,1,10000)
masks = [torch.tensor([0,0,0,0,0]).bool()]

cumulative_prob = cumulative_probs[-1].masked_fill_(masks[-1], -float('inf'))
cumulative_prob

tensor([0., -inf, -inf, -inf, -inf])

In [185]:
cumulative_prob = y_hat + cumulative_prob.view(-1,1,1).expand(5,1,output_size)
cumulative_prob.shape

torch.Size([5, 1, 10000])

In [188]:
cumulative_prob

tensor([[[-1.5879,  0.7067,  0.7412,  ...,  0.6248, -0.5969, -0.9310]],

        [[   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf]],

        [[   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf]],

        [[   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf]],

        [[   -inf,    -inf,    -inf,  ...,    -inf,    -inf,    -inf]]])

In [189]:
top_log_prob, top_indice = cumulative_prob.view(-1).sort(descending=True)


In [190]:
top_log_prob

tensor([3.6192, 3.3950, 3.3080,  ...,   -inf,   -inf,   -inf])

In [191]:
top_indice

tensor([ 2534,  7517,  9918,  ..., 49997, 49998, 49999])

In [180]:
cumulative_prob.view(-1)

tensor([0., -inf, -inf, -inf, -inf])

In [181]:
cumulative_prob.shape

torch.Size([5])

In [192]:
x = torch.randn(3, 4)
indices = torch.tensor([0,0])
torch.index_select(x, 0, indices)

tensor([[0.9862, 1.3751, 0.0278, 0.4193],
        [0.9862, 1.3751, 0.0278, 0.4193]])