In [None]:
from path import Path
import json, os
import collections
import re
import torchtext
import pandas as pd
import sklearn.model_selection
import torch
import torch.nn as nn
import torch.cuda
from torch import optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BASE_DIR = Path("./chinese-poetry")
POETS_PER_JSON = 1000
POET_DIR = BASE_DIR / "json"

In [None]:
def read_jsons(dynasty=["tang"], max_count=None):
    poets = []
    for dyna in dynasty:
        json_count = 0
        while True:
            json_path = POET_DIR / "poet.{dyna:s}.{count:d}.json".format(dyna=dyna, count=json_count)
            if not os.path.exists(json_path): break
            with open(json_path, "r", encoding="utf8") as f:
                _poets = json.loads(f.read())
            poets.extend(_poets)
            if max_count is not None and len(poets) >= max_count: return poets[:max_count+1]
            json_count += POETS_PER_JSON
    return poets

In [None]:
tang_poets = read_jsons()

In [None]:
def text_preprocess(poets):
    pattern = re.compile(r"（.*）|\[|\]") # Remove annotations
    joined = [' '.join(re.sub(pattern, "", ''.join(poet["paragraphs"])).replace("。。", "。")) for poet in poets]
    qijue = re.compile(r"^\w{7}，\w{7}。\w{7}，\w{7}。$")
    return ['<SOP> ' + s  + ' <EOP>' for s in joined if re.match(qijue, s.replace(" ", "")) and s.find("{") < 0]

In [None]:
preprocessed = text_preprocess(tang_poets)

In [None]:
train, test = sklearn.model_selection.train_test_split(preprocessed)

In [None]:
with open("tang_train.txt", "w") as f:
    f.write('\n'.join(train))
with open("tang_test.txt", "w") as f:
    f.write('\n'.join(test))
with open("tang.txt", "w") as f:
    f.write('\n'.join(preprocessed))