# Train KVETA

In [1]:
%load_ext dotenv
%dotenv
import os

base_dir = os.getenv("WORKING_DIR")
os.chdir(base_dir)

In [2]:
from src.data_loader_and_saver import JSONDataLoaderAndSaver

data_loader_and_saver = JSONDataLoaderAndSaver(base_dir, input_data_dir="src/data", output_data_dir="src/kveta/data")

In [11]:
from collections import Counter
from typing import Tuple

from src.util import Util
from src.kveta.sampa_syllable_parser import SampaSyllableParser
from src.kveta.syllable_class_parser import SyllableClassParser


def get_counters(train_X: list, train_y: list) -> Tuple[Counter, Counter, Counter]:
    """
    Given training data return counter of Syllable Classes, of metrical positions
    and of metrical positions with corresponding Syllable Classes.
    :param train_X: X training data
    :param train_y: y training data
    :return: Counters
    """
    syll_cls_counter = Counter()
    metrical_pos_counter = Counter()
    metrical_pos_syll_cls_counter = Counter()

    sampa_parser = SampaSyllableParser()
    syllable_class_parser = SyllableClassParser()

    for poem_idx, (poem_X, poem_y) in enumerate(zip(train_X, train_y)):
        poem_sonority_peaks = sampa_parser.parse_poem(poem_X)
        poem_syllable_classes = syllable_class_parser.parse_poem(poem_sonority_peaks, poem_X)

        for line_idx, (line_syllable_classes, line_y) in enumerate(zip(poem_syllable_classes, poem_y)):
            metrical_pattern = list(Util.normalize_metrical_pattern(line_y["pattern"]))
            syllable_classes = [syll_class for word in line_syllable_classes for syll_class in word]

            for syllable_cls, metrical_pos in zip(syllable_classes, metrical_pattern):
                syll_cls_counter[str(syllable_cls)] += 1
                metrical_pos_counter[metrical_pos] += 1
                metrical_pos_syll_cls_counter[(metrical_pos, str(syllable_cls))] += 1

    return syll_cls_counter, metrical_pos_counter, metrical_pos_syll_cls_counter

In [12]:
from collections import OrderedDict, Counter


def get_metrical_pos_given_syll_class_proba(metrical_pos_syll_cls_counter: Counter, syll_cls_counter: Counter) -> OrderedDict:
    """
    Get conditional probability of metrical positions given corresponding Syllable Class values.
    :param metrical_pos_syll_cls_counter: Counter of metrical positions with corresponding Syllable Classes
    :param syll_cls_counter: Counter of Syllable Classes
    :return: Conditional probability
    """
    metrical_pos_given_syll_class_proba = OrderedDict()

    for metrical_pos, syll_cls in sorted(metrical_pos_syll_cls_counter):
        if syll_cls not in metrical_pos_given_syll_class_proba:
            metrical_pos_given_syll_class_proba[syll_cls] = OrderedDict()

        metrical_pos_given_syll_class_proba[syll_cls][metrical_pos] = metrical_pos_syll_cls_counter[(metrical_pos, syll_cls)] / syll_cls_counter[syll_cls]

    return metrical_pos_given_syll_class_proba

## All poems just 1 metre, no unknown metres

In [13]:
extension = "_one_metre_all_metres_recognized"

### Load data

In [14]:
train_X, dev_X, _ = data_loader_and_saver.load_all_data(f"_X{extension}")

train_X_one_metre_all_metres_recognized.json: loaded 40137 records.
dev_X_one_metre_all_metres_recognized.json: loaded 8601 records.
test_X_one_metre_all_metres_recognized.json: loaded 8601 records.


In [15]:
train_X += dev_X

In [16]:
len(train_X)

48738

In [17]:
train_y, dev_y, _ = data_loader_and_saver.load_all_data(f"_y{extension}")

train_y_one_metre_all_metres_recognized.json: loaded 40137 records.
dev_y_one_metre_all_metres_recognized.json: loaded 8601 records.
test_y_one_metre_all_metres_recognized.json: loaded 8601 records.


In [18]:
train_y += dev_y

In [19]:
len(train_y)

48738

### Preprocess and count probabilites

In [20]:
syll_cls_counter, metrical_pos_counter, metrical_pos_syll_cls_counter = get_counters(train_X, train_y)

In [18]:
syll_cls_counter.most_common()

[('SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  4959072),
 ('SyllableClass(initial=True, final=True, content_word=False, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  2390394),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)',
  2297302),
 ('SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  1772143),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=True, next_long=False)',
  1216977),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=True)',
  821695),
 ('SyllableClass(i

In [22]:
len(syll_cls_counter.keys())

15

In [23]:
metrical_pos_counter.most_common()

[('W', 7793803), ('S', 7273315)]

In [24]:
len(metrical_pos_syll_cls_counter.keys())

30

In [22]:
metrical_pos_syll_cls_counter.most_common()

[(('W',
   'SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  4158333),
 (('S',
   'SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)'),
  2204835),
 (('W',
   'SyllableClass(initial=True, final=True, content_word=False, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  1726755),
 (('W',
   'SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  1157366),
 (('S',
   'SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=True, next_long=False)'),
  1117978),
 (('S',
   'SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_ini

In [26]:
metrical_pos_cnt = sum(metrical_pos_counter.values())
syll_cls_cnt = sum(syll_cls_counter.values())
metrical_pos_syll_cls_cnt = sum(metrical_pos_syll_cls_counter.values())

assert (metrical_pos_cnt == syll_cls_cnt == metrical_pos_syll_cls_cnt)

In [27]:
pos_cnt = metrical_pos_cnt
pos_cnt

15067118

In [28]:
metrical_pos_given_syll_class_proba = get_metrical_pos_given_syll_class_proba(metrical_pos_syll_cls_counter, syll_cls_counter)
metrical_pos_given_syll_class_proba

OrderedDict([('SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
              OrderedDict([('S', 0.34691162056335184),
                           ('W', 0.6530883794366482)])),
             ('SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
              OrderedDict([('S', 0.1614695249433765),
                           ('W', 0.8385304750566235)])),
             ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)',
              OrderedDict([('S', 0.9597497412181768),
                           ('W', 0.0402502587818232)])),
             ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=True)',
          

In [29]:
kveta_probabilities = {
    "metrical_pos_given_syll_cls_proba": metrical_pos_given_syll_class_proba
}

In [30]:
data_loader_and_saver.save_data(kveta_probabilities, f"kveta_probabilities{extension}")

Data saved to kveta_probabilities_one_metre_all_metres_recognized.json


## All lines just 1 metre, no unknown metres

In [31]:
extension = "_one_metre_line_all_metres_recognized"

### Load data

In [32]:
train_X, dev_X, _ = data_loader_and_saver.load_all_data(f"_X{extension}")

train_X_one_metre_line_all_metres_recognized.json: loaded 41762 records.
dev_X_one_metre_line_all_metres_recognized.json: loaded 8949 records.
test_X_one_metre_line_all_metres_recognized.json: loaded 8950 records.


In [33]:
train_X += dev_X

In [34]:
len(train_X)

50711

In [35]:
train_y, dev_y, _ = data_loader_and_saver.load_all_data(f"_y{extension}")

train_y_one_metre_line_all_metres_recognized.json: loaded 41762 records.
dev_y_one_metre_line_all_metres_recognized.json: loaded 8949 records.
test_y_one_metre_line_all_metres_recognized.json: loaded 8950 records.


In [36]:
train_y += dev_y

In [37]:
len(train_y)

50711

### Preprocess and count probabilites

In [38]:
syll_cls_counter, metrical_pos_counter, metrical_pos_syll_cls_counter = get_counters(train_X, train_y)

In [39]:
syll_cls_counter.most_common()

[('SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  5387807),
 ('SyllableClass(initial=True, final=True, content_word=False, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  2597846),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)',
  2498180),
 ('SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
  1930454),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=True, next_long=False)',
  1322609),
 ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=True)',
  890423),
 ('SyllableClass(i

In [40]:
len(syll_cls_counter.keys())

15

In [41]:
metrical_pos_counter.most_common()

[('W', 8485897), ('S', 7887959)]

In [42]:
len(metrical_pos_syll_cls_counter.keys())

30

In [43]:
metrical_pos_syll_cls_counter.most_common()

[(('W',
   'SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  4523044),
 (('S',
   'SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)'),
  2394828),
 (('W',
   'SyllableClass(initial=True, final=True, content_word=False, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  1876096),
 (('W',
   'SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)'),
  1264527),
 (('S',
   'SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=True, next_long=False)'),
  1212739),
 (('S',
   'SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_ini

In [44]:
metrical_pos_cnt = sum(metrical_pos_counter.values())
syll_cls_cnt = sum(syll_cls_counter.values())
metrical_pos_syll_cls_cnt = sum(metrical_pos_syll_cls_counter.values())

assert (metrical_pos_cnt == syll_cls_cnt == metrical_pos_syll_cls_cnt)

In [45]:
pos_cnt = metrical_pos_cnt
pos_cnt

16373856

In [46]:
metrical_pos_given_syll_class_proba = get_metrical_pos_given_syll_class_proba(metrical_pos_syll_cls_counter, syll_cls_counter)
metrical_pos_given_syll_class_proba

OrderedDict([('SyllableClass(initial=False, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
              OrderedDict([('S', 0.3449587506358608),
                           ('W', 0.6550412493641392)])),
             ('SyllableClass(initial=False, final=True, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=None, next_long=None)',
              OrderedDict([('S', 0.16050370772375477),
                           ('W', 0.8394962922762452)])),
             ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=False)',
              OrderedDict([('S', 0.9586290819716754),
                           ('W', 0.04137091802832462)])),
             ('SyllableClass(initial=True, final=False, content_word=None, mpp_preposition=False, prev_mpp_preposition=False, prev_initial=False, next_long=True)',
         

In [47]:
kveta_probabilities = {
    "metrical_pos_given_syll_cls_proba": metrical_pos_given_syll_class_proba
}

In [48]:
data_loader_and_saver.save_data(kveta_probabilities, f"kveta_probabilities{extension}")

Data saved to kveta_probabilities_one_metre_line_all_metres_recognized.json
