# 텍스트 생성
Winnie the Pooh 텍스트 파일을 이용한 학습  
https://machinelearningmastery.com/text-generation-with-lstm-in-pytorch/

In [1]:
import numpy as np
import os 
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

In [3]:
# 파일 읽기, 소문자로 변환
file_name ='./Winnie-the-Pooh.txt'
raw_text = open(file_name, 'r', encoding='utf-8').read()
lower_text = raw_text.lower()

# chars -> integer mapping
chars = sorted(list(set(lower_text))) 
chars_to_int = dict((c, i) for i, c in enumerate(chars))
print(chars_to_int) #파일 내에 있는 고유 문자를 단어사전으로 저장함



n_chars = len(lower_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

{'\n': 0, ' ': 1, '!': 2, '"': 3, '#': 4, '$': 5, '%': 6, '&': 7, "'": 8, '(': 9, ')': 10, '*': 11, ',': 12, '-': 13, '.': 14, '/': 15, '0': 16, '1': 17, '2': 18, '3': 19, '4': 20, '5': 21, '6': 22, '7': 23, '8': 24, '9': 25, ':': 26, ';': 27, '?': 28, '[': 29, ']': 30, '_': 31, 'a': 32, 'b': 33, 'c': 34, 'd': 35, 'e': 36, 'f': 37, 'g': 38, 'h': 39, 'i': 40, 'j': 41, 'k': 42, 'l': 43, 'm': 44, 'n': 45, 'o': 46, 'p': 47, 'q': 48, 'r': 49, 's': 50, 't': 51, 'u': 52, 'v': 53, 'w': 54, 'x': 55, 'y': 56, 'z': 57, 'æ': 58, '—': 59, '‘': 60, '’': 61, '“': 62, '”': 63, '•': 64, '™': 65}
Total Characters:  148066
Total Vocab:  66


In [4]:
seq_len = 100 #101자로 나눠서 학습, 예측을 수행함
X_data = []
Y_data = []
for i in range(0, n_chars - seq_len, 1):
    seq_in = lower_text[i:i + seq_len]
    seq_out = lower_text[i + seq_len]
    X_data.append([chars_to_int[c] for c in seq_in])
    Y_data.append(chars_to_int[seq_out])
n_patterns = len(X_data)
print("Total Patterns: ", n_patterns)

Total Patterns:  147966


In [5]:
X = torch.tensor(X_data, dtype=torch.float32).reshape(n_patterns, seq_len, 1) #tensor를 통해 각 문자들을 나눔
X = X / float(n_vocab) #vocab수로 나눠서 정규화 (0-1로 만듦 <- Pytorch는 0~1 값 선호)
y = torch.tensor(Y_data)
print(X.shape, y.shape)

torch.Size([147966, 100, 1]) torch.Size([147966])


In [8]:
class BuildModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=256, num_layers=1, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(256, n_vocab)
    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.linear(self.dropout(x))
        return x

In [9]:
epochs = 20
batch_size = 128
model = BuildModel()

optimizer = optim.Adam(model.parameters())
loss_fn = nn.CrossEntropyLoss(reduction="sum")
loader = data.DataLoader(data.TensorDataset(X, y), shuffle=True, batch_size=batch_size)

best_model = None
best_loss = np.inf
for epoch in range(epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #validation
    model.eval()
    loss = 0
    with torch.no_grad():                   #학습을 멈추고 모델 평가 (validation)
        for X_batch, y_batch in loader:
            y_pred = model(X_batch)
            loss += loss_fn(y_pred, y_batch)
        if loss < best_loss:                #현재 저장된 loss보다 저 작게 나오면 모델 저장
            best_loss = loss
            best_model = model.state_dict()
        print('Epoch %d: Cross-entrophy: %.4f' % (epoch, loss))

torch.save([best_model, chars_to_int], "./model_checkpoints/text_generator.pth")

  from .autonotebook import tqdm as notebook_tqdm


Epoch 0: Cross-entrophy: 422362.9688
Epoch 1: Cross-entrophy: 405260.9375
Epoch 2: Cross-entrophy: 389037.4688
Epoch 3: Cross-entrophy: 376983.3438
Epoch 4: Cross-entrophy: 367182.5312
Epoch 5: Cross-entrophy: 358907.2500
Epoch 6: Cross-entrophy: 351417.1875
Epoch 7: Cross-entrophy: 341741.1562
Epoch 8: Cross-entrophy: 334116.6875
Epoch 9: Cross-entrophy: 328914.7500
Epoch 10: Cross-entrophy: 321809.8125
Epoch 11: Cross-entrophy: 317064.7812
Epoch 12: Cross-entrophy: 308550.7812
Epoch 13: Cross-entrophy: 303386.5312
Epoch 14: Cross-entrophy: 297692.3125
Epoch 15: Cross-entrophy: 292292.9375
Epoch 16: Cross-entrophy: 288756.2500
Epoch 17: Cross-entrophy: 282903.8750
Epoch 18: Cross-entrophy: 279991.3750
Epoch 19: Cross-entrophy: 274500.4062


## Text Generation Test

In [102]:
best_model, chars_to_int = torch.load("./model_checkpoints/text_generator.pth")
n_vocab = len(chars_to_int)
int_to_chars = dict((i, c) for c, i in chars_to_int.items())

start = np.random.randint(0, len(raw_text)-seq_len)
prompt = lower_text[start:start+seq_len]            #문서 중 문구를 랜덤으로 뽑아서 seq_length 만큼 프롬프트로 입력
pattern = [chars_to_int[c] for c in prompt]

model.eval()
print("Prompt: \n %s  _____" % prompt)
with torch.no_grad():
    for i in range(1000):           #모델을 1000번 반복함. 즉 글자수가 1000이 될때까지 텍스트를 생성
        # 입력을 tensor로
        x = np.reshape(pattern, (1, len(pattern), 1)) / float(n_vocab)
        x = torch.tensor(x, dtype = torch.float32)

        # 문자 logit 계산
        prediction = model(x)

        # 인덱스를 문자로 변환
        index = int(prediction.argmax())
        result = int_to_chars[index]
        print(result, end = "")

        # 프롬프트에 단어를 추가
        pattern.append(index)
        pattern=pattern[1:]
print()
print("__Fin__")

Prompt: 
 balloon?"

"yes, i just said to myself coming along: 'i wonder if christopher robin
has such a thing  _____
 and then the boll hs toe tore 

"ioo the tooe   "a larpy hirteey," 
"yhs, yhu kave ho in " 
"the oor toe lo the hertoon " 
"yhs, yhu iave g vorl ho in an hnnettor b aoa oo ho so bo the bootoe " 
"the oere po he poe"thre " soo tere airistopher robin sas shit the borto tf the pore of the sore of the horest, 
"io whu den tooh " 
"yhet so he soo to lave a lort of toeek to an toeetllng "

"that so he soo tool " 
"yhs  bedrueey _ dare if io " said thbnit  "io i soele to toell to soeethin " 
"yhut ao the btruon oo the tore " said pooh. 
"th toal i yout oo tee horert?" 
"yes, 

"i shsught to he c derroon " said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee woul   said pooh.

"io sou tee wou