### Create Synthetic Dateset on Penn Treebank

In [2]:
import numpy as np
from collections import Counter
from collections import defaultdict

In [3]:
def read_sentences_with_pos_tags(path):
    sentences_with_pos_tags = []

    with open(path, 'r', encoding='utf-8') as file:
        current_sentence = []
        for line in file:
            # Skip empty lines and comments
            if line.strip() and not line.startswith('#'):
                fields = line.split('\t')
                if len(fields) > 3:  # Ensure there are enough fields
                    word = fields[1].lower()  # Word form is the second field
                    upos = fields[3]  # Universal POS tag is the fourth field
                    xpos = fields[4]  # Language specific POS tag is the fifth field
                    current_sentence.append((word, upos, xpos))

            # New sentence
            elif current_sentence:
                sentences_with_pos_tags.append(current_sentence)
                current_sentence = []

    return sentences_with_pos_tags

In [4]:
file_path1 = '../data/ptb/penn-train.conllu'
file_path2 = '../data/ptb/penn-test.conllu'
file_path3 = '../data/ptb/penn-dev.conllu'

In [5]:
sentences_pos_tags = read_sentences_with_pos_tags(file_path1)
print(len(sentences_pos_tags))
sentences_pos_tags[:3]

39832


[[('in', 'ADP', 'IN'),
  ('an', 'DET', 'DT'),
  ('oct.', 'PROPN', 'NNP'),
  ('19', 'NUM', 'CD'),
  ('review', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('``', 'PUNCT', '``'),
  ('the', 'DET', 'DT'),
  ('misanthrope', 'NOUN', 'NN'),
  ("''", 'PUNCT', "''"),
  ('at', 'ADP', 'IN'),
  ('chicago', 'PROPN', 'NNP'),
  ("'s", 'PART', 'POS'),
  ('goodman', 'PROPN', 'NNP'),
  ('theatre', 'PROPN', 'NNP'),
  ('-lrb-', 'PUNCT', '-LRB-'),
  ('``', 'PUNCT', '``'),
  ('revitalized', 'VERB', 'VBN'),
  ('classics', 'NOUN', 'NNS'),
  ('take', 'VERB', 'VBP'),
  ('the', 'DET', 'DT'),
  ('stage', 'NOUN', 'NN'),
  ('in', 'ADP', 'IN'),
  ('windy', 'PROPN', 'NNP'),
  ('city', 'PROPN', 'NNP'),
  (',', 'PUNCT', ','),
  ("''", 'PUNCT', "''"),
  ('leisure', 'NOUN', 'NN'),
  ('&', 'CONJ', 'CC'),
  ('arts', 'NOUN', 'NNS'),
  ('-rrb-', 'PUNCT', '-RRB-'),
  (',', 'PUNCT', ','),
  ('the', 'DET', 'DT'),
  ('role', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('celimene', 'PROPN', 'NNP'),
  (',', 'PUNCT', ','),
  ('played', 'VE

In [6]:
def replace_low_frequency_words(sentences_with_pos_tags, filter_count=1):
    # Count the frequencies of each word
    word_counts = Counter(word for sentence in sentences_with_pos_tags for word, _, _ in sentence)

    # Replace words with count less than filter_count to 'UNK' and their tags to 'UNK_TAG'
    processed_sentences = []
    for sentence in sentences_with_pos_tags:
        new_sentence = []
        for word, upos, xpos in sentence:
            if word_counts[word] < filter_count:
                new_word = 'UNK'
                new_upos = 'UNK_TAG'
                new_xpos = 'UNK_TAG'
            else:
                new_word = word
                new_upos = upos
                new_xpos = xpos
            new_sentence.append((new_word, new_upos, new_xpos))
        processed_sentences.append(new_sentence)

    return processed_sentences

In [7]:
filtered_sentences = replace_low_frequency_words(sentences_pos_tags, filter_count=20)
filtered_sentences[:3]

[[('in', 'ADP', 'IN'),
  ('an', 'DET', 'DT'),
  ('oct.', 'PROPN', 'NNP'),
  ('19', 'NUM', 'CD'),
  ('review', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('``', 'PUNCT', '``'),
  ('the', 'DET', 'DT'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ("''", 'PUNCT', "''"),
  ('at', 'ADP', 'IN'),
  ('chicago', 'PROPN', 'NNP'),
  ("'s", 'PART', 'POS'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('-lrb-', 'PUNCT', '-LRB-'),
  ('``', 'PUNCT', '``'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('take', 'VERB', 'VBP'),
  ('the', 'DET', 'DT'),
  ('stage', 'NOUN', 'NN'),
  ('in', 'ADP', 'IN'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('city', 'PROPN', 'NNP'),
  (',', 'PUNCT', ','),
  ("''", 'PUNCT', "''"),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('&', 'CONJ', 'CC'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  ('-rrb-', 'PUNCT', '-RRB-'),
  (',', 'PUNCT', ','),
  ('the', 'DET', 'DT'),
  ('role', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('UNK', 'UNK_TAG', 'UNK_TAG'),
  (',', 'PUNCT', ',

In [44]:
def create_vocab_index(sentences_with_pos_tags):
    # Function to create a dictionary mapping each unique word/POS to an integer index
    # with specified start index
    def build_index(items, start_index=0):
        item_to_index = defaultdict(lambda: len(item_to_index) + start_index)
        for item in items:
            item_to_index[item]
        return dict(item_to_index)

    # Flatten the list of sentences to get a single list of words and POS tags
    all_words = [word for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]
    all_upos_tags = [upos for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]
    all_xpos_tags = [xpos for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]

    return build_index(all_words, start_index=0), build_index(all_upos_tags, start_index=1),  build_index(all_xpos_tags, start_index=1)

In [45]:
word_to_index, upos_to_index, xpos_to_index = create_vocab_index(filtered_sentences)
print(len(word_to_index))
list(word_to_index.items())[10:15], list(upos_to_index.items())[:5], list(xpos_to_index.items())[:5]

4110


([('at', 10), ('chicago', 11), ("'s", 12), ('-lrb-', 13), ('take', 14)],
 [('ADP', 1), ('DET', 2), ('PROPN', 3), ('NUM', 4), ('NOUN', 5)],
 [('IN', 1), ('DT', 2), ('NNP', 3), ('CD', 4), ('NN', 5)])

In [46]:
len(word_to_index), len(upos_to_index), len(xpos_to_index)

(4110, 18, 46)

In [47]:
def convert_to_indexes(filtered_sentences_tags, word_to_index, upos_to_index, xpos_to_index):
    hidden_states_universal = []
    hidden_states_specific = []
    observations = []

    for sentence in filtered_sentences_tags:
        if len(sentence) <= 5: 
            continue
        sentence_upos_indexes = [upos_to_index[upos] for _, upos, _ in sentence]
        sentence_xpos_indexes = [xpos_to_index[xpos] for _, _, xpos in sentence]
        sentence_word_indexes = [word_to_index[word] for word, _, _ in sentence]

        hidden_states_universal.append(sentence_upos_indexes)
        hidden_states_specific.append(sentence_xpos_indexes)
        observations.append(sentence_word_indexes)

    return hidden_states_universal, hidden_states_specific, observations

In [48]:
hidden_states_universal, hidden_states_specific, observations = convert_to_indexes(
    filtered_sentences, word_to_index, upos_to_index, xpos_to_index)

In [49]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, hidden_states_universal[:5][index])) + ']')
    print('[' + ', '.join(map(str, hidden_states_specific[:5][index])) + ']')
    print('[' + ', '.join(map(str, observations[:5][index])) + ']')
    print('-----------------------------')

[1, 2, 3, 4, 5, 1, 6, 2, 7, 6, 1, 3, 8, 7, 7, 6, 6, 7, 7, 9, 2, 5, 1, 7, 3, 6, 6, 7, 10, 7, 6, 6, 2, 5, 1, 7, 6, 9, 1, 7, 7, 6, 11, 7, 9, 1, 7, 7, 6]
[1, 2, 3, 4, 5, 1, 6, 2, 7, 8, 1, 3, 9, 7, 7, 10, 6, 7, 7, 11, 2, 5, 1, 7, 3, 12, 8, 7, 13, 7, 14, 12, 2, 5, 1, 7, 12, 15, 1, 7, 7, 12, 16, 7, 15, 17, 7, 7, 18]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 8, 8, 13, 6, 8, 8, 14, 7, 15, 0, 8, 16, 17, 9, 8, 18, 8, 19, 17, 7, 20, 5, 8, 17, 21, 22, 8, 8, 17, 23, 8, 24, 25, 8, 8, 26]
-----------------------------
[7, 3, 3, 3, 9, 12, 9, 12, 3, 5, 8, 9, 13, 1, 1, 7, 5, 1, 4, 6]
[7, 3, 20, 3, 16, 21, 19, 22, 3, 23, 17, 24, 25, 1, 1, 7, 23, 1, 4, 18]
[8, 29, 30, 31, 32, 33, 34, 35, 36, 37, 25, 38, 39, 10, 40, 8, 30, 0, 41, 26]
-----------------------------
[2, 5, 5, 5, 13, 5, 9, 7, 5, 1, 2, 3]
[2, 5, 5, 5, 25, 5, 16, 7, 23, 1, 2, 3]
[7, 42, 43, 44, 45, 46, 47, 8, 30, 0, 7, 36]
-----------------------------
[3, 7, 6, 5, 10, 13, 5, 5, 6, 9, 12, 7, 5, 1, 2, 5, 5, 5, 1, 3, 10, 3, 6, 10, 1, 13, 13, 5, 6]

In [50]:
def add_noise_to_states_ptb(hidden_states, number_states, flip_prob=0.5):
    noisy_hidden_states = []
    for sequence in hidden_states:
        noisy_sequence = []
        for state in sequence:
            if np.random.rand() < flip_prob:
                # Flip the state to a different random state
                possible_states = list(range(1, number_states + 1))
                possible_states.remove(state)  # Remove the current state from possibilities
                new_state = np.random.choice(possible_states)
                noisy_sequence.append(new_state)
            else:
                noisy_sequence.append(state)
        noisy_hidden_states.append(noisy_sequence)
    return noisy_hidden_states

In [51]:
ptb_noisy_level = 0.3
noisy_hidden_states_universal = add_noise_to_states_ptb(hidden_states_universal, len(upos_to_index), flip_prob=ptb_noisy_level)
noisy_hidden_states_specific = add_noise_to_states_ptb(hidden_states_specific, len(xpos_to_index), flip_prob=ptb_noisy_level)

In [52]:
for i in range(len(hidden_states_universal)):
    hidden_states_universal[i].insert(0, 0)
    noisy_hidden_states_universal[i].insert(0, 0)

    hidden_states_specific[i].insert(0, 0)
    noisy_hidden_states_specific[i].insert(0, 0)
    
    observations[i].insert(0, -1)

In [53]:
for index in range(len(observations[20:25])):
    print('[' + ', '.join(map(str, hidden_states_universal[:5][index])) + ']')
    print('[' + ', '.join(map(str, noisy_hidden_states_universal[:5][index])) + ']')
    print('-----------------------------')

[0, 1, 2, 3, 4, 5, 1, 6, 2, 7, 6, 1, 3, 8, 7, 7, 6, 6, 7, 7, 9, 2, 5, 1, 7, 3, 6, 6, 7, 10, 7, 6, 6, 2, 5, 1, 7, 6, 9, 1, 7, 7, 6, 11, 7, 9, 1, 7, 7, 6]
[0, 8, 2, 3, 4, 5, 1, 6, 2, 7, 10, 15, 3, 5, 2, 7, 11, 4, 3, 13, 9, 3, 10, 14, 17, 3, 6, 6, 7, 10, 7, 16, 6, 2, 5, 1, 7, 7, 9, 1, 13, 7, 9, 11, 7, 9, 1, 14, 7, 6]
-----------------------------
[0, 7, 3, 3, 3, 9, 12, 9, 12, 3, 5, 8, 9, 13, 1, 1, 7, 5, 1, 4, 6]
[0, 7, 3, 18, 3, 9, 15, 9, 12, 3, 14, 8, 9, 17, 16, 3, 7, 5, 1, 7, 6]
-----------------------------
[0, 2, 5, 5, 5, 13, 5, 9, 7, 5, 1, 2, 3]
[0, 2, 5, 3, 2, 13, 11, 13, 7, 8, 1, 2, 3]
-----------------------------
[0, 3, 7, 6, 5, 10, 13, 5, 5, 6, 9, 12, 7, 5, 1, 2, 5, 5, 5, 1, 3, 10, 3, 6, 10, 1, 13, 13, 5, 6]
[0, 3, 7, 6, 5, 12, 13, 5, 5, 6, 9, 12, 7, 1, 1, 2, 5, 5, 5, 1, 6, 10, 3, 11, 10, 1, 13, 13, 5, 6]
-----------------------------
[0, 3, 3, 3, 9, 12, 5, 1, 4, 5, 1, 4, 5, 2, 5, 6]
[0, 3, 3, 2, 7, 12, 5, 1, 4, 5, 1, 3, 10, 10, 5, 1]
-----------------------------


In [54]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, hidden_states_specific[:5][index])) + ']')
    print('[' + ', '.join(map(str, noisy_hidden_states_specific[:5][index])) + ']')
    print('-----------------------------')

[0, 1, 2, 3, 4, 5, 1, 6, 2, 7, 8, 1, 3, 9, 7, 7, 10, 6, 7, 7, 11, 2, 5, 1, 7, 3, 12, 8, 7, 13, 7, 14, 12, 2, 5, 1, 7, 12, 15, 1, 7, 7, 12, 16, 7, 15, 17, 7, 7, 18]
[0, 1, 20, 3, 20, 5, 1, 6, 2, 7, 8, 33, 3, 9, 7, 7, 10, 6, 7, 7, 17, 2, 5, 26, 7, 3, 28, 44, 7, 14, 7, 41, 12, 2, 5, 1, 7, 16, 16, 8, 7, 7, 12, 16, 22, 15, 22, 7, 7, 18]
-----------------------------
[0, 7, 3, 20, 3, 16, 21, 19, 22, 3, 23, 17, 24, 25, 1, 1, 7, 23, 1, 4, 18]
[0, 7, 3, 20, 3, 42, 21, 19, 36, 3, 23, 17, 24, 27, 1, 46, 4, 39, 1, 13, 18]
-----------------------------
[0, 2, 5, 5, 5, 25, 5, 16, 7, 23, 1, 2, 3]
[0, 2, 5, 5, 36, 25, 30, 16, 7, 23, 1, 2, 24]
-----------------------------
[0, 3, 7, 12, 5, 13, 25, 5, 5, 12, 16, 21, 7, 5, 1, 2, 5, 5, 5, 1, 3, 13, 3, 12, 13, 1, 25, 25, 23, 18]
[0, 3, 7, 12, 15, 13, 25, 5, 5, 12, 16, 21, 7, 5, 1, 19, 5, 44, 43, 1, 3, 27, 3, 12, 35, 1, 22, 39, 21, 18]
-----------------------------
[0, 3, 3, 3, 16, 22, 5, 17, 4, 23, 1, 4, 23, 2, 5, 18]
[0, 3, 3, 19, 16, 11, 5, 10, 4, 21, 6,

In [55]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, observations[:5][index])) + ']')
    print('-----------------------------')

[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 8, 8, 13, 6, 8, 8, 14, 7, 15, 0, 8, 16, 17, 9, 8, 18, 8, 19, 17, 7, 20, 5, 8, 17, 21, 22, 8, 8, 17, 23, 8, 24, 25, 8, 8, 26]
-----------------------------
[-1, 8, 29, 30, 31, 32, 33, 34, 35, 36, 37, 25, 38, 39, 10, 40, 8, 30, 0, 41, 26]
-----------------------------
[-1, 7, 42, 43, 44, 45, 46, 47, 8, 30, 0, 7, 36]
-----------------------------
[-1, 48, 8, 17, 49, 50, 51, 52, 53, 17, 32, 54, 8, 55, 56, 7, 42, 43, 44, 0, 57, 50, 58, 17, 50, 0, 59, 60, 61, 26]
-----------------------------
[-1, 62, 63, 31, 64, 35, 65, 25, 66, 67, 68, 69, 67, 70, 71, 26]
-----------------------------


In [56]:
file_path = f"../data/PennTreebank_synthetic_dataset(noise-{ptb_noisy_level}).npz"
obs_object = np.array(observations, dtype=object)
uni_hid_object = np.array(hidden_states_universal, dtype=object)
noisy_uni_hid_object = np.array(noisy_hidden_states_universal, dtype=object)
spc_hid_object = np.array(hidden_states_specific, dtype=object)
noisy_spc_hid_object = np.array(noisy_hidden_states_specific, dtype=object)
np.savez(file_path, num_states=len(upos_to_index) + 1, num_obs=len(word_to_index), observation=obs_object, real_hidden_universal=uni_hid_object, noisy_hidden_universal=noisy_uni_hid_object, real_hidden_specific=spc_hid_object, noisy_hidden_specifc=noisy_spc_hid_object, noisy_level=ptb_noisy_level)

In [57]:
read_npz = np.load("../data/PennTreebank_synthetic_dataset(noise-0.3).npz")
read_npz['num_obs'], read_npz['num_states']

(array(4110), array(19))

In [31]:
upos_to_index['UNK_TAG']

7