In [2]:
import numpy as np
import pandas as pd
from string import punctuation

In [38]:
def load_data():
    
    data_txt = open('data/shakespeare.txt')
    quatrain = []; volta = []; couplet = [];
    
    Lines = data_txt.readlines()
    count=0
    for i,line in enumerate(Lines):
        if any(char.isdigit() for char in line):
            for j in range(1,9):
                quatrain.append(Lines[i+j][:-1])
            for j in range(9,13):
                volta.append(Lines[i+j][:-1])
            for j in range(13,15):
                couplet.append(Lines[i+j][:-1])
    return quatrain,volta,couplet


def load_key(includePunctuation=True):
    
    data_key = open('data/Syllable_dictionary.txt')
    Lines = data_key.readlines()
    
    key_list = [];
    
    for i,line in enumerate(Lines):
        split_str = line.strip(punctuation).split()
        if len(split_str) == 3:
            if len(split_str[1])==2:
                key_list.append([split_str[0],[int(split_str[2])],int(split_str[1][1])])
            elif len(split_str[2])==2:
                key_list.append([split_str[0],[int(split_str[1])],int(split_str[2][1])])
            else:
                key_list.append([split_str[0],[int(split_str[1]),int(split_str[2])],0])

        else:
            key_list.append([split_str[0],[int(split_str[1])],0])
    
    #append punctuation
    if includePunctuation==True:
        #comma
        key_list.append([',',[0],0])
        #period
        key_list.append(['.',[0],0])
        #question
        key_list.append(['?',[0],0])
        #colon
        key_list.append([':',[0],0])
        #semicolon
        key_list.append([';',[0],0])
    
    return key_list
    
    
    
def convert_to_nums(dataset,key_list,includePunctuation=True):
    punc_list = [',','.','?',':',';']
    word_list = [i[0] for i in key_list]
    seq_list = []
    for i in range(0,len(dataset)):
        words = dataset[i].split()
        seq = []
        for j in range(0,len(words)):
            if words[j].lower() == "th'" or words[j].lower() == "t'":
                n = word_list.index(words[j].lower())
            else:
                n = word_list.index(words[j].lower().strip(punctuation))
            seq.append(n)
            
            if includePunctuation==True:
                if words[j][-1] in punc_list:
                    seq.append(word_list.index(words[j][-1]))
            
        seq_list.append(seq)
    
    return seq_list

def convert_to_nums_red_keylist(dataset,key_list,includePunctuation=True):
    punc_list = [',','.','?',':',';']
    word_list = [i[0] for i in key_list]
    seq_list = []
    
    key_list_red = []
    used_inds = []
    for i in range(0,len(dataset)):
        words = dataset[i].split()
        seq = []
        for j in range(0,len(words)):
            if words[j].lower() == "th'" or words[j].lower() == "t'":
                n = word_list.index(words[j].lower())
            else:
                n = word_list.index(words[j].lower().strip(punctuation))
            if n not in used_inds:
                used_inds.append(n)
                key_list_red.append(key_list[n])
                
            seq.append(n)    
            if includePunctuation==True:
                if words[j][-1] in punc_list:
                    punc_indx = word_list.index(words[j][-1])
                    seq.append(punc_indx)
                    if punc_indx not in used_inds:
                        used_inds.append(punc_indx)
                        key_list_red.append(key_list[punc_indx])
        seq_list.append(seq)
    
    #convert sequence to 'reduced' indices
    seq_list_red = []
    for i in range(0,len(seq_list)):
        seq_list_red_cur = []
        for j in range(0,len(seq_list[i])):
            seq_list_red_cur.append(used_inds.index(seq_list[i][j]))
        seq_list_red.append(seq_list_red_cur)
    
    
    return seq_list_red,key_list_red

def generate_sequences():
    #CALL THIS TO GENERATE THE 3 TRAINING DATASETS
    #OUTPUTS:
    #seq_list_quatrain: list of sequences corresponding to lines in the first 2 quatrains
    #seq_list_volta: list of sequences corresponding to lines in the volta
    #seq_list_couplet: list of sequences corresponding to lines in the couplet
    #key_list: list of lists containing information stored in Syllable_dictionary:
    #Each list corresponds to a word, with index 0 being the word, 
    #index 1 being a list containing possible numbers of syllables, and index 2 being the number of syllables if it's
    #an end word.
    
    quatrain,volta,couplet = load_data()
    key_list = load_key()
    seq_list_quatrain = convert_to_nums(quatrain,key_list)
    seq_list_volta = convert_to_nums(volta,key_list)
    seq_list_couplet = convert_to_nums(couplet,key_list)
    return seq_list_quatrain,seq_list_volta,seq_list_couplet,key_list

def generate_sequences_reduced(includePunctuation = True):
    #CALL THIS TO GENERATE THE 3 TRAINING DATASETS
    #OUTPUTS:
    #seq_lists: list with 3 elements corresponding to the 3 lists of sequences
    #key_lists: list with 3 elements corresponding to the 3 associated key lists
    
    quatrain,volta,couplet = load_data()
    key_list = load_key(includePunctuation)
    
    seq_list_quatrain,key_list_quatrain = convert_to_nums_red_keylist(quatrain,key_list,includePunctuation)
    seq_list_volta,key_list_volta = convert_to_nums_red_keylist(volta,key_list,includePunctuation)
    seq_list_couplet,key_list_couplet = convert_to_nums_red_keylist(couplet,key_list,includePunctuation)
    
    seq_lists = [seq_list_quatrain,seq_list_volta,seq_list_couplet]
    key_lists = [key_list_quatrain,key_list_volta,key_list_couplet]
    return seq_lists,key_lists

def generate_rhyme_lists():
    #Retur
    
    key_list = load_key(includePunctuation=False)
    seq_lists,key_lists = generate_sequences_reduced(includePunctuation=False)
    rhymes = get_rhymes(seq_lists)
    return rhymes

def get_rhymes(seq_lists):
    #requires sequence lists with no punctuation
    rhyme_list_quatrain = []
    for i in range(0,len(seq_lists[0])-2,4):
        add_rhyme([seq_lists[0][i][-1],seq_lists[0][i+2][-1]],rhyme_list_quatrain)
        add_rhyme([seq_lists[0][i+1][-1],seq_lists[0][i+3][-1]],rhyme_list_quatrain)
    rhyme_list_volta = []
    for i in range(0,len(seq_lists[1])-2,4):
        add_rhyme([seq_lists[1][i][-1],seq_lists[1][i+2][-1]],rhyme_list_volta)
        add_rhyme([seq_lists[1][i+1][-1],seq_lists[1][i+3][-1]],rhyme_list_volta)
    rhyme_list_couplet = []

    for i in range(0,len(seq_lists[2])-1,2):
        if seq_lists[2][i] and seq_lists[2][i+1]:
            add_rhyme([seq_lists[2][i][-1],seq_lists[2][i+1][-1]],rhyme_list_couplet)
    
    return [rhyme_list_quatrain,rhyme_list_volta,rhyme_list_couplet]
               
            
def add_rhyme(rhyme,rhyme_list):
    if rhyme not in rhyme_list and [rhyme[1],rhyme[0]] not in rhyme_list:
        rhyme_list.append(rhyme)

In [37]:
quatrain,volta,couplet = load_data()
key_list = load_key(includePunctuation=False)

#seq_list = convert_to_nums(quatrain,key_list)
seq_lists,key_lists = generate_sequences_reduced(includePunctuation=False)
rhymes = get_rhymes(seq_lists)
# for i in range(0,len(rhymes[0])):
#     print(key_lists[0][rhymes[0][i][0]][0]+' '+key_lists[0][rhymes[0][i][1]][0]+'\n')
