In [1]:
import torch
import torch.nn as nn
import string
import random
import torch.nn.functional as F
from nltk.tokenize.sonority_sequencing import SyllableTokenizer
from nltk import word_tokenize
from nltk import TweetTokenizer
st = SyllableTokenizer()
tt = TweetTokenizer()

In [2]:
EPOCHS = 5000
LEARNING_RATE = 0.004
LAYERS = 2
HIDDEN_SIZE = 256
SEQ_SIZE = 25
PRINT_EVERY_ITR = 500
TEMPERATURE = 0.75

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
#convert sonnets to data
shakespeare = open('shakespeare.txt', 'r')
sonnets = []

line = shakespeare.readline()
while line != "":
    sonnet_number = int(line.strip('\n'))
    sonnet = ""
    
    while line != '\n' and line != '':
        line = shakespeare.readline()
        sonnet += line
    sonnets.append(sonnet.strip('\n'))
    
    if line == '':
        break
    
    spacer = shakespeare.readline()
    line = shakespeare.readline() #read next number
shakespeare.close()

In [5]:
#convert syllable dict file to dictionary
syllable_dict_file = open('Syllable_dictionary.txt', 'r')
syllable_dict = dict()


line = syllable_dict_file.readline()
while line != "":
    word = ""
    syllables = 0
    alt_syllables = 0
    end_syllables = 0
    
    elements = line.split()
    word = str(elements[0])
    
    if len(elements) > 2:
        if elements[1][0] == 'E':
            end_syllables = int(elements[1][1])
            syllables = int(elements[2])
        elif elements[2][0] == 'E':
            end_syllables = int(elements[2][1])
            syllables = int(elements[1])
        else:
            syllables = int(elements[1])
            alt_syllables = int(elements[2])
    else:
        syllables = int(elements[1])

    syllable_dict.update({word: [syllables, alt_syllables, end_syllables]})
    line = syllable_dict_file.readline()    
    
syllable_dict_file.close()

In [6]:
#tokenize the sonnets into words
all_sonnet_words = []
for sonnet in sonnets:
    lines = sonnet.split('\n')
    sonnet_lines = []
    for line in lines:
        sonnet_lines += tt.tokenize(line) + ['\n']
        
    all_sonnet_words.append(sonnet_lines)

In [7]:
punctuation = [',', '.', '?', '!', ':', '(', ')', ';', "'", '\n']

In [8]:
#find words not accounted for by syllable dict
to_remove = []

for i, sonnet in enumerate(all_sonnet_words):
    for j, word in enumerate(sonnet):
        if word.lower() not in syllable_dict and word not in punctuation:
            
            #fix some edge cases that aren't in the syllable dictionary
            if word.lower() in ['gainst', 'greeing', 'scaped', 'tis', 'twixt']:
                to_remove.append((i, j-1))
                sonnet[j] = "'" + word
            elif word.lower() == 't':
                to_remove.append((i, j+1))
                sonnet[j] = 'to'
            elif word.lower() == 'th':
                to_remove.append((i, j+1))
                sonnet[j] = word + 'e'

for removal in to_remove[::-1]:
    i, j = removal
    all_sonnet_words[i].pop(j)

#print words that aren't in the syllable dictionary to see if we missed any
for i, sonnet in enumerate(all_sonnet_words):
    for j, word in enumerate(sonnet):
        if word.lower() not in syllable_dict and word not in punctuation:
            print(word)

In [9]:
#tokenize the words in the sonnets to syllables, adjusting for the shakespeare dictionary
tokenized_sonnets = []

for sonnet in all_sonnet_words:
    tokenized_sonnet = []
    for word in sonnet:
        if word in punctuation:
            if tokenized_sonnet[-1] not in punctuation:
                tokenized_sonnet[-1] = word
            else:
                tokenized_sonnet += word
        else:
            syllablized = st.tokenize(word)
            syllables = -1
            if word.lower() in syllable_dict:
                syllables = syllable_dict[word.lower()][0]

                while syllables < len(syllablized):
                    if word[-1] == 'e' or word[-1] == 's':
                        syllablized = syllablized[:-2] + [syllablized[-2] + syllablized[-1]]
                    elif len(syllablized) > 2:
                        syllablized = [syllablized[0] + syllablized[1]] + syllablized[2:]
                    else:
                        syllablized = [syllablized[0] + syllablized[1]]
                    
            tokenized_sonnet += syllablized + [" "]
            
    tokenized_sonnets.append(tokenized_sonnet)

In [10]:
#account for all unique syllables
corpus_syllables = set()

for sonnet in tokenized_sonnets:
    for syllable in sonnet:
        corpus_syllables.add(syllable)

In [11]:
#create mapping for all unique syllables
corpus_dictionary = dict()
reverse_dict = dict()

for i,syllable in enumerate(corpus_syllables):
    corpus_dictionary.update({syllable: i})
    reverse_dict.update({i: syllable})

In [12]:
vocab_size = len(corpus_dictionary)

In [13]:
class RNN(nn.Module):
    def __init__(self, input_size, num_layers, output_size, hidden_size):
        super(RNN, self).__init__()
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.LSTM(hidden_size, hidden_size, num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
    
    def forward(self, input_seq, hidden_state):
        embedding = self.embedding(input_seq)
        output, hidden_state = self.rnn(embedding.unsqueeze(1), hidden_state)
        output = self.decoder(output)
        return output, (hidden_state[0].detach(), hidden_state[1].detach())

In [14]:
def toTensor(syllables):
    tensor = torch.zeros(len(syllables)).long()
    for i, syllable in enumerate(syllables):
        tensor[i] = corpus_dictionary[syllable]
    return tensor

In [15]:
def getRandomExample(data, sequence_size):
    sonnet_index = random.randint(0, len(data) - 1)
    sonnet = data[sonnet_index]
    sonnet_position = random.randint(0, len(sonnet) - sequence_size)
    seq = sonnet[sonnet_position : sonnet_position + sequence_size + 1]
    input_seq = toTensor(seq[:-1])
    target_seq = toTensor(seq[1:])
    return input_seq, target_seq

In [16]:
model = RNN(input_size=vocab_size, num_layers=LAYERS, output_size=vocab_size, hidden_size=HIDDEN_SIZE).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [17]:
print("Beginning training")
for epoch in range(1, EPOCHS + 1):
    input_seq, target_seq = getRandomExample(tokenized_sonnets, SEQ_SIZE)
    hidden = torch.zeros(LAYERS, 1, HIDDEN_SIZE).to(device)
    cell = torch.zeros(LAYERS, 1, HIDDEN_SIZE).to(device)
    
    model.zero_grad()
    loss = 0
    input_seq, target_seq = input_seq.to(device), target_seq.to(device)
    
    for c in range(SEQ_SIZE):
        output, (hidden, cell) = model(input_seq, (hidden, cell))
        loss += loss_fn(torch.squeeze(output), torch.squeeze(target_seq))
        
    loss.backward()
    optimizer.step()
    loss = loss.item() / SEQ_SIZE
    
    if epoch % PRINT_EVERY_ITR == 0:
        print(f'Epoch: {epoch}, Loss: {loss}')

Beginning training
Epoch: 500, Loss: 3.8124874877929686
Epoch: 1000, Loss: 2.92393798828125
Epoch: 1500, Loss: 3.281172180175781
Epoch: 2000, Loss: 4.035428161621094
Epoch: 2500, Loss: 3.21732666015625
Epoch: 3000, Loss: 3.6085614013671874
Epoch: 3500, Loss: 3.714349670410156
Epoch: 4000, Loss: 3.7021466064453126
Epoch: 4500, Loss: 3.2543179321289064
Epoch: 5000, Loss: 2.953560791015625


In [18]:
temperatures = [1.5, 0.75, 0.5]

for temp in temperatures:
    prompt = ["Shall", " ", "I", " ", "com", "pare", " ", "thee", " ", "to", " ", "a", " ", "sum", "mer's", " ", "day", "?", "\n"]
    generated_string = ""
    for s in prompt:
        generated_string += s
    lines = 1

    prompt = toTensor(prompt)
    hidden = torch.zeros(LAYERS, 1, HIDDEN_SIZE).to(device)
    cell = torch.zeros(LAYERS, 1, HIDDEN_SIZE).to(device)

    for c in range(len(prompt) - 1):
        _, (hidden, cell) = model(prompt[c].view(1).to(device), (hidden, cell))

    prime_char = prompt[-1]

    while lines < 14:
        output, (hidden, cell) = model(prime_char.view(1).to(device), (hidden, cell))
        distribution = F.softmax(output / temp, dim=2)
        syllable_id = torch.multinomial(distribution[0], 1)[0][0].item()
        next_syllable = reverse_dict[syllable_id]
        generated_string += str(next_syllable)
        prime_char = toTensor([next_syllable])
        if next_syllable == '\n':
            lines += 1
    
    print(f'Sonnet generated with temperature: {temp}')
    print(generated_string)

Sonnet generated with temperature: 1.5
Shall I compare thee to a summer's day?
It him best,dead still consmost darhes,goesnature were hast writetain
Not taught thou our days,tiesjusiege astaintned death-sesnalatetaint!
Unless grief to love,tatain flattance matage best
Whate'tions have grant much,than flowers going,and thou frown if give,the two none.
Is whom of gay,
Against otcause know,a Under,abunby charial?
Ere to not discahehold,by thou end,if thou deep sertant need gait,it tells-beus ,
Yet temfore hateselfhing proritate,to changetyranrest prime cast,
Unless aputy's moun valate may,from jolThy satekens.I my costling bier baall mourn pleaty,as thy waigarden'riner's strong,from of wake,ter hues,then summer life lory,some datresscy  sumsumbe raspeak,though sight?
Or hard
Save truthsrigly memtish cheer;of,where two the news;
and estan hooks?
How faigoingly saufinds:

Sonnet generated with temperature: 0.75
Shall I compare thee to a summer's day?
Which find too do away,
But love I age w