## 데이터 전처리 부분은 책 코드 사용

In [None]:
import numpy as np
import torch
from collections import Counter
import re
#https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html
#부분 참고

class textDataset(torch.utils.data.Dataset):
    def __init__(self,seq_length=20,filename = "./data/aesop/data.txt"):
        self.filename= filename

        self.seq_length = seq_length
        self.words = self.load_words()
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indices = [self.word_to_index[w] for w in self.words]
        print(self.words_indices[:30])
        self.len_indices =len(self.index_to_word)
        self.len_text = len(self.words_indices)

    def load_words(self):
        with open(self.filename, encoding='utf-8-sig') as f:
            text = f.read()
        #removing text before and after the main stories
        start = text.find("THE FOX AND THE GRAPES\n\n\n")
        end = text.find("ILLUSTRATIONS\n\n\n[")
        text = text[start:end]

        start_story = '| ' * self.seq_length
        text = start_story + text
        text = text.lower()
        text = text.replace('\n\n\n\n\n', start_story)
        text = text.replace('\n', ' ')
        text = re.sub('  +', '. ', text).strip()
        text = text.replace('..', '.')

        text = re.sub('([!"#$%&()*+,-./:;<=>?@[\]^_`{|}~])', r' \1 ', text)
        text = re.sub('\s{2,}', ' ', text)
        ## 맨 앞에 ' '가 있음
        return text[1:].split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return self.len_text - self.seq_length

    
    def __getitem__(self, index):
        
        #one-hot
        y = self.words_indices[index+self.seq_length]
        one_hot_label = torch.tensor(np.eye(self.len_indices)[y])
        return (
            torch.tensor(self.words_indices[index:index+self.seq_length]),
            one_hot_label,
        )

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import RMSprop,Adam

class textGenLSTM(nn.Module):
    def __init__(self,seq_length=20,total_words=4169):
        super().__init__()
        
        # n_units = 256
        # embedding_size = 100
        self.seq_length =20
        self.n_units = 256
        self.num_layers = 1 #적층 레이어를 위해
        embedding_size = 256

        self.embedding = nn.Embedding(
            num_embeddings=total_words, embedding_dim= embedding_size
        )

        self.lstm = nn.LSTM(
            input_size=self.n_units,
            hidden_size=self.n_units,
            num_layers=self.num_layers,
            batch_first=True,
            dropout=0.2,            
        )

        self.fc = nn.Sequential(
            nn.Linear(self.n_units,total_words),
            nn.Softmax(),
        )
        
        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = Adam(self.parameters(),lr=0.001)

    def forward(self,x,prev_state): 
        #[32, 20]
        emb = self.embedding(x) 
        #[32, 20, 256]
        # h0 [1, 32, 256], c0[1, 32, 256]
        output,state = self.lstm(emb,prev_state)
        # [32, 20, 256], ([1, 32, 256],[1, 32, 256])
        logits = self.fc(output[:,-1,:]) 
        #[32, 4169]
        return logits,state
    
    def init_state(self):
        return (torch.zeros(self.num_layers, self.batch_size, self.n_units),
                torch.zeros(self.num_layers, self.batch_size, self.n_units))
    
    def train(self,dataloader):
        
        for iter, (x, y) in enumerate(dataloader):
            self.batch_size = x.shape[0]
            state_h, state_c = self.init_state()

            self.optimizer.zero_grad()
            
            y_pred, (state_h, state_c) = self.forward(x, (state_h, state_c))
            loss = self.loss_fn(y_pred, y)

            state_h = state_h.detach() 
            state_c = state_c.detach()

            loss.backward()
            self.optimizer.step()
            if iter%100==0:
                print(iter, f"loss {loss.item():.4f}")
        return loss
        

In [None]:
seq_length=20
my_dataset = textDataset(seq_length=seq_length)
total_indexing_num = len(my_dataset.word_to_index)
model = textGenLSTM(seq_length=20,total_words=total_indexing_num)

In [None]:
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader


batch_size=32
num_epochs=10

dataloader = DataLoader(my_dataset, batch_size=batch_size,shuffle=True)


for epoch in range(1,num_epochs+1):
    
    loss = model.train(dataloader)

    print({ 'epoch': epoch, 'iter': iter, 'loss': loss.item() })



In [None]:

def load_words(input_text,dataset):
    text = input_text

    start_story = '| ' * seq_length
    text = start_story + text
    text = text.lower()
    words =  text[1:].split(' ')
    words_indices = [dataset.word_to_index[w] for w in words]
    return words_indices

def sample_with_temp(preds, temperature=1.0):
    # helper function to sample an index from a probability array
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)



def generate_text(seed_text, next_words, model, max_sequence_len=20, temp=0.2,dataset=my_dataset):
    output_text = seed_text
    
    seed_text = '| ' * seq_length + seed_text
    
    for _ in range(next_words):
        token_list = load_words(seed_text,dataset)
        token_list = token_list[-max_sequence_len:]
        token_list = np.reshape(token_list, (1, max_sequence_len))
        token_list = torch.tensor(token_list)

        state_h, state_c = model.init_state()
        probs = model.forward(token_list,(state_h, state_c))[0].detach().numpy()
        y_class = sample_with_temp(probs[0], temperature = temp)
        
        if y_class == 0:
            output_word = ''
        else:
            output_word = dataset.index_to_word[y_class]
            
        if output_word == "|":
            break
            

        output_text += output_word + ' '
        seed_text += output_word + ' '

            
            
    return output_text

In [None]:
seed_text = "the frog and the snake . "
gen_words = 100
print (generate_text(seed_text, gen_words, model))