In [1]:
import nltk
import numpy as np
# Download the Brown corpus and the universal tagset
nltk.download("brown")
nltk.download('universal_tagset')

[nltk_data] Downloading package brown to /home/labuser/nltk_data...
[nltk_data]   Package brown is already up-to-date!
[nltk_data] Downloading package universal_tagset to
[nltk_data]     /home/labuser/nltk_data...
[nltk_data]   Package universal_tagset is already up-to-date!


True

In [2]:
from nltk.corpus import brown

# Get the tagged sentences from the Brown corpus
tagged_sentences = brown.tagged_sents(tagset="universal")
tagged_sentences = list(tagged_sentences)


In [3]:
folds = []

num_sentences = len(tagged_sentences)

for i in range(5):
    folds.append(tagged_sentences[i * int(num_sentences / 5) : (i + 1) * int(num_sentences / 5)])

In [10]:
"""---------------------GLOBAL VARIABLES------------------------"""
states = set()
words = set()
for sentence in tagged_sentences:
    for word in sentence:
        states.add(word[1])
        words.add(word[0])


state_index_mapping = {}
word_index_mapping = {}


state_index_mapping["^"] = 0
word_index_mapping["~epsilon~"] = 0

i = 1
for state in states:
    state_index_mapping[state] = i
    i += 1
state_index_mapping["~END~"] = i
    
i = 1
for word in words:
    word_index_mapping[word] = i
    i += 1
    
n_states = len(states) + 2
n_words = len(words) + 1
    
print(state_index_mapping)
"""---------------------FUNCTIONS------------------------"""


def make_set(folds, fold_num):
    train_sentences = []
    test_sentences = []
    for i in range(5):
        if ( i != fold_num ):
            train_sentences.extend(folds[i])
        else:
            test_sentences.extend(folds[i])
    
    return train_sentences, test_sentences

def get_matrices(train_sentences):
    
    transition_matrix = np.ones((n_states, n_states))
    emission_matrix = np.ones((n_states, n_words))

    for sentence in train_sentences:
        transition_matrix[0][state_index_mapping[sentence[0][1]]] += 1 # for the start
        emission_matrix[0][0] += 1
        for i in range(len(sentence) - 1):
            cur_word = sentence[i]
            next_word = sentence[i + 1]
            row = state_index_mapping[cur_word[1]]
            col = state_index_mapping[next_word[1]]
            emission_matrix[row][word_index_mapping[cur_word[0]]] += 1

            transition_matrix[row][col] += 1
        emission_matrix[col][word_index_mapping[next_word[0]]] += 1
        end_state = state_index_mapping["~END~"]
        epsilon = word_index_mapping["~epsilon~"]
        transition_matrix[col][end_state] += 1
        transition_matrix[end_state][epsilon] += 1
    

    transition_matrix = transition_matrix / np.sum(transition_matrix, axis = 1)
    emission_matrix = emission_matrix / np.sum(emission_matrix, axis = 1, keepdims=True)

    return transition_matrix, emission_matrix



{'^': 0, 'CONJ': 1, 'PRON': 2, 'NUM': 3, 'ADV': 4, '.': 5, 'NOUN': 6, 'VERB': 7, 'DET': 8, 'ADJ': 9, 'ADP': 10, 'X': 11, 'PRT': 12, '~END~': 13}


In [33]:
def Viterbi(output_sequence: list, transition_matrix: np.ndarray, emission_matrix: np.ndarray):
    N = transition_matrix.shape[0]
    T = len(output_sequence)
    SEQSCORE = np.zeros((N, T))
    BACKPTR = np.zeros((N, T))
    C = np.zeros(T)
    SEQSCORE[0, 0] = 1

    for t in range(1, T):
        for i in range(N):
            max_j = 0
            max_val = 0
            for j in range(N):
                temp = SEQSCORE[j, t - 1] * transition_matrix[j][i] * emission_matrix[j][word_index_mapping[output_sequence[t][0]]]
                if max_val < temp: 
                    max_j = j
                    max_val = temp
            BACKPTR[i, t] = max_j

    C[T-1] = np.argmax(SEQSCORE[:, T - 1])
    for i in range(T-2, -1, -1):
        BACKPTR[int(C[i + 1]), i + 1]
    return C


In [35]:
for fold_num in range(5):
    train_sent, test_sent = make_set(folds, fold_num)
    transition_matrix, emission_matrix = get_matrices(train_sent)
    print(transition_matrix[4][4])
    # print(Viterbi(test_sent[0], transition_matrix, emission_matrix))
    # break

0.09657258954993155
0.09576248191027496
0.09556491800308055
0.09694394293011581
0.09920790355572964
