In [None]:
import re
import numpy as np
import os
os.sys.path.append('../1/')
from z2 import loader
from math import log
import sys
import heapq
import collections
import operator


vowels = list('aeioóuyąę') + list('aeioóuyąę'.upper())
compacted_vovels = ['i' + x for x in vowels if x != 'i']
word2tag = dict()
tag2word = dict()
dataPath = 'data/'

def stringNorm(sent, num=False):
    regex = re.compile(f'[,\.!?:;\'{"0-9" if not num else ""}\*\-“…\(\)„”—»«–––=\[\]’]')
    return regex.sub('',sent.lower())

def bigrams2unigrams(bigrams):
    return {w1: sum([float(bigrams[w1][w2]) for w2 in bigrams[w1]])/2 for w1 in bigrams}

def count_syllable(phrase, verose=False):
    res = 0
    for i, letter in enumerate(phrase):
        if letter in vowels:
            res += 1
            if verose:
                print(letter)
        if phrase[i:i+2] in compacted_vovels:
            res -= 1
            if verose:
                print(phrase[i:i+2])
    return res


with open(dataPath + "supertags.txt") as tags:
    for line in tags:
        word, tag = stringNorm(line, num=True).split()
        word2tag[word] = tag
        if tag in tag2word:
            tag2word[tag].append(word)
        else:
            tag2word[tag] = [word]
            
base = {}                       
with open(dataPath + "superbazy.txt") as file:
    for line in file:
        word, base_word = line.split()
        base[word] = base_word

            
vectors = {}
with open(dataPath + "poleval_base_vectors.txt") as file:
    for line in file:
        vec = line.split()
        if 150< len(vec) < 250:
            x = np.array([float(x) for x in vec[1:]])
            vectors[vec[0]] = x / np.sqrt(x.T @ x)

                       

with open(dataPath + 'rytmiczne_zdania_z_korpusu.txt') as f:
    sentences = [
        tuple(
            [
                [
                    x for x in y.split()
                ] 
                for y in line.split('RYM:')[1].rstrip(' .\n').split('[*]')
            ]
        ) 
        for line in f
    ]               
                       
                       


PMI = lambda w1, w2: log(
    float(bigrams[w1][w2] if w1 in bigrams and w2 in bigrams[w1] else 1) * uniSum / (unigrams[w1] * unigrams[w2]) 
    + sys.float_info.min)
                       
                       

In [None]:

for i, x in enumerate(vectors):
    if i < 10:
        print(vectors[x].T@vectors[x].T)
    else:
        break

In [None]:
def get_rym(w):
    best = None
    for i in range(len(w)):
        if count_syllable(w[i:]) == 2:
            best = w[i:]
    return best

def sample_verset():
    index = np.random.choice(np.arange(len(sentences)))
    return sentences[index]


def get_accents(phrase):
    return [count_syllable(x) for x in phrase]


In [None]:
def sameTags(w):
    if w in word2tag:
        return tag2word[word2tag[w]]
    elif ('^' + w)[-3:] in word2tag:
        return tag2word[word2tag[('^' + w)[-3:]]]
    else:
        return []
    
def createAltWords(accent, verse, rime=None):
    return [list(
        set(
            filter(
                lambda x: count_syllable(x) == accent[i] ,
                sameTags(w)
            )
        ).intersection(
            {y for y in safeGrams}
        ))
        for i, w in enumerate(verse)
    ]

def createAltWords(accent, w, rime=None):
    return list(
            filter(
                lambda x: count_syllable(x) == accent and x != w and(rime is None or get_rym(x) == rime),
                sameTags(w)
            )
    )                      
                        
def get_rime_set(alts):
    return {get_rym(x) for x in alts}


def change_word(accent, word, rime=None):
    alts = createAltWords(accent, word, rime=rime)
    vec = vectors[base[word]]
    values = list(
        map(
            lambda x: vectors[base[x]].T @ vec  if x in base and base[x] in vectors else 0,
            alts
        ))
    if len(alts) > 0:
        x = np.argmax(np.array(values))
        choosen = alts[x]
        return choosen, values[x]
    return None, 0


def find_common_rime(a1,w1,a2,w2):
    alts1 = createAltWords(a1, w1, rime=None)
    alts2 = createAltWords(a2, w2, rime=None)
    common_rimes = get_rime_set(alts1).intersection(get_rime_set(alts2) ) 
    if len(common_rimes) == 0:
        return None
    
    vec1 = vectors[base[w1]]
    values1 = list(
        map(
            lambda x: vectors[base[x]].T @ vec1  if x in base and base[x] in vectors and get_rym(x) in common_rimes  else 0,
            alts1
        ))
    
    x = np.argmax(np.array(values1))
    choosen1 = alts1[x]
    choosen2 = change_word(a2, w2, rime=get_rym(choosen1))[0]

    return choosen1, choosen2
    
def change_last(a1,w1,a2,w2):
        alt1 = change_word(a1, v1[-1], get_rym(v2[-1]))
        alt2 = change_word(a2, v2[-1], get_rym(v1[-1]))
#         print(alt1, alt2)
        if alt1[1] <= alt2[1] and alt2[1] > 0:
            return w1, alt2[0]
        elif alt2[1] < alt1[1] and alt1[1] > 0:
            return alt1[0], w2
        else:
            return find_common_rime(a1, v1[-1], a2, v2[-1])
        
def pretty_print(v1,v2):
    print(' '.join(v1)+'\n'+' '.join(v2)+'\n')

In [None]:
for _ in range(20):
    v1,v2 = sample_verset()
    a1 = [count_syllable(i) for i in v1] 
    a2 = [count_syllable(i) for i in v2]
    pretty_print(v1,v2)
    for i, w1 in enumerate(v1):
        if np.random.rand(1) > 0.8:
            try:
                if i < len(v1) - 1:
                    v1[i] = change_word(a1[i], w1)[0]
                else:
                    v1[-1], v2[-1] = change_last(a1[i],w1,a2[-1],v2[-1])
                pretty_print(v1,v2)
            except:
                print('.')
    print('<->')
    for i, w2 in enumerate(v2):
        if np.random.rand(1) > 0.8:
            try:
                if i < len(v2) - 1:
                    v2[i] = change_word(a2[i], w2)[0]
                else:
                    v1[-1], v2[-1] = change_last(a1[-1],v1[-1],a2[i],w2)
                pretty_print(v1,v2)
            except:
                print('.')
                
    print('************')
