# WordNet-based Query Expansion Playground

The goal of WN-based query expansion is the same as BERT-based query expansion; furthermore, the strategy is almost the same, except for how similar tokens are generated.
In this case the candidate tokens are selected from the set of synonyms of the word that has to be expanded.

In [3]:
import pydash
from nltk.wsd import lesk
from nltk.corpus import wordnet as wn, wordnet_ic

def naive_wsd(list_of_synsets, term_dis):
  """
  list_of_synsets list of lists containig synsets of each word
  term_dis term to be disambiguated
  """
  brown_ic = wordnet_ic.ic("ic-brown.dat")
  # Lower res_similarity -> low probability of associated concepts

  sense_confidence = float('-inf')
  disambiguated_sense = None

  for sense_dis in term_dis:
    confidence = 0
    for term_other in list_of_synsets:
      if term_dis != term_other:
        confidence += max([sense_dis.res_similarity(sense_other, brown_ic) for sense_other in term_other])
    if confidence > sense_confidence:
      disambiguated_sense = sense_dis
      sense_confidence = confidence
  
  return disambiguated_sense, confidence

Below, some experiments have been made in order to understand which is the best way to get the candidates.
An empirical test showed that Lesk WSD underperforms against the naive strategy. The idea from that point would have been to take synset's hyponyms and hyperonyms, but the overhead caused by WSD and POS tagging (for a more accurate WSD) is not worth the effort.
Instead, taking the synonyms of a word seems to be a much more consistent method.

In [41]:
TOKEN_ID = 2

original_sentence_tokens = 'modern shared room near Harvard'.split()

tmp = original_sentence_tokens[:]
tmp[TOKEN_ID] = '{}'
original_sentence_fmt = ' '.join(tmp)
token = original_sentence_tokens[TOKEN_ID]

lesk_synset = lesk(original_sentence_tokens, token)
print(f'{lesk_synset=}')

nouns_synsets = (
    pydash.chain(original_sentence_tokens)
      .map(lambda n: wn.morphy(n, wn.NOUN))
      .filter(lambda n: n is not None)
      .map(lambda n: wn.synsets(n, wn.NOUN))
      .value()
  )

naive_synset = naive_wsd(nouns_synsets, nouns_synsets[1])
print(f'{naive_synset=}')
print([s.lemma_names() for s in wn.synsets(token)])

for s in wn.synsets(token):
    print(s, s.definition())

lesk_synset=Synset('room.n.04')
naive_synset=(Synset('room.n.01'), 0.5962292078977726)
[['room'], ['room', 'way', 'elbow_room'], ['room'], ['room'], ['board', 'room']]
Synset('room.n.01') an area within a building enclosed by walls and floor and ceiling
Synset('room.n.02') space for movement
Synset('room.n.03') opportunity for
Synset('room.n.04') the people who are present in a room
Synset('board.v.02') live and take one's meals at or in


In [42]:
from transformers import BertTokenizer
from transformers import BertModel
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', cache_dir = 'hf_cache')
encoder = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True, cache_dir = 'hf_cache')

def get_meaned_embeddings(sentence: str):
    tokens = tokenizer.tokenize(sentence)
    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    input_ids = torch.tensor(input_ids).unsqueeze(0)
    with torch.no_grad():
        outputs = encoder(input_ids)
        embedding = outputs.last_hidden_state[0]

    return embedding.mean(dim = 0)

In [43]:
from operator import itemgetter

candidates = (
    pydash.chain([s.lemma_names() for s in wn.synsets(token)])
        .flatten_deep()
        .sorted_uniq()
        .value()
)

print(candidates)

original_sentence_embedding = get_meaned_embeddings(original_sentence_fmt.format(token))
cos_sim = torch.nn.CosineSimilarity(dim = 0)

similarities = (
    pydash.chain(candidates)
        .map(lambda c: c.replace('_', ' '))
        .map(lambda c: original_sentence_fmt.format(c))
        .map(get_meaned_embeddings)
        .map(lambda x: cos_sim(x, original_sentence_embedding))
        .value()
)
print(similarities)

THRESHOLD = 0.8
SLICE = 5

expansions = (
    pydash.chain(candidates)
        .zip(similarities)
        .filter(lambda t: t[1] > THRESHOLD)
        .sort(key = itemgetter(1), reverse = True)
        .map(itemgetter(0))
        .take(SLICE)
        .value()
)

print(expansions)

['board', 'elbow_room', 'room', 'way']
[tensor(0.9377), tensor(0.9285), tensor(1.), tensor(0.9333)]
['room', 'board', 'way', 'elbow_room']
