In [9]:
import json
import os
import utils

model_dir = "experiments/"
json_path = os.path.join(model_dir, 'params.json')
params = utils.Params(json_path)
params.device = torch.device(params.device)

In [12]:
from model.data_loader import DataLoader

data_dir = "data/full_version/"
data_loader = DataLoader(data_dir, params)
params.vocab_size = len(data_loader.BABYNAME.vocab)

vocab built


In [14]:
import model.charRNN as net
model = net.Net(params).to(params.device)

In [17]:
weight_path = os.path.join(model_dir, "best.pth.tar")
checkpoint = utils.load_checkpoint(weight_path, model)

In [54]:
import torch
import torch.nn.functional as F

In [69]:
def sample(net, prime="A", category="boy"):
    
    category_tensor = data_loader.SEX.process(['boy']).data.sub_(1).float()
    
    prime = prime.lower()
    prime_tensor = data_loader.BABYNAME.process([prime])[:, :-1]
    bsz, prime_tensor_length = prime_tensor.size()
    
    net.eval()
    hidden = net.init_hidden(1)
    
    for step in range(prime_tensor_length):
        outputs, hidden = net(category_tensor, prime_tensor[:, step], hidden)
        probabilities = F.softmax(outputs, 1)
    
    return probabilities.squeeze()
        
probabilities = sample(model, prime="A", category="boy")

In [79]:
def calcualte_top_k(probabilities, k):
    prob, idx = torch.topk(probabilities, k)
    chars = [data_loader.BABYNAME.vocab.itos[char] for char in idx.cpu().numpy()]
    return prob.detach().cpu().numpy(), chars

In [137]:
from collections import OrderedDict, defaultdict

In [95]:
def beam_search(net, prime="A", category="boy", beam_width=3):
    print("Sampling a {} name beginning with {}..".format(category, prime))
    
    initial_probabilities = sample(net, prime=prime, category=category)
    
    _prob_dict = {}
    prob, chars = calcualte_top_k(initial_probabilities, beam_width)
    for p, c in zip(prob, chars):
        _prob_dict[prime + c] = p
        
    _prob_dict_2 = {}
    for prime, prob in _prob_dict.items():
        probabilities = sample(net, prime=prime, category=category)
        prob, chars = calcualte_top_k(probabilities, probabilities.size(0))
        for p, c in zip(prob, chars):
            _prob_dict_2[prime + c] = p
    
    print(_prob_dict_2)
    
beam_search(model)

Sampling a boy name beginning with A..
{'All': 0.21784028, 'Ald': 0.09377982, 'Alf': 0.08585877, 'Ale': 0.06972965, 'Alv': 0.06927484, 'Alb': 0.067968026, 'Ala': 0.057595167, 'Alt': 0.046755683, 'Ali': 0.041659147, 'Alo': 0.030670848, 'Alp': 0.030295575, 'Alm': 0.02858818, 'Alu': 0.026979713, 'Als': 0.020680347, 'Alc': 0.01978152, 'Aly': 0.016061509, 'Alg': 0.014761034, 'Alw': 0.013700099, 'Alk': 0.0119596515, 'Alj': 0.008504599, 'Alh': 0.0076925475, 'Alr': 0.005483069, 'Alz': 0.0041428055, 'Aln': 0.0030471366, 'Alq': 0.0015253004, 'Al ': 0.0014234417, 'Al<eos>': 0.0010466484, 'Alx': 0.00090113824, "Al'": 0.0006199554, 'Al.': 0.00033923425, 'Al<unk>': 0.00033805674, 'Al<bos>': 0.00033805674, 'Al,': 0.00032149217, 'Al-': 0.0003200074, 'Al<pad>': 1.661644e-05, 'Arn': 0.1431779, 'Arl': 0.11548514, 'Art': 0.08677163, 'Ari': 0.07781561, 'Arr': 0.07774458, 'Ard': 0.07005276, 'Arm': 0.0570677, 'Arc': 0.054755755, 'Ars': 0.050334852, 'Arv': 0.04710375, 'Are': 0.03305963, 'Ara': 0.030229568, 'A

In [155]:
def sample(net, prime="A", category="boy", hidden=None):
    
    category_tensor = data_loader.SEX.process(['boy']).data.sub_(1).float()
    
    prime = prime.lower()
    prime_tensor = data_loader.BABYNAME.process([prime])[:, :-1]
    bsz, prime_tensor_length = prime_tensor.size()
    
    net.eval()
    if not hidden:
        hidden = net.init_hidden(1)
    
    for step in range(prime_tensor_length):
        with torch.no_grad():
            outputs, hidden = net(category_tensor, prime_tensor[:, step], hidden)
        probabilities = F.softmax(outputs, 1)
    
    return torch.log(probabilities.squeeze()), hidden
        
probabilities, prime_hidden = sample(model, prime="A", category="boy")

In [156]:
probabilities

tensor([ -8.8408, -11.2302,  -8.8408,  -6.5204,  -4.0978,  -4.2354,
         -2.1199,  -3.3469,  -1.4864,  -1.8328,  -5.2466,  -3.5051,
         -3.5879,  -4.2780,  -2.7338,  -3.8513,  -2.9054,  -3.7857,
         -3.6540,  -4.3559,  -2.7267,  -4.1226,  -5.3045,  -3.9120,
         -4.2732,  -4.9682,  -5.3083,  -4.7135,  -6.5695,  -5.8877,
         -8.3208,  -8.4416,  -8.3012,  -8.9544,  -8.9642])

In [157]:
def clean_beam_basket(basket, beam_width):
    _tmp_basket = basket.copy()
    to_remove = sorted(basket)[beam_width:]
    for item in to_remove:
        _tmp_basket.pop(item)
        
    return _tmp_basket

In [138]:
def beam_search(net, prime="A", category="boy", beam_width=3):
    print("Sampling a {} name beginning with {}..".format(category, prime))
    
#     initial_probabilties, prime_hidden = sample(net, prime=prime, category=category)
    
    beam_basket = OrderedDict()
    beam_basket[prime] = 0
    hidden_dict = defaultdict()
#     hidden_dict[prime] = prime_hidden
    
    
    counter = 0
    while True:
        counter += 1
        print("counter: {}".format(counter))
        
        # 바스켓을 청소한다.
        beam_basket = clean_beam_basket(beam_basket, beam_width)
        print(beam_basket)
        
        # 만약 바스켓에 모든 아이템이 <eos>가 있으면 루프를 멈춘다.
        eos_cnt = 0
        for k in beam_basket.keys():
            if "<eos>" in k:
                eos_cnt += 1
        if eos_cnt == beam_width:
            print("all items have <eos>")
            break
            
        # 모든 key를 돌면서
        ## <eos>가 없는 경우 inference를 한다.
        for k in beam_basket.keys():
            if "<eos>" not in k:
                hidden = hidden_dict.get(k)
                probabilities, hidden = sample(net, prime=k, category=category, hidden=hidden)
                for ix, prob in enumerate(probabilities):
                    new_k = k + data_loader.BABYNAME.vocab.itos[ix]
                    beam_basket[new_k] = beam_basket
                
            
        
        
        
        break
    
    
    
#     initial_probabilities = sample(net, prime=prime, category=category)
    
#     _prob_dict = {}
#     prob, chars = calcualte_top_k(initial_probabilities, beam_width)
#     for p, c in zip(prob, chars):
#         _prob_dict[prime + c] = p
        
#     _prob_dict_2 = {}
#     for prime, prob in _prob_dict.items():
#         probabilities = sample(net, prime=prime, category=category)
#         prob, chars = calcualte_top_k(probabilities, probabilities.size(0))
#         for p, c in zip(prob, chars):
#             _prob_dict_2[prime + c] = p
    
#     print(_prob_dict_2)
    
beam_search(model)

Sampling a boy name beginning with A..
counter: 1
OrderedDict([('A', 1.0)])


In [162]:
for ix, p in enumerate(probabilities):
    print(ix, p)

0 tensor(-8.8408)
1 tensor(-11.2302)
2 tensor(-8.8408)
3 tensor(-6.5204)
4 tensor(-4.0978)
5 tensor(-4.2354)
6 tensor(-2.1199)
7 tensor(-3.3469)
8 tensor(-1.4864)
9 tensor(-1.8328)
10 tensor(-5.2466)
11 tensor(-3.5051)
12 tensor(-3.5879)
13 tensor(-4.2780)
14 tensor(-2.7338)
15 tensor(-3.8513)
16 tensor(-2.9054)
17 tensor(-3.7857)
18 tensor(-3.6540)
19 tensor(-4.3559)
20 tensor(-2.7267)
21 tensor(-4.1226)
22 tensor(-5.3045)
23 tensor(-3.9120)
24 tensor(-4.2732)
25 tensor(-4.9682)
26 tensor(-5.3083)
27 tensor(-4.7135)
28 tensor(-6.5695)
29 tensor(-5.8877)
30 tensor(-8.3208)
31 tensor(-8.4416)
32 tensor(-8.3012)
33 tensor(-8.9544)
34 tensor(-8.9642)


In [89]:
torch.topk(probabilities, probabilities.size(0))

(tensor([ 0.2262,  0.1600,  0.1200,  0.0654,  0.0650,  0.0547,  0.0352,
          0.0300,  0.0277,  0.0259,  0.0227,  0.0213,  0.0200,  0.0166,
          0.0162,  0.0145,  0.0139,  0.0139,  0.0128,  0.0090,  0.0070,
          0.0053,  0.0050,  0.0050,  0.0028,  0.0015,  0.0014,  0.0002,
          0.0002,  0.0002,  0.0001,  0.0001,  0.0001,  0.0001,  0.0000]),
 tensor([  8,   9,   6,  20,  14,  16,   7,  11,  12,  18,  17,  15,
          23,   4,  21,   5,  24,  13,  19,  27,  25,  10,  22,  26,
          29,   3,  28,  32,  30,  31,   2,   0,  33,  34,   1]))

In [88]:
probabilities.size()

torch.Size([35])

In [97]:
beam_basket = OrderedDict()

In [139]:
a = defaultdict()

In [151]:
print(a.get('a'))

None


In [104]:
beam_basket = OrderedDict()
beam_basket['A'] = 1.0
beam_basket['B'] = 0.8
beam_basket['C'] = 0.7
beam_basket['D'] = 0.4

In [160]:
probabilities.sort(descending=True)

(tensor([ -1.4864,  -1.8328,  -2.1199,  -2.7267,  -2.7338,  -2.9054,
          -3.3469,  -3.5051,  -3.5879,  -3.6540,  -3.7857,  -3.8513,
          -3.9120,  -4.0978,  -4.1226,  -4.2354,  -4.2732,  -4.2780,
          -4.3559,  -4.7135,  -4.9682,  -5.2466,  -5.3045,  -5.3083,
          -5.8877,  -6.5204,  -6.5695,  -8.3012,  -8.3208,  -8.4416,
          -8.8408,  -8.8408,  -8.9544,  -8.9642, -11.2302]),
 tensor([  8,   9,   6,  20,  14,  16,   7,  11,  12,  18,  17,  15,
          23,   4,  21,   5,  24,  13,  19,  27,  25,  10,  22,  26,
          29,   3,  28,  32,  30,  31,   2,   0,  33,  34,   1]))

In [124]:
data_loader.BABYNAME.vocab.stoi

defaultdict(<function torchtext.vocab._default_unk_index()>,
            {'<unk>': 0,
             '<pad>': 1,
             '<bos>': 2,
             '<eos>': 3,
             'a': 4,
             'e': 5,
             'n': 6,
             'i': 7,
             'l': 8,
             'r': 9,
             'o': 10,
             's': 11,
             't': 12,
             'h': 13,
             'd': 14,
             'y': 15,
             'm': 16,
             'c': 17,
             'u': 18,
             'k': 19,
             'b': 20,
             'g': 21,
             'j': 22,
             'v': 23,
             'p': 24,
             'f': 25,
             'w': 26,
             'z': 27,
             'q': 28,
             'x': 29,
             ' ': 30,
             '-': 31,
             "'": 32,
             '.': 33,
             ',': 34,
             'A': 0,
             '<': 0,
             '>': 0})