In [1]:
import numpy as np
import nltk
from tensorflow.keras.preprocessing.text import Tokenizer
from nltk.stem import LancasterStemmer
from tensorflow.keras.layers import Input, Embedding, Dense
from tensorflow.keras.models import Model
from sklearn.metrics.pairwise import cosine_similarity

nltk.download('punkt')
nltk.download('gutenberg')
nltk.download('stopwords')

# 영문 소설 18개를 읽어와서 전처리를 수행한다.
n = 18
stemmer = LancasterStemmer()
stopwords = nltk.corpus.stopwords.words('english')
stopwords.extend(['and', 'but', 'the', 'for', 'would', 'shall'])

sent_stem = []
files = nltk.corpus.gutenberg.fileids()
for i, text_id in enumerate(files[:n]):
    text = nltk.corpus.gutenberg.raw(text_id)
    sentences = nltk.sent_tokenize(text)

    # 각 단어에 Lancaster stemmer를 적용한다.
    for sentence in sentences:
        word_tok = nltk.word_tokenize(sentence)
        stem = [stemmer.stem(word) for word in word_tok if word not in stopwords if len(word) > 2]
        sent_stem.append(stem)
    print('{}: {} ----- processed.'.format(i+1, text_id))

print("\n총 문장 개수 =", len(sent_stem))
print(sent_stem[0])

# 단어사전
tokenizer = Tokenizer()
tokenizer.fit_on_texts(sent_stem)

# 단어사전
word2idx = tokenizer.word_index
word2idx['<PAD>'] = 0
idx2word = {v:k for k, v in word2idx.items()}

print("사전 크기 =", len(word2idx))

# 문장을 단어의 인덱스로 표현
sent_idx = tokenizer.texts_to_sequences(sent_stem)
print(sent_idx[0])

# trigram
x_train = []
y_train = []
for sent in sent_idx:
    if len(sent) < 3:
        continue

    for a, b, c in nltk.trigrams(sent):
        x_train.append(b)
        x_train.append(b)

        y_train.append(a)
        y_train.append(c)

x_train = np.array(x_train).reshape(-1, 1)
y_train = np.array(y_train).reshape(-1, 1)

x_train.shape, y_train.shape

VOC_SIZE = len(word2idx)
EMB_SIZE = 32

x_input = Input(batch_shape=(None, 1))
x_emb = Embedding(VOC_SIZE, EMB_SIZE, name='emb')(x_input)
y_output = Dense(VOC_SIZE, activation='softmax')(x_emb)

model = Model(x_input, y_output)
model.compile(loss = 'sparse_categorical_crossentropy', optimizer='adam')
model.summary()

# word --> word2vec을 확인하기 위한 모델
model_vec = Model(x_input, x_emb)

hist = model.fit(x_train, y_train, batch_size=20480, epochs=1)

def get_word2vec(word):
    stem_word = stemmer.stem(word)
    if stem_word not in word2idx:
        print('{}가 없습니다.'.format(word))
        return
    
    word2vec = model_vec.predict(np.array(word2idx[stem_word]).reshape(1,1))[0]
    return word2vec

father = get_word2vec('father')
mother = get_word2vec('mother')
doctor = get_word2vec('doctor')

print(father)

cosine_similarity(father, mother)

cosine_similarity(father, doctor)

W = model.get_layer('emb').get_weights()[0]
W.shape

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\배진우\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package gutenberg to
[nltk_data]     C:\Users\배진우\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping corpora\gutenberg.zip.
[nltk_data] Downloading package stopwords to
[nltk_data]     C:\Users\배진우\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


1: austen-emma.txt ----- processed.
2: austen-persuasion.txt ----- processed.
3: austen-sense.txt ----- processed.
4: bible-kjv.txt ----- processed.
5: blake-poems.txt ----- processed.
6: bryant-stories.txt ----- processed.
7: burgess-busterbrown.txt ----- processed.
8: carroll-alice.txt ----- processed.
9: chesterton-ball.txt ----- processed.
10: chesterton-brown.txt ----- processed.
11: chesterton-thursday.txt ----- processed.
12: edgeworth-parents.txt ----- processed.
13: melville-moby_dick.txt ----- processed.
14: milton-paradise.txt ----- processed.
15: shakespeare-caesar.txt ----- processed.
16: shakespeare-hamlet.txt ----- processed.
17: shakespeare-macbeth.txt ----- processed.
18: whitman-leaves.txt ----- processed.

총 문장 개수 = 94434
['emm', 'jan', 'aust', '1816', 'volum', 'chapt', 'emm', 'woodh', 'handsom', 'clev', 'rich', 'comfort', 'hom', 'happy', 'disposit', 'seem', 'unit', 'best', 'bless', 'ex', 'liv', 'near', 'twenty-one', 'year', 'world', 'littl', 'distress', 'vex']
사전 크기

(32395, 32)