# Chinese Poetry Generator

Generates Tang poetry using the CharRNN model

Data: from https://github.com/chinese-poetry/chinese-poetry

In [1]:
import random
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
import torch.optim as optim
from zhconv import convert
from tqdm.auto import tqdm
import numpy as np


In [2]:
from data_loader import ParseRawData
# This loads the json data and only takes the main body of each poem
data_all = ParseRawData()

In [3]:
len(data_all)

57598

We have 57598 Tang Poems in total.

In [4]:
data_all[0]

'秦川雄帝宅，函谷壯皇居。綺殿千尋起，離宮百雉餘。連甍遙接漢，飛觀迥凌虛。雲日隱層闕，風煙出綺疎。'

# preprocess data

Filter and prepare data:
- convert traditional Chinese to simplified Chinese for better readability
- only take poems with 5 character lines - easier for model to learn the poem structure

In [5]:
comma = data_all[0][5]
data = [convert(x, "zh-hans") for x in data_all if len(x.split(comma)[0])==5]
len(data)

30379

In [6]:
data[0]

'秦川雄帝宅，函谷壮皇居。绮殿千寻起，离宫百雉馀。连甍遥接汉，飞观迥凌虚。云日隐层阙，风烟出绮疏。'

In [7]:
char_counter = Counter([x for poem in data for x in poem])

In [8]:
len([x for x in char_counter if char_counter[x] < 5])

2607

In [9]:
char_to_ix = {}
min_freq = 5 # a relatively large number is chosen to make the dataset smaller to fit my laptop
# EOP_TOKEN = '$'
for poem in data:
    for char in poem:
        # only take chars that appeared at least min_freq times
        if (char not in char_to_ix) and (char_counter[char] >= min_freq):
            char_to_ix[char] = len(char_to_ix)
# char_to_ix['$'] = len(char_to_ix)
# char_to_ix['<START>'] = len(char_to_ix)
ix_to_char = dict((i, w) for (w, i) in char_to_ix.items())

In [48]:
# get input and target for training.
# use first 12 chars as input to predict the next char.
seq_len = 12
input_data = []
target = []
# EOP_TOKEN = '$'
PADDING = 'O'
char_to_ix[PADDING] = len(char_to_ix)

for poem in data:
    # add EOP token
#     poem += EOP_TOKEN
    # Add 6 chars padding before the poem.
    poem = PADDING * 6 + poem
    for i in range(len(poem) - seq_len):
        target.append(poem[i + seq_len])
        input_data.append(poem[i:i+seq_len])

print("Number of training samples: {}".format(len(input_data)))

Number of training samples: 1775989


In [49]:
# plus one to count for the unknown chars
vocab_size = len(char_to_ix) + 1
print("vocab size: {}".format(vocab_size))

vocab size: 5196


In [50]:
input_data[0]

'OOOOOO秦川雄帝宅，'

In [51]:
def transform_X(text, char_to_ix, onehot=False):
    """Transforms one row of input text to index or onehot encoded arrays
    input format: '秦川雄帝宅，'
    output format: if onehot == False, output an index array same length as text. 
                otherwise, output onehot encoded array of shape (len(text), vocab_size).
    """
    if onehot:
        X = np.zeros((len(text), len(char_to_ix)+1))
        for i, char in enumerate(text):
            X[i, char_to_ix.get(char, len(char_to_ix))] = 1
    else:
        X = np.array([char_to_ix.get(char, len(char_to_ix)) for char in text])
    return X


# def transform_y(text, char_to_ix):
#     """Transform the target text into onehot encoded tensor"""
#     y = np.zeros((len(char_to_ix)+1))
#     y[char_to_ix.get(text, len(char_to_ix))] = 1
#     return y

# For Pytorch crossentropyloss, the target does not need to be onehot encoded.
def transform_y(text, char_to_ix):
    return char_to_ix.get(text, len(char_to_ix))

# dataset is too big for my laptop. Build a custom dataset
class PoemDataset(Dataset):
    def __init__(self, data, target, transform_X, transform_y, char_to_ix):
        self.data = data
        self.target = target
        self.transform_X = transform_X
        self.transform_y = transform_y
        self.char_to_ix = char_to_ix
        
    def __len__(self):
        """Total number of samples"""
        return len(self.data)
    
    def __getitem__(self, index):
        """Generate one sample of data"""
        X = self.transform_X(self.data[index], self.char_to_ix)
        y = self.transform_y(self.target[index], self.char_to_ix)
        sample = {"input": X, "target": y}
        return sample

In [52]:
poem_dataset = PoemDataset(input_data, target, transform_X, transform_y, char_to_ix)

In [53]:
batch_size = 128
dataloader = DataLoader(poem_dataset, batch_size=batch_size, shuffle=True)

In [16]:
# # check we have the right shape
# for i, sample_batched in enumerate(dataloader):
#     print(i, sample_batched['input'].size(),
#          sample_batched['target'].size(),
#          sample_batched['target'].dtype)
#     if i == 3:
#         break

In [63]:
class PoemGenerationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(PoemGenerationModel, self).__init__()
    
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=2, batch_first=True)
        self.dropout = nn.Dropout(0.2)
        self.linear = nn.Linear(hidden_dim, vocab_size)
        
    def forward(self, input_data):
        embeds = self.embedding(input_data)
        lstm_out, _ = self.lstm(embeds)
        lstm_drop = self.dropout(lstm_out)
        logits = self.linear(lstm_drop[:, -1, :].squeeze())
        return logits

In [64]:
embed_dim = 256
hidden_dim = 256
lr = 0.001
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [65]:
device

'cuda'

In [66]:
model = PoemGenerationModel(vocab_size, embed_dim, hidden_dim).float().to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [67]:
print(model)

PoemGenerationModel(
  (embedding): Embedding(5196, 256)
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=5196, bias=True)
)


In [68]:
def sample(preds, temperature=0.8):
    """Sample the output based predicted probabilities.
    
    preds: 1D tensor. Logits from the model
    temperature: When temperature is low, tend to choose the most likely words. 
    When temperature is high, model will be more adventurous. 
    """
    # helper function to sample an index from a probability array
    preds = torch.nn.functional.softmax(preds, dim=0).detach().cpu().numpy()
    
    exp_preds = np.power(preds, 1./temperature)
    preds = exp_preds / np.sum(exp_preds)
    pro = np.random.choice(range(len(preds)), 1, p=preds)
    return int(pro.squeeze())

In [69]:
def generate_poem(model, input_text, output_length=18, temperature=1):
    """Given input_text, generate a poem.
    input_text need to be 6 chars, where last one is a comma.
    Example input: "我有紫霞想，"
    """
    if len(input_text) < seq_len:
        input_text = PADDING * (seq_len - len(input_text)) + input_text
        
    generated = ""
    for i in range(output_length):
        pred = generate_one_char(model, input_text, temperature=temperature)
        generated += pred
        input_text = input_text[1:] + pred
    return generated
    
def generate_one_char(model, input_text, temperature=1):
#     X_test = np.zeros((1, seq_len, vocab_size))
#     for t, char in enumerate(input_text):
#         X_test[0, t, char_to_ix.get(char, len(char_to_ix))] = 1
    X_test = transform_X(input_text, char_to_ix)
        
    pred = model(torch.from_numpy(X_test).unsqueeze(0).to(device)) #use less precision for laptop
    next_index = sample(pred, temperature)
    next_char = ix_to_char.get(next_index, "?")
    
    return next_char

In [70]:
def generate_sample():
    # Prints generated sample text. Used during training to check the model performance
    print('\n----- Generating text:')

    # randomly pick the starting line of a poem as the seed
    poem_index = random.randint(0, len(data))
    print("Generating with seed: {}".format(data[poem_index][:seq_len]))
    seed_text = data[poem_index][:seq_len]
    for temperature in [0.2, 0.5, 0.7, 1.0]:
        print('----- temperature:', temperature)
        generated = generate_poem(model, seed_text, output_length=12, temperature=temperature)
        print(generated)
    

    seed_text = PADDING*6 + data[poem_index][:seq_len-6]
    print("Generating with seed: {}".format(seed_text))
    for temperature in [0.2, 0.5, 0.7, 1.0]:
        print('----- temperature:', temperature)
        generated = generate_poem(model, seed_text, temperature=temperature)
        print(generated)

In [71]:
n_epochs = 5

for epoch in range(n_epochs):
#     model.train()
    step = 0
    # Dataloader returns the batches
    for samples in tqdm(dataloader):
        cur_batch_size = len(samples)
        batch_X = samples['input'].to(device)
        batch_y = samples['target'].to(device)

        # Zero out the gradients before backpropagation
        model.zero_grad()

        y_pred = model(batch_X.long())
        # Compute loss and update gradients
        loss = loss_function(y_pred, batch_y.long())
        loss.backward()
        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
#         nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        if step % 1000 == 0:
            print("\n\nLoss at epoch {} step {}: {}".format(epoch, step, loss))
            generate_sample()
            
        step += 1

  0%|          | 0/13875 [00:00<?, ?it/s]



Loss at epoch 0 step 0: 8.568345069885254

----- Generating text:
Generating with seed: 兹山昔飞来，远自琅琊台。
----- temperature: 0.2
苡枿绚臭䲡赞购蜜鹏椅何熨
----- temperature: 0.5
拱祗陁良畬胫衽燧拗申掺憾
----- temperature: 0.7
夭迨舲舻棰舁阿宵阑弃滟藜
----- temperature: 1.0
缰乙忘繁骄灞綦粝藟鄜聋亹
Generating with seed: OOOOOO兹山昔飞来，
----- temperature: 0.2
觊铁轘麈予姒裛潢壅猷超佞饯祢吟蜓醭情
----- temperature: 0.5
惋纽瀍颍番瓮旸仰浓娘动魑佯乾睚瞽纤脚
----- temperature: 0.7
楫𪨗蛙樽涟并躔弄卢窕掀岘晡侯谪唯谒鸳
----- temperature: 1.0
鹆帷油旸闪漘娄踯踉槌础涡迢廉函判漫塞


Loss at epoch 0 step 1000: 5.808252811431885

----- Generating text:
Generating with seed: 大雅何寥阔，斯人尚典刑。
----- temperature: 0.2
一人不云月，不人不天归。
----- temperature: 0.5
归心忽欲道，天风未生阴。
----- temperature: 0.7
闲城少思音，应虎外出同。
----- temperature: 1.0
随野里似湔，风恨后得精。
Generating with seed: OOOOOO大雅何寥阔，
----- temperature: 0.2
无人无不时。一日不不月，不人不云风。
----- temperature: 0.5
不令自水新。林柳一多阻，?来不如台。
----- temperature: 0.7
霜思代枝知。幽滴终药楼，孤朝莫渡尘。
----- temperature: 1.0
争分日长回。夜扉非呼奇，今防处多昼。


Loss at epoch 0 step 2000: 5.515780448913574

----- Generating text:
Generating with seed: 偶作关东

  0%|          | 0/13875 [00:00<?, ?it/s]



Loss at epoch 1 step 0: 5.089142322540283

----- Generating text:
Generating with seed: 高阁逼诸天，登临近日边。
----- temperature: 0.2
一朝无一事，何日不相亲。
----- temperature: 0.5
谁怜明月后，多见白云来。
----- temperature: 0.7
鬓声来未定，星草落清灰。
----- temperature: 1.0
如何沈次气，事好列金威。
Generating with seed: OOOOOO高阁逼诸天，
----- temperature: 0.2
闲居一枝?。何时一枝雪，不觉一枝?。
----- temperature: 0.5
一船空客船。一朝不可识，不得天涯归。
----- temperature: 0.7
青楼水漠流。扣杖犹不识，传君独为行。
----- temperature: 1.0
空痕寄五枢。隋陵至人错，经焉可相当。


Loss at epoch 1 step 1000: 4.954770088195801

----- Generating text:
Generating with seed: 道北冯都使，高斋见一川。
----- temperature: 0.2
一朝无限处，万里有秋风。
----- temperature: 0.5
时朝客来远，寒月上山川。
----- temperature: 0.7
万国徒不及，君子不见恩。
----- temperature: 1.0
假年心大园，蜂吏枕东窗。
Generating with seed: OOOOOO道北冯都使，
----- temperature: 0.2
天高有一年。不知无事者，不得是何如。
----- temperature: 0.5
生涯此一年。何当见中郡，何必有游才。
----- temperature: 0.7
神心入国生。别来同早日，吟得到天田。
----- temperature: 1.0
分时有力强。致惭吾不慨，岂忽为地膺。


Loss at epoch 1 step 2000: 5.169358253479004

----- Generating text:
Generating with seed: 彭蠡隐深

  0%|          | 0/13875 [00:00<?, ?it/s]



Loss at epoch 2 step 0: 4.656186103820801

----- Generating text:
Generating with seed: 吾爱王子晋，得道伊洛滨。
----- temperature: 0.2
一朝无一事，一醉不知贫。
----- temperature: 0.5
宝剑登高殿，金鞍动玉轮。
----- temperature: 0.7
松下紫蒿盖，天留白日新。
----- temperature: 1.0
疑走破头人，哀哉度浩亲。
Generating with seed: OOOOOO吾爱王子晋，
----- temperature: 0.2
时来不可亲。何以问君子，不为生死身。
----- temperature: 0.5
真为明主恩。山川自有地，百化已长言。
----- temperature: 0.7
。巍童王氏宅，根懦真??。小师转忘筌
----- temperature: 1.0
眼祗接系寇。一朝见封缄，引集无愠戚。


Loss at epoch 2 step 1000: 4.629902362823486

----- Generating text:
Generating with seed: 不食非关药，天生是女仙。
----- temperature: 0.2
一身无一醉，何日有柴门。
----- temperature: 0.5
谁能知别离，不得问心然。
----- temperature: 0.7
此中嗟苦醉，谁能不念禅。
----- temperature: 1.0
乡行终不极，贫里复成禅。
Generating with seed: OOOOOO不食非关药，
----- temperature: 0.2
无人见故人。不知山上客，不觉到山人。
----- temperature: 0.5
无人爱我流。诗书不可识，终日莫相寻。
----- temperature: 0.7
因依居。有至客不知，文字天下前。食中
----- temperature: 1.0
含心幸一朝。鸦来吴漏了，那折卜山斜。


Loss at epoch 2 step 2000: 4.592876434326172

----- Generating text:
Generating with seed: 凤扆朝碧

  0%|          | 0/13875 [00:00<?, ?it/s]



Loss at epoch 3 step 0: 4.507320404052734

----- Generating text:
Generating with seed: 初禅韵高柳，密茑挂深松。
----- temperature: 0.2
不得无人处，无心更白松。
----- temperature: 0.5
虽在天台里，何曾作客行。
----- temperature: 0.7
池竹时平柳，湖潭讵有峰。
----- temperature: 1.0
牛雁何时归？鬟起接山更。
Generating with seed: OOOOOO初禅韵高柳，
----- temperature: 0.2
独坐向南归。日暮风声尽，风声雨露微。
----- temperature: 0.5
送别有馀春。夜月波上叶，寒窗石上尘。
----- temperature: 0.7
开穴背秋槐。此夕仍何别，扁舟去未期。
----- temperature: 1.0
微恐乱牛开。神皇一变结，散后是班樽。


Loss at epoch 3 step 1000: 4.263589859008789

----- Generating text:
Generating with seed: 闭门迹群化，凭林结所思。
----- temperature: 0.2
不知一杯酒，谁肯有时期。
----- temperature: 0.5
秋风吹落日，夜雨下寒时。
----- temperature: 0.7
溪雨昼不尽，空意清不知。
----- temperature: 1.0
惜中争倚槛，卿事共争衣。
Generating with seed: OOOOOO闭门迹群化，
----- temperature: 0.2
一别复何人。不见天地内，不知心不勤。
----- temperature: 0.5
佳气布尘期。北阙开天乐，南风发楚词。
----- temperature: 0.7
新宫日无愁。浮云动佳气，叠嶂无情收。
----- temperature: 1.0
百类乱何为。我虽白鹦鹉，始喜千山衰。


Loss at epoch 3 step 2000: 4.167479991912842

----- Generating text:
Generating with seed: 君平既弃

  0%|          | 0/13875 [00:00<?, ?it/s]



Loss at epoch 4 step 0: 4.200263977050781

----- Generating text:
Generating with seed: 追逐轻薄伴，闲游不著绯。
----- temperature: 0.2
不知人不见，何处是年重？
----- temperature: 0.5
一年兼好兴，一夜更何为。
----- temperature: 0.7
老歌鸿变绝，他日觉谁栖。
----- temperature: 1.0
黔凡久美阔，愿羡君知机。
Generating with seed: OOOOOO追逐轻薄伴，
----- temperature: 0.2
何人知此心。不知山水客，不敢问渔心。
----- temperature: 0.5
心期无所依。无人不知己，此事不能归。
----- temperature: 0.7
对行思主人。沧波踏青桂，白马入丹青。
----- temperature: 1.0
自身暗延王。宛天会尔所，新唱结金光。


Loss at epoch 4 step 1000: 4.384469509124756

----- Generating text:
Generating with seed: 大业来四夷，仁风和万国。
----- temperature: 0.2
天地有灵境，天地无人识。
----- temperature: 0.5
何当见东北，万里复何望。
----- temperature: 0.7
有时因饮使，一别即言宴。
----- temperature: 1.0
横舟映桑田，寂寂长沙外。
Generating with seed: OOOOOO大业来四夷，
----- temperature: 0.2
天高一何息。天下有所求，身是此生死。
----- temperature: 0.5
岁岁不敢止。不中不可测，空使其如水。
----- temperature: 0.7
乾坤一哉贼。仕真三晋下，迷禀五色巡。
----- temperature: 1.0
歌声既秋色。宝刀结仙道，持衣桃李碧。


Loss at epoch 4 step 2000: 4.75571346282959

----- Generating text:
Generating with seed: 静谈云鹤趣

In [72]:
save_path = "saved_model/charrnn_pytorch_embedding_model_v2"

In [43]:
# torch.save(model.state_dict(), save_path)

In [73]:
torch.save(model, save_path)

1st try of model with embedding:    
`PoemGenerationModel(
  (embedding): Embedding(5195, 128)
  (lstm): LSTM(128, 128, batch_first=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=128, out_features=5195, bias=True)
)
`
seq=6

2nd try of model with embedding:
`PoemGenerationModel(
  (embedding): Embedding(5196, 256)
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=5196, bias=True)
)
`
seq=12


# Load the saved model and test

In [74]:
# trained_model = PoemGenerationModel(vocab_size, embed_dim, hidden_dim)
# trained_model.load_state_dict(torch.load(save_path))
trained_model = torch.load(save_path)
trained_model.eval()

PoemGenerationModel(
  (embedding): Embedding(5196, 256)
  (lstm): LSTM(256, 256, num_layers=2, batch_first=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (linear): Linear(in_features=256, out_features=5196, bias=True)
)

In [80]:
seed_text = "我有紫霞想，"
generated = generate_poem(model, seed_text, temperature=0.5)
print(generated)

不如玉床空。不知世上人，还得一何通。


In [128]:
generated = generate_poem(trained_model, seed_text, temperature=0.3)
print(generated)

不如今日时。山川空有路，松竹未成枝。


In [131]:
generated = generate_poem(trained_model, seed_text, output_length=42, temperature=0.5)
print(generated)

忽然虚且清。不知真士术，不见有君名。天子无人迹，三年失所行。几时无一事，有貌数年情。


In [109]:
generated = generate_poem(trained_model, "明月几时有，", output_length=18, temperature=0.3)
print(generated)

此夜独悠悠。白日无人见，青山一曲流。


In [124]:
generated = generate_poem(trained_model, "八月湖水平，", output_length=18, temperature=0.2)
print(generated)

一望青云端。高风吹寒色，远水连秋湍。
