# CICS544 Final Project: Chinese Couplets Generation

In [1]:
import pandas as pd
import numpy as np

In [2]:
# data reading @author: Joe Chen

# adding start token <s> and end_token <e> to each row
# spliting each chinese character (word segmentation)
# data structure: data = [[w11,w12,...],[w21,w22,...], ....]

# train 上联
with open("couplet/train/in.txt",encoding='utf8') as f:
    tr_in =  [('<s> ' + row.strip() + ' <e>').split() for row in f.readlines()]
# train 下联  
with open("couplet/train/out.txt",encoding='utf8') as f:
    tr_out = [('<s> ' + row.strip() + ' <e>').split() for row in f.readlines()]

#test
with open("couplet/test/in.txt",encoding='utf8') as f:
    te_in = [('<s> ' + row.strip() + ' <e>').split() for row in f.readlines()]
with open("couplet/test/out.txt",encoding='utf8') as f:
    te_out = [('<s> ' + row.strip() + ' <e>').split() for row in f.readlines()]

In [3]:
# first two rows in training input
tr_in[:2]

[['<s>', '晚', '风', '摇', '树', '树', '还', '挺', '<e>'],
 ['<s>', '愿', '景', '天', '成', '无', '墨', '迹', '<e>']]

In [4]:
# flatten each dataset for word vector training and vocabulary building
tr_in_flat = [i for x in tr_in for i in x]
tr_out_flat = [i for x in tr_out for i in x]
te_in_flat = [i for x in te_in for i in x]
te_out_flat = [i for x in te_out for i in x]

In [5]:
# first two rows in training input after flatten
tr_in_flat[:18]

['<s>',
 '晚',
 '风',
 '摇',
 '树',
 '树',
 '还',
 '挺',
 '<e>',
 '<s>',
 '愿',
 '景',
 '天',
 '成',
 '无',
 '墨',
 '迹',
 '<e>']

In [6]:
# training word2vec embeddings using gensim
from gensim.models import Word2Vec

In [8]:
sentences = tr_in_flat + tr_out_flat + te_in_flat + te_out_flat
model = Word2Vec(sentences = tr_out,vector_size=100, window=10, min_count=10, workers=4)

In [9]:
#testing 可以试试别的词
model.wv.most_similar('春')

[('新', 0.42505815625190735),
 ('欣', 0.40011340379714966),
 ('曙', 0.37929314374923706),
 ('千', 0.37075915932655334),
 ('园', 0.3696725070476532),
 ('<s>', 0.36914631724357605),
 ('花', 0.3633737564086914),
 ('栉', 0.362592875957489),
 ('煦', 0.3612635135650635),
 ('，', 0.35984164476394653)]

In [12]:
#combining word tokens for vocab
vocab = set(tr_in_flat + tr_out_flat + te_in_flat + te_out_flat)
# remove start/end tokens since they need to be at first 2 indece when creating word2id
vocab.remove('<s>')
vocab.remove('<e>')

# creating word2id lookup dictionary
word2id = {w:i for i,w in enumerate(vocab,start=2)}
word2id['<s>'] = 0
word2id['<e>'] = 1

id2word = {i:w for w,i in word2id.items()}

In [13]:
# first 10 words in lookup dictionary
[id2word[i] for i in range(10)]

['<s>', '<e>', '嗦', '烈', '隄', '馂', '学', '咪', '醁', '掼']

In [14]:
# size of data and vocab
len(tr_in),len(te_in),len(word2id)

(770491, 4000, 9129)