In [1]:
from prompts import formatting_prompts
from data import get_dataset
dataset_name = "imdb-small"
split = "train"

batch_size = 16
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
dataset = get_dataset(dataset_name, split).select(range(100))
dataset



Dataset({
    features: ['text', 'label', 'answer', 'sample_id'],
    num_rows: 100
})

In [2]:
import torch 
def print_settings(fixedness, semantics, count):
    print("--------------------------------")
    print(f"FIXEDNESS: {fixedness}, SEMANTICS: {semantics}, COUNT: {count}")
    print("--------------------------------")
    print(f"Decoded Probing Tokens:")
    dataloader = formatting_prompts(dataset, fixedness, semantics, count, tokenizer, batch_size)
    v = iter(dataloader)
    probe_positions_lengths = []
    for i in range(len(dataloader)):
        sample = next(v)
        input_ids = sample['input_ids']
        probe_positions = sample['probe_positions']
        sample_ids = sample['sample_ids']
        tokenizer.decode(input_ids[0]) 
        p = tokenizer.decode(input_ids[0][probe_positions[0]])
        if i < 5:
            print(f"Sample {i} >>> ", p.replace("\n", " \\n"))
        probe_positions_lengths.append(torch.tensor([len(v) for v in probe_positions]))
    probe_positions_lengths = torch.concat(probe_positions_lengths).float()
    print("--------------------------------")
    print("Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)")
    print(f" >>> Variance of {len(probe_positions_lengths)} Samples: ", torch.var(probe_positions_lengths).item())
    print(f" >>> Mean of {len(probe_positions_lengths)} Samples: ", torch.mean(probe_positions_lengths).item())

In [3]:
fixedness_list = ['fixed', 'variable']
semantics_list = ['syntactical', 'special', 'random']
count_list = ['single', 'multi']

In [4]:
print_settings('fixed', 'syntactical', 'single')
print_settings('fixed', 'syntactical', 'multi')

--------------------------------
FIXEDNESS: fixed, SEMANTICS: syntactical, COUNT: single
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>  ...
Sample 1 >>>  .
Sample 2 >>>  !.
Sample 3 >>>  .
Sample 4 >>>  ?.
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: fixed, SEMANTICS: syntactical, COUNT: multi
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>   Think step by step.
Sample 1 >>>   Think step by step.
Sample 2 >>>   Think step by step.
Sample 3 >>>   Think step by step.
Sample 4 >>>   Think step by step.
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  5.0


In [5]:
print_settings('fixed', 'special', 'single')
print_settings('fixed', 'special', 'multi')

--------------------------------
FIXEDNESS: fixed, SEMANTICS: special, COUNT: single
--------------------------------
Decoded Probing Tokens:


Sample 0 >>>  <|eot_id|>
Sample 1 >>>  <|eot_id|>
Sample 2 >>>  <|eot_id|>
Sample 3 >>>  <|eot_id|>
Sample 4 >>>  <|eot_id|>
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: fixed, SEMANTICS: special, COUNT: multi
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 1 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 2 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 3 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 4 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples

In [6]:
print_settings('fixed', 'random', 'single')
print_settings('fixed', 'random', 'multi')

--------------------------------
FIXEDNESS: fixed, SEMANTICS: random, COUNT: single
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>   random
Sample 1 >>>   random
Sample 2 >>>   random
Sample 3 >>>   random
Sample 4 >>>   random
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: fixed, SEMANTICS: random, COUNT: multi
--------------------------------
Decoded Probing Tokens:


Sample 0 >>>   Random text is inserted.
Sample 1 >>>   Random text is inserted.
Sample 2 >>>   Random text is inserted.
Sample 3 >>>   Random text is inserted.
Sample 4 >>>   Random text is inserted.
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  5.0


In [7]:
print_settings('variable', 'syntactical', 'single')
print_settings('variable', 'syntactical', 'multi')

--------------------------------
FIXEDNESS: variable, SEMANTICS: syntactical, COUNT: single
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>  !
Sample 1 >>>  !
Sample 2 >>>  !
Sample 3 >>>  !
Sample 4 >>>  ,
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: variable, SEMANTICS: syntactical, COUNT: multi
--------------------------------
Decoded Probing Tokens:


Sample 0 >>>   Think through it slowly.
Sample 1 >>>   Think through it slowly.
Sample 2 >>>   Think through it slowly.
Sample 3 >>>   Think through it slowly.
Sample 4 >>>   as a careful analyst.
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  5.0


In [None]:
# Warning:
# Warning:
# Warning:
print("!!! Variable Special is not supported and equivalent to Fixed Special as Candidates for Special Tokens are Non-trivial")
print_settings('variable', 'special', 'single')
print_settings('variable', 'special', 'multi')

!!! Variable Special is not supported and equivalent to Fixed Special as Candidates for Special Tokens are Non-trivial
--------------------------------
FIXEDNESS: variable, SEMANTICS: special, COUNT: single
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>  <|eot_id|>
Sample 1 >>>  <|eot_id|>
Sample 2 >>>  <|eot_id|>
Sample 3 >>>  <|eot_id|>
Sample 4 >>>  <|eot_id|>
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: variable, SEMANTICS: special, COUNT: multi
--------------------------------
Decoded Probing Tokens:


Sample 0 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 1 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 2 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 3 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
Sample 4 >>>  <|eot_id|><|start_header_id|>assistant<|end_header_id|> \n \n
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  5.0


In [9]:
print_settings('variable', 'random', 'single')
print_settings('variable', 'random', 'multi')

--------------------------------
FIXEDNESS: variable, SEMANTICS: random, COUNT: single
--------------------------------
Decoded Probing Tokens:
Sample 0 >>>   beta
Sample 1 >>>   beta
Sample 2 >>>   probe
Sample 3 >>>   beta
Sample 4 >>>   probe
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  1.0
--------------------------------
FIXEDNESS: variable, SEMANTICS: random, COUNT: multi
--------------------------------
Decoded Probing Tokens:


Sample 0 >>>   Arbitrary tokens appended here.
Sample 1 >>>   Arbitrary tokens appended here.
Sample 2 >>>   Arbitrary tokens appended here.
Sample 3 >>>   Arbitrary tokens appended here.
Sample 4 >>>   Noise sequence without semantics.
--------------------------------
Variance of Probing Token Lengths (Expected to be 0 to ensure the same length of probing tokens)
 >>> Variance of 100 Samples:  0.0
 >>> Mean of 100 Samples:  5.0
