In [37]:
from sentence_transformers import SentenceTransformer

encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
text = """Positional control of pneumatic manipulators for constructiobn tasks
This paper describes solutions that can be applied to pneumatic manipulator
	problems in positioning, both for angle trajectories and for long
	linear trajectories, used in construction tasks. Optimal positioning of
	a pneumatic manipulator along angle trajectories with minimum control
	energy consumption is given. The implementation of the control system
	is presented. Control algorithms for a long linear trajectory
	manipulator based on two-phase and three-phase motion modes of the
	end-effector are investigated. Conventional and fuzzy logic controls of
	a pneumatic manipulator were applied and experimental testing was
	carried out. The obtained results allow widening the application range
	of pneumatic manipulators in construction, particularly in gantry type
	machines"""

enc_tokenizer = encoder.tokenizer

In [38]:
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel

gpt_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
generator = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
example_text = "I really like eating ice cream and drinking hot"
inputs = gpt_tokenizer(example_text, return_tensors="pt")
outputs = generator(**inputs, labels=inputs["input_ids"])
logits = outputs.logits
print(logits.shape)
am_tokens = torch.argmax(logits, dim=-1)
print(am_tokens.shape)
decoded_tokens = am_tokens[0].tolist()
last_token = decoded_tokens[-1]
decoded_token = gpt_tokenizer.decode(last_token)
print(decoded_token)
print("next token in the sentence is:", decoded_token)

torch.Size([1, 9, 50257])
torch.Size([1, 9])
 chocolate
next token in the sentence is:  chocolate


In [31]:
import torch.nn.functional as F

def get_top_k_next_tokens(text, k=1):
    assert k > 0 and k <= len(gpt_tokenizer.vocab)
    # Tokenize the input text
    encoded_input = gpt_tokenizer(text, return_tensors='pt')
    print("encoded input:", encoded_input)
    # Get model output
    output = generator(**encoded_input)
    logits = output.logits
    probabilities = F.softmax(logits, dim=-1)[0][-1]

    # Get top k indices for each position in the sequence
    top_k_indices = torch.topk(probabilities, k, dim=-1).indices
    # Decode each index using the tokenizer
    decoded_tokens = [gpt_tokenizer.decode(int(idx)) for idx in top_k_indices]

    return decoded_tokens

# Specify the number of top-k tokens to retrieve
k = 3
result = get_top_k_next_tokens(example_text, k)
print(result)


encoded input: {'input_ids': tensor([[  40, 1107,  588, 6600, 4771, 8566,  290, 7722, 3024]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
[' chocolate', ' coffee', ' water']


In [39]:
unused_token = next(
        k for k, v in enc_tokenizer.vocab.items()
        if "[unused" in k)
unused_token_id = enc_tokenizer.vocab[unused_token]
new_token_id = unused_token_id
non_new_token_ids = torch.tensor([i for i in range(len(enc_tokenizer.vocab)) if i != new_token_id])

for param in encoder.parameters():
    param.requires_grad = False
print(len(enc_tokenizer))
print(type(encoder[0].auto_model.embeddings.word_embeddings))

dot_prods = []

for idx, p in enumerate(encoder[0].auto_model.embeddings.word_embeddings.parameters()):
    p.requires_grad = True
    t = torch.tensor(p)
    for x in range(t.shape[0]):
        subt = t[x]
        dot_prods.append(torch.dot(subt, subt).item())

print(min(dot_prods), max(dot_prods))

assert any(p.requires_grad for p in encoder.parameters())

30522
<class 'torch.nn.modules.sparse.Embedding'>
0.0599682554602623 2.869729995727539


  t = torch.tensor(p)


In [44]:
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch import nn
from tqdm import tqdm
from torch.nn import CosineSimilarity

criterion = nn.CosineEmbeddingLoss()

def soft_prompt_for_text(text, target):
    tokenized_text = enc_tokenizer.encode(text)
    tokenized_text[-1] = new_token_id
    tokenized_text.append(102)
    input_ids = torch.tensor(tokenized_text).unsqueeze(0)
    attention_mask = torch.ones_like(input_ids)
    optimizer = optim.Adam(encoder.parameters(), lr=1)
    scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=0.0001)
    best = float("inf")
    best_emb = None
    for _ in tqdm(range(200)):
        optimizer.zero_grad()
        output = encoder({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        })["sentence_embedding"]
        
        loss = criterion(output.squeeze(), target, torch.tensor(1.0))
        loss.backward()
        # set grad of non-new token to 0
        # all ids except new_token_id
        
        encoder[0].auto_model.embeddings.word_embeddings.weight.grad[non_new_token_ids] = 0

        optimizer.step()
        scheduler.step()

        if loss.item() < best:
            best = loss.item()
            best_emb = encoder[0].auto_model.embeddings.word_embeddings.weight.data[new_token_id].cpu().numpy()
    
    # now discretization
    cossims = []
    for token in tqdm(range(len(enc_tokenizer.vocab))):
        token_embedding = encoder[0].auto_model.embeddings.word_embeddings.weight.data[token].cpu().numpy()
        cossims.append(CosineSimilarity()(torch.Tensor(token_embedding).unsqueeze(0), torch.Tensor(best_emb).unsqueeze(0)).item())
    return -max(cossims)

    
target = torch.tensor(encoder.encode(text))
soft_prompt_for_text(example_text, target)

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:22<00:00,  9.00it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11498.21it/s]


-1.0

In [45]:
from queue import PriorityQueue

def dot_prod(text, target):
    vec = encoder.encode(text, convert_to_tensor=True)
    return vec @ target

def iterative_soft_prompt(text, target, k=3, iter_num=100):
    pq = PriorityQueue()
    pq.put((0, text))
    best_text = None
    best_cossim = float('-inf')
    for _ in range(iter_num):
        text = pq.get()[1]
        this_cossim = dot_prod(text, target)
        print("cos sim:", text, this_cossim)
        if this_cossim > best_cossim:
            best_cossim = this_cossim
            best_text = text
        print("text:", text)
        top_k_next_tokens = get_top_k_next_tokens(text, k)
        print(top_k_next_tokens)
        top_k_words = [text + token for token in top_k_next_tokens]
        print(top_k_words)
        soft_prompt_scores = [(soft_prompt_for_text(next_word, target), next_word) for next_word in top_k_words]
        print(soft_prompt_scores)
        for sp_score, next_word in soft_prompt_scores:
            pq.put((sp_score, next_word))
    return best_text, best_cossim
        
target = torch.tensor(encoder.encode(text))
text, cossim = iterative_soft_prompt("This", target)

cos sim: This tensor(-0.0359)
text: This
encoded input: {'input_ids': tensor([[1212]]), 'attention_mask': tensor([[1]])}
[' is', ',', '.']
['This is', 'This,', 'This.']


100%|██████████| 200/200 [00:21<00:00,  9.18it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11769.70it/s]
100%|██████████| 200/200 [00:21<00:00,  9.50it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11612.96it/s]
100%|██████████| 200/200 [00:21<00:00,  9.23it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11283.20it/s]


[(-1.0, 'This is'), (-1.0, 'This,'), (-0.9999999403953552, 'This.')]
cos sim: This is tensor(-0.0573)
text: This is
encoded input: {'input_ids': tensor([[1212,  318]]), 'attention_mask': tensor([[1, 1]])}
[' a', ' the', ' not']
['This is a', 'This is the', 'This is not']


100%|██████████| 200/200 [00:21<00:00,  9.18it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11884.70it/s]
100%|██████████| 200/200 [00:21<00:00,  9.46it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12454.96it/s]
100%|██████████| 200/200 [00:19<00:00, 10.47it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12343.28it/s]


[(-0.9999998807907104, 'This is a'), (-1.0, 'This is the'), (-1.0, 'This is not')]
cos sim: This is not tensor(-0.0120)
text: This is not
encoded input: {'input_ids': tensor([[1212,  318,  407]]), 'attention_mask': tensor([[1, 1, 1]])}
[' a', ' the', ' an']
['This is not a', 'This is not the', 'This is not an']


100%|██████████| 200/200 [00:19<00:00, 10.41it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12445.08it/s]
100%|██████████| 200/200 [00:18<00:00, 10.61it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12074.67it/s]
100%|██████████| 200/200 [00:18<00:00, 10.72it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12769.09it/s]


[(-1.0, 'This is not a'), (-1.0000001192092896, 'This is not the'), (-1.0, 'This is not an')]
cos sim: This is not the tensor(-0.0618)
text: This is not the
encoded input: {'input_ids': tensor([[1212,  318,  407,  262]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
[' first', ' end', ' case']
['This is not the first', 'This is not the end', 'This is not the case']


100%|██████████| 200/200 [00:18<00:00, 10.72it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12346.81it/s]
100%|██████████| 200/200 [00:18<00:00, 10.82it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12447.72it/s]
100%|██████████| 200/200 [00:20<00:00,  9.98it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12124.06it/s]


[(-0.9999998211860657, 'This is not the first'), (-1.0, 'This is not the end'), (-1.0000001192092896, 'This is not the case')]
cos sim: This is not the case tensor(-0.0988)
text: This is not the case
encoded input: {'input_ids': tensor([[1212,  318,  407,  262, 1339]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
['.', ' with', ' for']
['This is not the case.', 'This is not the case with', 'This is not the case for']


100%|██████████| 200/200 [00:20<00:00,  9.83it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12013.60it/s]
100%|██████████| 200/200 [00:18<00:00, 10.59it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12486.90it/s]
100%|██████████| 200/200 [00:19<00:00, 10.41it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11888.53it/s]


[(-0.9999998211860657, 'This is not the case.'), (-1.0, 'This is not the case with'), (-1.0, 'This is not the case for')]
cos sim: This is not a tensor(-0.0143)
text: This is not a
encoded input: {'input_ids': tensor([[1212,  318,  407,  257]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
[' good', ' new', ' problem']
['This is not a good', 'This is not a new', 'This is not a problem']


100%|██████████| 200/200 [00:18<00:00, 10.72it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12381.88it/s]
100%|██████████| 200/200 [00:18<00:00, 10.72it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12158.46it/s]
100%|██████████| 200/200 [00:18<00:00, 10.69it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12965.45it/s]


[(-0.9999998807907104, 'This is not a good'), (-0.9999998211860657, 'This is not a new'), (-0.9999998807907104, 'This is not a problem')]
cos sim: This is not an tensor(-0.0249)
text: This is not an
encoded input: {'input_ids': tensor([[1212,  318,  407,  281]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
[' easy', ' exhaustive', ' attempt']
['This is not an easy', 'This is not an exhaustive', 'This is not an attempt']


100%|██████████| 200/200 [00:18<00:00, 10.71it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12437.04it/s]
100%|██████████| 200/200 [00:18<00:00, 10.73it/s]
100%|██████████| 30522/30522 [00:02<00:00, 12495.07it/s]
100%|██████████| 200/200 [00:19<00:00, 10.23it/s]
100%|██████████| 30522/30522 [00:02<00:00, 11054.17it/s]


[(-1.0, 'This is not an easy'), (-1.0000001192092896, 'This is not an exhaustive'), (-0.9999999403953552, 'This is not an attempt')]
cos sim: This is not an exhaustive tensor(0.0294)
text: This is not an exhaustive
encoded input: {'input_ids': tensor([[ 1212,   318,   407,   281, 36049]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
[' list', ' review', ' study']
['This is not an exhaustive list', 'This is not an exhaustive review', 'This is not an exhaustive study']


 24%|██▍       | 49/200 [00:05<00:15,  9.81it/s]

In [None]:
text, cossim