In [5]:
import numpy as np
import pandas as pd
from string import punctuation

In [36]:
def load_data():
    
    data_txt = open('data/shakespeare.txt')
    quatrain = []; volta = []; couplet = [];
    
    Lines = data_txt.readlines()
    count=0
    for i,line in enumerate(Lines):
        if any(char.isdigit() for char in line):
            for j in range(1,9):
                quatrain.append(Lines[i+j][:-1])
            for j in range(9,13):
                volta.append(Lines[i+j][:-1])
            for j in range(13,15):
                couplet.append(Lines[i+j][:-1])
    return quatrain,volta,couplet

def load_key(includePunctuation=True):
    
    data_key = open('data/Syllable_dictionary.txt')
    Lines = data_key.readlines()
    
    key_list = [];
    
    for i,line in enumerate(Lines):
        split_str = line.strip(punctuation).split()
        if len(split_str) == 3:
            if len(split_str[1])==2:
                key_list.append([split_str[0],[int(split_str[2])],int(split_str[1][1])])
            elif len(split_str[2])==2:
                key_list.append([split_str[0],[int(split_str[1])],int(split_str[2][1])])
            else:
                key_list.append([split_str[0],[int(split_str[1]),int(split_str[2])],0])

        else:
            key_list.append([split_str[0],[int(split_str[1])],0])
    
    #append punctuation
    if includePunctuation==True:
        #comma
        key_list.append([',',[0],0])
        #period
        key_list.append(['.',[0],0])
        #question
        key_list.append(['?',[0],0])
        #colon
        key_list.append([':',[0],0])
        #semicolon
        key_list.append([';',[0],0])
    
    return key_list
    
    
    
def convert_to_nums(dataset,key_list,includePunctuation=True):
    punc_list = [',','.','?',':',';']
    word_list = [i[0] for i in key_list]
    seq_list = []
    for i in range(0,len(dataset)):
        words = dataset[i].split()
        seq = []
        for j in range(0,len(words)):
            if words[j].lower() == "th'" or words[j].lower() == "t'":
                n = word_list.index(words[j].lower())
            else:
                n = word_list.index(words[j].lower().strip(punctuation))
            seq.append(n)
            
            if includePunctuation==True:
                if words[j][-1] in punc_list:
                    seq.append(word_list.index(words[j][-1]))
            
        seq_list.append(seq)
    
    return seq_list

def convert_to_nums_red_keylist(dataset,key_list,includePunctuation=True):
    punc_list = [',','.','?',':',';']
    word_list = [i[0] for i in key_list]
    seq_list = []
    
    key_list_red = []
    used_inds = []
    for i in range(0,len(dataset)):
        words = dataset[i].split()
        seq = []
        for j in range(0,len(words)):
            if words[j].lower() == "th'" or words[j].lower() == "t'":
                n = word_list.index(words[j].lower())
            else:
                n = word_list.index(words[j].lower().strip(punctuation))
            if n not in used_inds:
                used_inds.append(n)
                key_list_red.append(key_list[n])
                
            seq.append(n)    
            if includePunctuation==True:
                if words[j][-1] in punc_list:
                    punc_indx = word_list.index(words[j][-1])
                    seq.append(punc_indx)
                    if punc_indx not in used_inds:
                        used_inds.append(punc_indx)
                        key_list_red.append(key_list[punc_indx])
        seq_list.append(seq)
    
    #convert sequence to 'reduced' indices
    seq_list_red = []
    for i in range(0,len(seq_list)):
        seq_list_red_cur = []
        for j in range(0,len(seq_list[i])):
            seq_list_red_cur.append(used_inds.index(seq_list[i][j]))
        seq_list_red.append(seq_list_red_cur)
    
    
    return seq_list_red,key_list_red

def generate_sequences():
    #CALL THIS TO GENERATE THE 3 TRAINING DATASETS
    #OUTPUTS:
    #seq_list_quatrain: list of sequences corresponding to lines in the first 2 quatrains
    #seq_list_volta: list of sequences corresponding to lines in the volta
    #seq_list_couplet: list of sequences corresponding to lines in the couplet
    #key_list: list of lists containing information stored in Syllable_dictionary:
    #Each list corresponds to a word, with index 0 being the word, 
    #index 1 being a list containing possible numbers of syllables, and index 2 being the number of syllables if it's
    #an end word.
    
    quatrain,volta,couplet = load_data()
    key_list = load_key()
    seq_list_quatrain = convert_to_nums(quatrain,key_list)
    seq_list_volta = convert_to_nums(volta,key_list)
    seq_list_couplet = convert_to_nums(couplet,key_list)
    return seq_list_quatrain,seq_list_volta,seq_list_couplet,key_list

def generate_sequences_reduced(includePunctuation = True):
    #CALL THIS TO GENERATE THE 3 TRAINING DATASETS
    #OUTPUTS:
    #seq_lists: list with 3 elements corresponding to the 3 lists of sequences
    #key_lists: list with 3 elements corresponding to the 3 associated key lists
    
    quatrain,volta,couplet = load_data()
    key_list = load_key(includePunctuation)
    
    seq_list_quatrain,key_list_quatrain = convert_to_nums_red_keylist(quatrain,key_list,includePunctuation)
    seq_list_volta,key_list_volta = convert_to_nums_red_keylist(volta,key_list,includePunctuation)
    seq_list_couplet,key_list_couplet = convert_to_nums_red_keylist(couplet,key_list,includePunctuation)
    
    seq_lists = [seq_list_quatrain,seq_list_volta,seq_list_couplet]
    key_lists = [key_list_quatrain,key_list_volta,key_list_couplet]
    return seq_lists,key_lists
    
    
            
            

In [38]:
quatrain,volta,couplet = load_data()
key_list = load_key()

#seq_list = convert_to_nums(quatrain,key_list)
seq_lists,key_lists = generate_sequences_reduced()
print(seq_lists[0])


[[0, 1, 2, 3, 4, 6, 5], [7, 8, 9, 10, 11, 12, 6, 13], [14, 15, 16, 17, 18, 19, 20, 6, 21], [22, 23, 24, 11, 25, 22, 27, 26], [14, 28, 29, 30, 31, 32, 33, 6, 34], [35, 36, 37, 38, 39, 40, 6, 41], [42, 43, 44, 45, 46, 6, 47], [36, 48, 36, 6, 49, 30, 36, 50, 48, 51, 27, 52], [53, 54, 55, 56, 57, 36, 6, 58], [59, 60, 61, 62, 63, 36, 9, 6, 64], [36, 65, 66, 67, 68, 69, 70, 6, 71], [72, 73, 43, 74, 75, 76, 77, 78, 27, 79], [80, 81, 6, 82, 45, 83, 36, 84, 6, 47], [45, 83, 16, 85, 76, 36, 86, 88, 87], [30, 89, 90, 31, 32, 61, 91, 6, 34], [92, 93, 94, 6, 95, 59, 96, 98, 97], [99, 63, 36, 100, 59, 101, 16, 102, 28, 6, 103], [71, 104, 16, 20, 7, 102, 18, 105, 6, 106], [107, 108, 109, 110, 71, 28, 111, 6, 112], [28, 113, 114, 16, 6, 115, 116, 117, 98, 118], [119, 45, 104, 120, 68, 121, 107, 122, 123], [124, 16, 125, 76, 36, 127, 126], [128, 129, 104, 130, 68, 131, 72, 73, 16, 6, 132], [76, 22, 133, 30, 134, 127, 135], [136, 137, 138, 113, 28, 6, 139], [140, 36, 48, 36, 9, 127, 141], [142, 143, 144