### Create Synthetic Dateset on Penn Treebank

In [1]:
import numpy as np
from collections import Counter
from collections import defaultdict

In [2]:
def read_sentences_with_pos_tags(path):
    sentences_with_pos_tags = []

    with open(path, 'r', encoding='utf-8') as file:
        current_sentence = []
        for line in file:
            if line.strip() and not line.startswith('#'):   # Skip empty lines and comments
                fields = line.split('\t')
                if len(fields) > 3:                         # Ensure there are enough fields
                    word = fields[1].lower()                # Word form is the second field
                    upos = fields[3]                        # Universal POS tag is the fourth field
                    xpos = fields[4]                        # Language specific POS tag is the fifth field
                    if upos not in ['NUM', 'PUNCT', 'INTJ', 'X']:  # Remove punctuation, interjection, and undefined pos tags
                        current_sentence.append((word, upos, xpos))

            # New sentence
            elif current_sentence:
                sentences_with_pos_tags.append(current_sentence)
                current_sentence = []

    return sentences_with_pos_tags

In [7]:
file_path1 = '../../../data/ptb/penn-train.conllu'
file_path2 = '../data/ptb/penn-test.conllu'
file_path3 = '../data/ptb/penn-dev.conllu'

In [8]:
sentences_pos_tags = read_sentences_with_pos_tags(file_path1)
print(len(sentences_pos_tags))
sentences_pos_tags[:3]

39821


[[('in', 'ADP', 'IN'),
  ('an', 'DET', 'DT'),
  ('oct.', 'PROPN', 'NNP'),
  ('review', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('the', 'DET', 'DT'),
  ('misanthrope', 'NOUN', 'NN'),
  ('at', 'ADP', 'IN'),
  ('chicago', 'PROPN', 'NNP'),
  ("'s", 'PART', 'POS'),
  ('goodman', 'PROPN', 'NNP'),
  ('theatre', 'PROPN', 'NNP'),
  ('revitalized', 'VERB', 'VBN'),
  ('classics', 'NOUN', 'NNS'),
  ('take', 'VERB', 'VBP'),
  ('the', 'DET', 'DT'),
  ('stage', 'NOUN', 'NN'),
  ('in', 'ADP', 'IN'),
  ('windy', 'PROPN', 'NNP'),
  ('city', 'PROPN', 'NNP'),
  ('leisure', 'NOUN', 'NN'),
  ('&', 'CONJ', 'CC'),
  ('arts', 'NOUN', 'NNS'),
  ('the', 'DET', 'DT'),
  ('role', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('celimene', 'PROPN', 'NNP'),
  ('played', 'VERB', 'VBN'),
  ('by', 'ADP', 'IN'),
  ('kim', 'PROPN', 'NNP'),
  ('cattrall', 'PROPN', 'NNP'),
  ('was', 'AUX', 'VBD'),
  ('mistakenly', 'ADV', 'RB'),
  ('attributed', 'VERB', 'VBN'),
  ('to', 'ADP', 'TO'),
  ('christina', 'PROPN', 'NNP'),
  ('haag', 'PR

In [10]:
def replace_low_frequency_words(sentences_with_pos_tags, filter_count=1):
    # Count the frequencies of each word
    word_counts = Counter(word for sentence in sentences_with_pos_tags for word, _, _ in sentence)

    # Replace words with count less than filter_count to 'UNK' and their tags to 'UNK_TAG'
    processed_sentences = []
    for sentence in sentences_with_pos_tags:
        new_sentence = []
        for word, upos, xpos in sentence:
            if word_counts[word] < filter_count:
                new_word = 'UNK'                # Only set the word to UNK
                new_upos = upos
                new_xpos = xpos
            else:
                new_word = word
                new_upos = upos
                new_xpos = xpos
            new_sentence.append((new_word, new_upos, new_xpos))
        processed_sentences.append(new_sentence)

    return processed_sentences

In [11]:
filtered_sentences = replace_low_frequency_words(sentences_pos_tags, filter_count=10)
filtered_sentences[:3]

[[('in', 'ADP', 'IN'),
  ('an', 'DET', 'DT'),
  ('oct.', 'PROPN', 'NNP'),
  ('review', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('the', 'DET', 'DT'),
  ('UNK', 'NOUN', 'NN'),
  ('at', 'ADP', 'IN'),
  ('chicago', 'PROPN', 'NNP'),
  ("'s", 'PART', 'POS'),
  ('UNK', 'PROPN', 'NNP'),
  ('UNK', 'PROPN', 'NNP'),
  ('UNK', 'VERB', 'VBN'),
  ('UNK', 'NOUN', 'NNS'),
  ('take', 'VERB', 'VBP'),
  ('the', 'DET', 'DT'),
  ('stage', 'NOUN', 'NN'),
  ('in', 'ADP', 'IN'),
  ('UNK', 'PROPN', 'NNP'),
  ('city', 'PROPN', 'NNP'),
  ('UNK', 'NOUN', 'NN'),
  ('&', 'CONJ', 'CC'),
  ('arts', 'NOUN', 'NNS'),
  ('the', 'DET', 'DT'),
  ('role', 'NOUN', 'NN'),
  ('of', 'ADP', 'IN'),
  ('UNK', 'PROPN', 'NNP'),
  ('played', 'VERB', 'VBN'),
  ('by', 'ADP', 'IN'),
  ('UNK', 'PROPN', 'NNP'),
  ('UNK', 'PROPN', 'NNP'),
  ('was', 'AUX', 'VBD'),
  ('UNK', 'ADV', 'RB'),
  ('attributed', 'VERB', 'VBN'),
  ('to', 'ADP', 'TO'),
  ('UNK', 'PROPN', 'NNP'),
  ('UNK', 'PROPN', 'NNP')],
 [('ms.', 'PROPN', 'NNP'),
  ('UNK', 'PROPN'

In [12]:
def create_vocab_index(sentences_with_pos_tags):
    # Function to create a dictionary mapping each unique word/POS to an integer index
    # with specified start index
    def build_index(items, start_index=0):
        item_to_index = defaultdict(lambda: len(item_to_index) + start_index)
        for item in items:
            item_to_index[item]
        return dict(item_to_index)

    all_words = [word for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]
    all_upos_tags = [upos for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]
    all_xpos_tags = [xpos for sentence in sentences_with_pos_tags for word, upos, xpos in sentence]

    return build_index(all_words, start_index=0), build_index(all_upos_tags, start_index=1),  build_index(all_xpos_tags, start_index=1)

In [13]:
word_to_index, upos_to_index, xpos_to_index = create_vocab_index(filtered_sentences)
print(len(word_to_index))
list(word_to_index.items())[10:15], list(upos_to_index.items())

6466


([('take', 10), ('stage', 11), ('city', 12), ('&', 13), ('arts', 14)],
 [('ADP', 1),
  ('DET', 2),
  ('PROPN', 3),
  ('NOUN', 4),
  ('PART', 5),
  ('VERB', 6),
  ('CONJ', 7),
  ('AUX', 8),
  ('ADV', 9),
  ('PRON', 10),
  ('ADJ', 11),
  ('SCONJ', 12),
  ('SYM', 13)])

In [14]:
def count_tags(tag, sentences_with_pos_tags):
    count = 0
    for sentence in sentences_with_pos_tags: 
        for _, upos, _ in sentence:
            if upos == tag: 
                # print(sentence)
                count += 1
                
    return count

In [15]:
len(word_to_index), len(upos_to_index), len(xpos_to_index)

(6466, 13, 34)

In [16]:
for tag, _ in upos_to_index.items():
    print(count_tags(tag, filtered_sentences))

97030
80893
94139
187112
25869
106967
23947
28196
31211
35590
66402
13503
12504


In [17]:
def convert_to_indexes(filtered_sentences_tags, word_to_index, upos_to_index, xpos_to_index):
    hidden_states_universal = []
    hidden_states_specific = []
    observations = []

    for sentence in filtered_sentences_tags:
        if len(sentence) <= 5: 
            continue
        sentence_upos_indexes = [upos_to_index[upos] for _, upos, _ in sentence]
        sentence_xpos_indexes = [xpos_to_index[xpos] for _, _, xpos in sentence]
        sentence_word_indexes = [word_to_index[word] for word, _, _ in sentence]

        hidden_states_universal.append(sentence_upos_indexes)
        hidden_states_specific.append(sentence_xpos_indexes)
        observations.append(sentence_word_indexes)

    return hidden_states_universal, hidden_states_specific, observations

In [18]:
hidden_states_universal, hidden_states_specific, observations = convert_to_indexes(
    filtered_sentences, word_to_index, upos_to_index, xpos_to_index)

In [19]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, hidden_states_universal[:5][index])) + ']')
    print('[' + ', '.join(map(str, hidden_states_specific[:5][index])) + ']')
    print('[' + ', '.join(map(str, observations[:5][index])) + ']')
    print('-----------------------------')

[1, 2, 3, 4, 1, 2, 4, 1, 3, 5, 3, 3, 6, 4, 6, 2, 4, 1, 3, 3, 4, 7, 4, 2, 4, 1, 3, 6, 1, 3, 3, 8, 9, 6, 1, 3, 3]
[1, 2, 3, 4, 1, 2, 4, 1, 3, 5, 3, 3, 6, 7, 8, 2, 4, 1, 3, 3, 4, 9, 7, 2, 4, 1, 3, 6, 1, 3, 3, 10, 11, 6, 12, 3, 3]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 6, 6, 6, 6, 10, 5, 11, 0, 6, 12, 6, 13, 14, 5, 15, 4, 6, 16, 17, 6, 6, 18, 6, 19, 20, 6, 6]
-----------------------------
[3, 3, 3, 3, 6, 10, 6, 10, 3, 4, 5, 6, 11, 1, 1, 4, 1]
[3, 3, 14, 3, 10, 15, 13, 16, 3, 7, 12, 17, 18, 1, 1, 7, 1]
[6, 23, 24, 25, 26, 27, 28, 29, 30, 31, 20, 32, 33, 7, 34, 24, 0]
-----------------------------
[2, 4, 4, 4, 11, 4, 6, 4, 1, 2, 3]
[2, 4, 4, 4, 18, 4, 10, 7, 1, 2, 3]
[5, 35, 36, 37, 38, 39, 40, 24, 0, 5, 30]
-----------------------------
[3, 3, 4, 7, 11, 4, 4, 6, 10, 6, 4, 1, 2, 4, 4, 4, 1, 3, 7, 3, 7, 1, 11, 11, 4]
[3, 3, 4, 9, 18, 4, 4, 10, 15, 13, 4, 1, 2, 4, 4, 4, 1, 3, 9, 3, 9, 1, 18, 18, 7]
[41, 6, 42, 43, 44, 45, 46, 26, 47, 48, 49, 50, 5, 35, 36, 37, 0, 51, 43, 52, 43, 0, 53, 54, 55]
-------

In [20]:
def add_noise_to_states_ptb(hidden_states, number_states, flip_prob=0.5):
    noisy_hidden_states = []
    for sequence in hidden_states:
        noisy_sequence = []
        for state in sequence:
            if np.random.rand() < flip_prob:
                # Flip the state to a different random state
                possible_states = list(range(1, number_states + 1))
                possible_states.remove(state)  # Remove the current state from possibilities
                new_state = np.random.choice(possible_states)
                noisy_sequence.append(new_state)
            else:
                noisy_sequence.append(state)
        noisy_hidden_states.append(noisy_sequence)
    return noisy_hidden_states

In [21]:
ptb_noisy_level = 0.4
noisy_hidden_states_universal = add_noise_to_states_ptb(hidden_states_universal, len(upos_to_index), flip_prob=ptb_noisy_level)
noisy_hidden_states_specific = add_noise_to_states_ptb(hidden_states_specific, len(xpos_to_index), flip_prob=ptb_noisy_level)

In [22]:
for i in range(len(hidden_states_universal)):
    hidden_states_universal[i].insert(0, 0)
    noisy_hidden_states_universal[i].insert(0, 0)

    hidden_states_specific[i].insert(0, 0)
    noisy_hidden_states_specific[i].insert(0, 0)
    
    observations[i].insert(0, -1)

In [23]:
for index in range(len(observations[20:25])):
    print('[' + ', '.join(map(str, hidden_states_universal[:5][index])) + ']')
    print('[' + ', '.join(map(str, noisy_hidden_states_universal[:5][index])) + ']')
    print('-----------------------------')

[0, 1, 2, 3, 4, 1, 2, 4, 1, 3, 5, 3, 3, 6, 4, 6, 2, 4, 1, 3, 3, 4, 7, 4, 2, 4, 1, 3, 6, 1, 3, 3, 8, 9, 6, 1, 3, 3]
[0, 2, 13, 6, 4, 1, 2, 4, 1, 3, 5, 3, 3, 13, 4, 3, 2, 4, 11, 10, 3, 12, 7, 4, 2, 4, 1, 3, 6, 1, 3, 1, 8, 9, 6, 1, 1, 3]
-----------------------------
[0, 3, 3, 3, 3, 6, 10, 6, 10, 3, 4, 5, 6, 11, 1, 1, 4, 1]
[0, 3, 3, 2, 9, 6, 8, 6, 6, 3, 4, 7, 6, 2, 13, 1, 4, 1]
-----------------------------
[0, 2, 4, 4, 4, 11, 4, 6, 4, 1, 2, 3]
[0, 10, 10, 4, 8, 11, 2, 11, 4, 1, 12, 3]
-----------------------------
[0, 3, 3, 4, 7, 11, 4, 4, 6, 10, 6, 4, 1, 2, 4, 4, 4, 1, 3, 7, 3, 7, 1, 11, 11, 4]
[0, 3, 3, 4, 2, 11, 7, 3, 3, 10, 6, 2, 5, 2, 10, 12, 4, 1, 3, 7, 3, 7, 1, 11, 11, 6]
-----------------------------
[0, 3, 3, 3, 6, 10, 4, 1, 4, 1, 4, 2, 4]
[0, 3, 3, 3, 6, 10, 4, 9, 12, 1, 4, 2, 7]
-----------------------------


In [24]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, hidden_states_specific[:5][index])) + ']')
    print('[' + ', '.join(map(str, noisy_hidden_states_specific[:5][index])) + ']')
    print('-----------------------------')

[0, 1, 2, 3, 4, 1, 2, 4, 1, 3, 5, 3, 3, 6, 7, 8, 2, 4, 1, 3, 3, 4, 9, 7, 2, 4, 1, 3, 6, 1, 3, 3, 10, 11, 6, 12, 3, 3]
[0, 1, 17, 16, 4, 1, 2, 9, 28, 3, 5, 3, 8, 6, 7, 8, 2, 4, 1, 3, 3, 4, 9, 7, 2, 16, 34, 3, 6, 1, 10, 3, 10, 11, 32, 12, 19, 10]
-----------------------------
[0, 3, 3, 14, 3, 10, 15, 13, 16, 3, 7, 12, 17, 18, 1, 1, 7, 1]
[0, 3, 3, 14, 5, 33, 15, 31, 16, 31, 2, 12, 17, 31, 1, 1, 7, 7]
-----------------------------
[0, 2, 4, 4, 4, 18, 4, 10, 7, 1, 2, 3]
[0, 23, 4, 4, 33, 18, 4, 6, 7, 1, 2, 10]
-----------------------------
[0, 3, 3, 4, 9, 18, 4, 4, 10, 15, 13, 4, 1, 2, 4, 4, 4, 1, 3, 9, 3, 9, 1, 18, 18, 7]
[0, 3, 3, 25, 9, 18, 4, 4, 10, 15, 13, 21, 1, 17, 6, 4, 4, 1, 3, 9, 3, 9, 1, 18, 34, 11]
-----------------------------
[0, 3, 3, 3, 10, 16, 4, 12, 7, 1, 7, 2, 4]
[0, 30, 3, 34, 10, 21, 4, 8, 7, 1, 16, 2, 4]
-----------------------------


In [25]:
for index in range(len(observations[:5])):
    print('[' + ', '.join(map(str, observations[:5][index])) + ']')
    print('-----------------------------')

[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 6, 6, 6, 6, 10, 5, 11, 0, 6, 12, 6, 13, 14, 5, 15, 4, 6, 16, 17, 6, 6, 18, 6, 19, 20, 6, 6]
-----------------------------
[-1, 6, 23, 24, 25, 26, 27, 28, 29, 30, 31, 20, 32, 33, 7, 34, 24, 0]
-----------------------------
[-1, 5, 35, 36, 37, 38, 39, 40, 24, 0, 5, 30]
-----------------------------
[-1, 41, 6, 42, 43, 44, 45, 46, 26, 47, 48, 49, 50, 5, 35, 36, 37, 0, 51, 43, 52, 43, 0, 53, 54, 55]
-----------------------------
[-1, 56, 57, 25, 58, 29, 59, 20, 60, 61, 60, 62, 63]
-----------------------------


In [27]:
file_path = f"../../../data/PTB_synthetic_dataset(noise-{ptb_noisy_level}).npz"
obs_object = np.array(observations, dtype=object)
uni_hid_object = np.array(hidden_states_universal, dtype=object)
noisy_uni_hid_object = np.array(noisy_hidden_states_universal, dtype=object)
spc_hid_object = np.array(hidden_states_specific, dtype=object)
noisy_spc_hid_object = np.array(noisy_hidden_states_specific, dtype=object)
np.savez(file_path, num_states=len(upos_to_index) + 1, num_obs=len(word_to_index), observation=obs_object, real_hidden_universal=uni_hid_object, noisy_hidden_universal=noisy_uni_hid_object, real_hidden_specific=spc_hid_object, noisy_hidden_specifc=noisy_spc_hid_object, noisy_level=ptb_noisy_level)

In [129]:
read_npz = np.load("../data/PennTreebank_synthetic_dataset(noise-0.8).npz")
read_npz['num_obs'], read_npz['num_states']

(array(6466), array(14))

In [31]:
upos_to_index['UNK_TAG']

7

In [131]:
read_result = np.load("../data/ptb-noise-0.7_iter-20_timestamp-0118_222152_result.npz")
read_result['result']

array([55696.93966458, 48077.46788257, 40796.52998724, 40882.63262071,
       48252.11934206, 56556.51708689, 61556.06152119, 63970.7717321 ,
       65471.6533929 , 66335.87011565, 66808.8515842 , 67165.15026411,
       67429.15459948, 67580.05554896, 67665.67502065, 67821.98475421,
       67988.70769473, 68163.15558276, 68297.49185731, 68308.82919506])

In [3]:
np.sum([1, 2, 3] != [1, 2, 4])

1