In [None]:
from nltk.corpus import wordnet
from transformers import AutoTokenizer, CLIPTextModel
import torch
import numpy as np

In [None]:
original_sentence = "pick up the red hexagon"
target_sentence = "pick up the blue box"

In [None]:
wordnet.synsets('hexagon')

In [None]:
substitudes = []
for syns in list(filter(lambda x: 'n' in x.name() or 's' in x.name(), wordnet.synsets('hexagon'))):
  for lem in syns.lemmas():
    substitudes.append(lem.name().replace('_', ' '))
substitudes = list(set(substitudes))
substitudes

In [None]:
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
original_inputs = tokenizer(original_sentence, return_tensors="pt")
original_emb = model(**original_inputs).pooler_output[0]

In [None]:
target_inputs = tokenizer(target_sentence, return_tensors="pt")
target_emb = model(**target_inputs).pooler_output[0]

In [None]:
original_emb.dot(target_emb.t()).item()

In [None]:
sentence_pairs = [
  {"target_word": "red", "sentences": ("pick up the red hexagon", "pick up the blue box")},
]

In [None]:
target_word = 'pink'
pair_idx = 2

In [None]:
def test_substitudes(pair): 
  target_word = pair["target_word"]
  original_sentence ,target_sentence  = pair[pair_idx]
  original_inputs = tokenizer(original_sentence, return_tensors="pt")
  original_emb = model(**original_inputs).pooler_output[0]
  target_inputs = tokenizer(target_sentence, return_tensors="pt")
  target_emb = model(**target_inputs).pooler_output[0]
  print(original_sentence, "->", target_sentence)
  substitudes = []
  for syns in list(filter(lambda x: 'n' in x.name() or 's' in x.name(), wordnet.synsets(target_word))):
    for lem in syns.lemmas():
      substitudes.append(lem.name().replace('_', ' '))
  substitudes = list(set(substitudes))

  for subs in substitudes:
    perturbed = original_sentence.replace(target_word, subs)
    
    inputs = tokenizer(perturbed, return_tensors="pt")
    emb = model(**inputs).pooler_output[0]
    odot = emb.dot(original_emb.t()).item()
    tdot = emb.dot(target_emb.t()).item()
    if (odot < tdot):
      print("=================================")
      print(perturbed)
      print("Success", odot, tdot)
      print("=================================")

In [None]:
def test_rewrite(pair, target_sentence): 
  original_sentence, perturbed  = pair
  original_inputs = tokenizer(original_sentence, return_tensors="pt")
  original_emb = model(**original_inputs).pooler_output[0]
  target_inputs = tokenizer(target_sentence, return_tensors="pt")
  target_emb = model(**target_inputs).pooler_output[0]
  # print(original_sentence, "->", target_sentence)
  
  inputs = tokenizer(perturbed, return_tensors="pt")
  emb = model(**inputs).pooler_output[0]
  odot = emb.dot(original_emb.t()).item()
  tdot = emb.dot(target_emb.t()).item()
  print("=================================")
  print(perturbed)
  print(odot, tdot)
  print("=================================")

In [None]:
rewritten_sentences = [
    ("put the blue box in the green bowl", "position the cobalt container within the emerald basin"),
    ("put the blue box in the green bowl", "deposit the azure box into the verdant receptacle"),
    ("put the blue box in the green bowl", "set the sapphire container inside the lush vessel"),
    ("put the blue box in the green bowl", "place the indigo box within the grassy dish"),
    ("put the blue box in the green bowl", "lay the cerulean container inside the leafy bowl"),
    ("put the blue box in the green bowl", "arrange the navy box within the jade bowl"),
    ("put the blue box in the green bowl", "position the sky-blue container inside the olive bowl"),
    ("put the blue box in the green bowl", "insert the azure box into the viridescent basin"),
    ("put the blue box in the green bowl", "put the blue box in the green bowl")
]

In [None]:
for rpair in rewritten_sentences:
  test_rewrite(rpair, "put the green box in the blue bowl")