Skip to content

Commit

Permalink
added tsne visualizations
Browse files Browse the repository at this point in the history
  • Loading branch information
Marten Postma committed Jan 19, 2018
1 parent 6ee8d22 commit 1b491b2
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 20 deletions.
74 changes: 57 additions & 17 deletions evaluate/perform_wsd.py
Expand Up @@ -11,6 +11,7 @@
import morpho_utils
import tensor_utils as utils
import score_utils
import tsne_utils

parser = argparse.ArgumentParser(description='Perform WSD using LSTM model')
parser.add_argument('-m', dest='model_path', required=True, help='path to model trained LSTM model')
Expand Down Expand Up @@ -150,6 +151,7 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc
"""
highest_synsets = []
highest_conf = 0.0
synset_std = None
candidate_freq = dict()
strategy = 'lstm'

Expand All @@ -166,43 +168,51 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc

if candidate not in sense_embeddings:
#print('%s %s %s: candidate %s missing in sense embeddings' % (instance_id, lemma, pos, candidate))
continue
continue

cand_embedding = sense_embeddings[candidate]
cand_embedding, cand_std = sense_embeddings[candidate]
sim = 1 - spatial.distance.cosine(cand_embedding, target_embedding)

potentially_added_synset = (synset, cand_std)

if sim == highest_conf:
highest_synsets.append(synset)
highest_synsets.append(potentially_added_synset)
elif sim > highest_conf:
highest_synsets = [synset]
highest_synsets = [potentially_added_synset]
highest_conf = sim

if len(highest_synsets) == 1:
highest_synset = highest_synsets[0]
highest_synset, synset_std = highest_synsets[0]
elif len(highest_synsets) >= 2:
highest_synset = highest_synsets[0]
highest_synset, synset_std = highest_synsets[0]
#print('%s %s %s: 2> synsets with same conf %s: %s' % (instance_id, lemma, pos, highest_conf, highest_synsets))
else:
if args.mfs_fallback:
highest_synset = candidate_synsets[0]
synset_std = None
#print('%s: no highest synset -> mfs' % instance_id)
strategy = 'mfs_fallback'
else:
highest_synset = None
return highest_synset, candidate_freq, strategy
synset_std = None
return highest_synset, synset_std, candidate_freq, strategy


# load wsd competition dataframe
wsd_df = pandas.read_pickle(args.wsd_df_path)

# add output column
wsd_df['lstm_output'] = [None for _ in range(len(wsd_df))]
wsd_df['target_embedding'] = [None for _ in range(len(wsd_df))]
wsd_df['std_chosen_synset'] = [None for _ in range(len(wsd_df))]
wsd_df['lstm_acc'] = [None for _ in range(len(wsd_df))]
wsd_df['emb_freq'] = [None for _ in range(len(wsd_df))]
wsd_df['#_cand_synsets'] = [None for _ in range(len(wsd_df))]
wsd_df['#_new_cand_synsets'] = [None for _ in range(len(wsd_df))]
wsd_df['gold_in_new_cand_synsets'] = [None for _ in range(len(wsd_df))]
wsd_df['wsd_strategy'] = [None for _ in range(len(wsd_df))]
wsd_df['only_one_embedding'] = [None for _ in range(len(wsd_df))]
wsd_df['has_gold_embedding'] = [None for _ in range(len(wsd_df))]

# load sense embeddings
with open(args.sense_embeddings_path, 'rb') as infile:
Expand Down Expand Up @@ -240,8 +250,6 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc
pos = row['pos']
if the_wn_version in {'171'}:
pos = None


candidate_synsets, \
new_candidate_synsets, \
gold_in_candidates = morpho_utils.candidate_selection(wn,
Expand All @@ -259,8 +267,6 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc
the_chosen_candidates = [synset2identifier(synset, wn_version=the_wn_version)
for synset in new_candidate_synsets]

print()
print(the_chosen_candidates, gold_in_candidates)
# get mapping to higher abstraction level
synset2higher_level = dict()
if args.gran in {'sensekey', 'blc20', 'direct_hypernym'}:
Expand Down Expand Up @@ -288,6 +294,7 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc
# perform wsd
if len(the_chosen_candidates) >= 2:
chosen_synset, \
candidate_std, \
candidate_freq, \
strategy = score_synsets(target_embedding,
the_chosen_candidates,
Expand All @@ -303,23 +310,47 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc

else:
chosen_synset = None
candidate_std = None
if the_chosen_candidates:
chosen_synset = the_chosen_candidates[0]
chosen_synset = the_chosen_candidates[0]
candidate_freq = dict()

# add to dataframe
wsd_df.set_value(row_index, col='target_embedding', value=target_embedding)
wsd_df.set_value(row_index, col='lstm_output', value=chosen_synset)
wsd_df.set_value(row_index, col='std_chosen_synset', value=candidate_std)

wsd_df.set_value(row_index, col='#_cand_synsets', value=len(candidate_synsets))
wsd_df.set_value(row_index, col='#_new_cand_synsets', value=len(new_candidate_synsets))
wsd_df.set_value(row_index, col='gold_in_new_cand_synsets', value=gold_in_candidates)
wsd_df.set_value(row_index, col='wsd_strategy', value=wsd_strategy)

# score it
print(chosen_synset, row['source_wn_engs'])
lstm_acc = chosen_synset in row['source_wn_engs'] # used to be wn30_engs


only_one_embedding = False
has_gold_embedding = True

if wsd_strategy != 'monosemous':
num_embeddings = 0
has_gold_embedding = False

for source_wn_eng in row['source_wn_engs']:
if source_wn_eng in candidate_freq:
if candidate_freq[source_wn_eng]:
has_gold_embedding = True
for synset_id, freq in candidate_freq.items():
if freq:
num_embeddings += 1

only_one_embedding = num_embeddings == 1

wsd_df.set_value(row_index, col='has_gold_embedding', value=has_gold_embedding)
wsd_df.set_value(row_index, col='only_one_embedding', value=only_one_embedding)
wsd_df.set_value(row_index, col='lstm_acc', value=lstm_acc)
wsd_df.set_value(row_index, col='emb_freq', value=candidate_freq)

wsd_df.set_value(row_index, col='emb_freq', value=candidate_freq)
wsd_df.set_value(row_index, col='wsd_strategy', value=wsd_strategy)

if lstm_acc:
num_correct += 1

Expand All @@ -341,4 +372,13 @@ def score_synsets(target_embedding, candidate_synsets, sense_embeddings, instanc




# write tsne visualizations
visualize = True
if visualize:
output_folder = args.results.replace('/results.txt', '')
tsne_utils.create_tsne_visualizations(output_folder,
correct={False, True},
meanings=True,
instances=True,
polysemy=range(2, 3),
num_embeddings=range(2, 100))
24 changes: 24 additions & 0 deletions evaluate/score_utils.py
@@ -1,5 +1,29 @@
import os


def no_sense_data_for_non_gold_cand(emb_freq, source_wn_engs):
"""
check whether there are training instances
for the other senses than the gold sense
:param dict emb_freq: mapping synset_id -> number of sense annotated examples
:param set source_wn_engs: set of gold synset ids
:rtype: bool
:return: True -> there are no training instances for non gold synset ids
False, there are instances for non gold synset ids
"""
only_data_for_answer = True
for synset_id, freq in emb_freq.items():

if all([freq >= 1, # 1 or more training instances
synset_id not in source_wn_engs # check if in answer
]):
only_data_for_answer = False

return only_data_for_answer


def experiment_results(df, mfs_fallback, wsd_df_path):
"""
given df with wsd results, return information for table
Expand Down
9 changes: 6 additions & 3 deletions evaluate/test-lstm_v2.py
Expand Up @@ -53,6 +53,7 @@ def ctx_embd_input(sentence):
print('loaded vocab')

synset2context_embds = defaultdict(list)
synset2instances = dict()
meaning_freqs = defaultdict(int)
batch_size = int(args.batch_size)
counter = 0
Expand Down Expand Up @@ -83,8 +84,6 @@ def ctx_embd_input(sentence):

for line in n_lines:



sentence = line.strip()
tokens, annotation_indices = ctx_embd_input(sentence)

Expand Down Expand Up @@ -121,10 +120,14 @@ def ctx_embd_input(sentence):
synset2avg_embedding = dict()
for synset, embeddings in synset2context_embds.items():
average = sum(embeddings) / len(embeddings)
synset2avg_embedding[synset] = average
std = np.std(embeddings)
synset2avg_embedding[synset] = average, std

with open(args.output_path, 'wb') as outfile:
pickle.dump(synset2avg_embedding, outfile)

with open(args.output_path + '.instances', 'wb') as outfile:
pickle.dump(synset2context_embds, outfile)

with open(args.output_path + '.freq', 'wb') as outfile:
pickle.dump(meaning_freqs, outfile)

0 comments on commit 1b491b2

Please sign in to comment.