## Installing libraries

In [1]:
import pandas as pd
import ast
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score

## Global variables

In [2]:
tag_set = set(['EX', 'RP', 'RBR', '-NONE-', 'UH', 'VBZ', 'VBG', '$', 'MD', 'TO', 'WP', 'WP$', '.', 'PRP', 'PDT', '#', 'POS', 'VBN', '-RRB-', 'DT', "''", ':', 'RBS', 'JJR', 'IN',
',', 'VBD', 'LS', 'JJS', 'WRB', 'VBP', '-LRB-', 'NNP', 'NNS', 'PRP$', 'JJ', 'CC', 'FW', 'CD', 'VB', 'NN', 'NNPS', 'SYM', 'WDT', '``', 'RB'])

P_given = {}
P_given_freq = {} # (Tag, Tag)

tag_freq = {} # Tag

word_freq = {} # Word
word_tag_freq = {} # (Word, Tag)
word_tag_prob = {}

vocab = set() # Words


## Util functions

In [3]:
def process_sentence(sent, tags):
  for word in sent:
    vocab.add(word)
    word_freq[word] = word_freq.get(word, 0) + 1

  for idx, (word, tag) in enumerate(zip(sent, tags)):
    tag_freq[tag] = tag_freq.get(tag, 0) + 1
    word_tag_freq[(word,tag)] = word_tag_freq.get((word,tag), 0) + 1
    if (idx > 0):
      prev_tag = tags[idx - 1]
      P_given_freq[(tag, prev_tag)] = P_given_freq.get( (tag, prev_tag), 0) + 1

def get_word_tag_prob(word, tag):
	if word not in vocab:
		for tag in tag_set:
			word_tag_prob[word, tag] = 1.0 / len(tag_set)
		vocab.add(word)

	return word_tag_prob.get((word, tag), 0)

## Function to train the model

In [4]:
def train(data):
  for sent, tags in zip(data['tokenized_sentences'], data['tags']):
    process_sentence(sent, tags)

  for prev_tag in tag_set:
    for cur_tag in tag_set:
      if prev_tag not in tag_freq:
        P_given[cur_tag, prev_tag] = 1.0 / len(tag_set)
      else:
        try:
          P_given[cur_tag, prev_tag] = P_given_freq[cur_tag, prev_tag] / tag_freq[prev_tag]
        except KeyError:
          P_given[cur_tag, prev_tag] = 0

  for word in vocab:
    for tag in tag_set:
      try:
        word_tag_prob[word, tag] = word_tag_freq[word, tag] / word_freq[word]
      except KeyError:
        word_tag_prob[word, tag] = 0

## Function to predict on given sentence

In [5]:
def predict(sent):
	prev_state = {}
	P = {}
	for idx, word in enumerate(sent):
		P_new = {}
		if idx == 0:
			for cur_tag in tag_set:
				P_new[cur_tag] = get_word_tag_prob(word, cur_tag)
		else:
			for prev_tag in tag_set:
				for cur_tag in tag_set:
					# How may cur_tag occur after prev_tag?
					prob = P[prev_tag]
					prob *= P_given.get( (cur_tag, prev_tag), 0)  # get_P_given(cur_tag, prev_tag)
					prob *= get_word_tag_prob(word, cur_tag)

					if cur_tag not in P_new or prob > P_new[cur_tag]:
						P_new[cur_tag] = prob
						prev_state[cur_tag, idx] = prev_tag
		P = P_new
	
	final_tag = None
	for tag in tag_set:
		if final_tag is None or P[tag] > P[final_tag]:
			final_tag = tag
	pred = []
	pred.append(final_tag)
	cur_tag = final_tag
	for idx in range(len(sent) - 1, 0, -1):
		cur_tag = prev_state[cur_tag, idx]
		pred.append(cur_tag)
	
	pred.reverse()
	return pred

## Functions to calculate accuracy and class-wise accuracy

In [6]:
def calculate_accuracy(predicted_tag_list, tag_list):
  total_matches = 0
  total_word_count = 0
  for idx, pred_tag in enumerate(predicted_tag_list):
      total_matches = total_matches + accuracy_score(tag_list[idx], pred_tag, normalize=False)
      total_word_count = total_word_count + len(pred_tag)
  accuracy = total_matches / total_word_count
  return accuracy
      
def calculate_class_accuracy(predicted_tag_list, tag_list):
  tag_count = {}
  tag_match_count = {}
  tag_wise_freq = {}
  for idx, pred_tag_l in enumerate(predicted_tag_list):
    for jdx, pred_tag in enumerate(pred_tag_l):
      tag_count[tag_list[idx][jdx]] = tag_count.get( tag_list[idx][jdx], 0) + 1
      if pred_tag == tag_list[idx][jdx]:
        tag_match_count[pred_tag] = tag_match_count.get(pred_tag, 0 ) + 1
  
  for tag in tag_match_count:
    tag_wise_freq[tag] = tag_match_count[tag] / tag_count[tag]
  
  return tag_wise_freq

## Processing Corpus

In [7]:

data = pd.read_csv("WSJ_treebank_corpus.csv", converters={0:ast.literal_eval, 1:ast.literal_eval})
max_accuracy = 0
class_accuracy = {}

# prepare cross validation
kfold = KFold(3, True, 1)

# enumerate splits
for trn, test in kfold.split(data):
  
  # clear previous runs
  word_freq.clear()
  tag_freq.clear()
  word_tag_freq.clear()
  vocab.clear()
  P_given_freq.clear()
  word_tag_prob.clear()
  P_given.clear()

  pred = [] 
  tag_list = [] 

  # train model for current fold
  train(data.iloc[trn])
  
  # validation for current fold
  for idx, row in data.iloc[test].iterrows():
    pred.append(predict(row['tokenized_sentences']))
    tag_list.append(row['tags'])
  
  # calculate accuracy of current fold
  accuracy = calculate_accuracy(pred, tag_list)

  print("Train Sample Size \t: ", len(trn))
  print("Test Sample Size \t: ", len(test))
  print("Accuracy \t: ", accuracy)
  print("\n\n")

  if accuracy > max_accuracy:
    max_accuracy = accuracy    
    class_accuracy = calculate_class_accuracy(pred, tag_list)
  
print("Class-wise Accuracy (Fold Accuracy : ", max_accuracy, ")")

for key, val in class_accuracy.items():
  print(key, "\t", val)

Train Sample Size 	:  2609
Test Sample Size 	:  1305
Accuracy 	:  0.8880385658330823



Train Sample Size 	:  2609
Test Sample Size 	:  1305
Accuracy 	:  0.8872023809523809



Train Sample Size 	:  2610
Test Sample Size 	:  1304
Accuracy 	:  0.8803045505518503



Class-wise Accuracy (Fold Accuracy :  0.8880385658330823 )
DT 	 0.9575489110372831
NN 	 0.9064202795546079
IN 	 0.9572257867399939
RB 	 0.8211640211640212
VBN 	 0.7371967654986523
-NONE- 	 0.9443684450524396
TO 	 0.9601139601139601
VB 	 0.7972027972027972
NNP 	 0.9144098963557339
NNS 	 0.7994157740993184
VBZ 	 0.8432432432432433
JJ 	 0.7885010266940452
PRP 	 0.9583333333333334
CD 	 0.8269742679680568
, 	 0.9725
VBD 	 0.8669396110542477
. 	 1.0
VBG 	 0.551440329218107
JJS 	 0.8709677419354839
POS 	 0.9372693726937269
`` 	 0.9555555555555556
CC 	 0.9680998613037448
WP 	 0.9102564102564102
'' 	 0.9515418502202643
JJR 	 0.8088235294117647
VBP 	 0.7221006564551422
MD 	 0.9003115264797508
WDT 	 0.4423076923076923
PRP$ 	 0.9523809523