In [10]:
import nltk
from nltk.metrics import jaccard_distance
from scipy.stats import pearsonr
from nltk.wsd import lesk
from nltk import pos_tag
from nltk.corpus import wordnet as wn
from nltk.stem import WordNetLemmatizer
#nltk.download('punkt')
#nltk.download('wordnet')
#nltk.download('universal_tagset')

[nltk_data] Downloading package universal_tagset to
[nltk_data]     /home2/users/alumnes/1202114/nltk_data...
[nltk_data]   Unzipping taggers/universal_tagset.zip.


True

In [34]:
with open('trial/STS.input.txt') as fp:
    data = fp.readlines()
    
with open('trial/STS.gs.txt') as e:
    gs = e.readlines()

def nltk_pos_to_wordnet_pos(nltk_pos):
    mapping = {'NOUN': wn.NOUN, 'ADJ': wn.ADJ, 'VERB': wn.VERB, 'ADP': wn.ADV}
    if nltk_pos in mapping:
        return mapping[nltk_pos]
    else:
        return None
    
def get_synsets(context):
    words = [word for (word, pos) in context]
    return {lesk(words, w, p) for w, p in context if lesk(words, w, p)}


def get_pos(sent1):
    tok_sent1 = nltk.word_tokenize(sent1)
    pos_sent1 = pos_tag(tok_sent1, tagset='universal')
    w_pos_sent1 = [(word, nltk_pos_to_wordnet_pos(pos)) for (word, pos) in pos_sent1]
    filtered_pos_sent1 = [(word, pos) for (word, pos) in w_pos_sent1 if pos is not None]
    return filtered_pos_sent1

def eval_synsets(sent1, sent2):
    pos_sent1 = get_pos(sent1)
    pos_sent2 = get_pos(sent2)
    synsets1 = get_synsets(pos_sent1)
    synsets2 = get_synsets(pos_sent2)
    return jaccard_distance(synsets1, synsets2)

def eval_definitions(sent1, sent2):
    pos_sent1 = get_pos(sent1)
    pos_sent2 = get_pos(sent2)
    synsets1 = get_synsets(pos_sent1)
    synsets2 = get_synsets(pos_sent2)
    definitions1 = set([])
    for synset in synsets1:
        for word in nltk.word_tokenize(synset.definition()):
            definitions1.add(word)
    definitions2 = set([])
    for synset in synsets2:
        for word in nltk.word_tokenize(synset.definition()):
            definitions2.add(word)
        
    return jaccard_distance(definitions1, definitions2)
    

jaccard_synsets = []
jaccard_definitions = []
gold = []
for index, line in enumerate(data):
    (num, sent1, sent2) = line.split('\t')
    jaccard_synsets.append(eval_synsets(sent1, sent2))
    jaccard_definitions.append(eval_definitions(sent1, sent2))
    gold.append(int(gs[index].split('\t')[1][0]))

# 0.490670810375692


print('pearson correlation between gold and jaccard distance using synsets:', pearsonr(gold, jaccard_synsets)[0])
print('pearson correlation between gold and jaccard distance using definition words:'pearsonr(gold, jaccard_definitions)[0])

0.4697360281253835
0.4539469938971054
