In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from transformers.tokenization_utils_base import BatchEncoding

In [2]:
# Load the model with output_hidden_states set to True
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
def generate_text(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    max_num_tokens: int = 25,
    top_k: int = 5,
    layer: int = 8,
    temperature: float = 1.0,
    stop_token_ids: list = [],
    stop_words: list = [],
    eos_weight: float = 2.0,
    enable_logging: bool = False
) -> str:
    """
    Generate text using a language model.

    Args:
        model (model): The language model.
        tokenizer (model): The tokenizer corresponding to the model.
        prompt (str): The initial text to start generation from.
        num_tokens (int, optional): The number of tokens to generate. Defaults to 5.
        top_k (int, optional): The number of top tokens to consider for sampling. Defaults to 5.
        layer (int, optional): The layer of the model to use for generation. Defaults to 8.
        temperature (float, optional): The temperature for softmax. Defaults to 1.0.
        stop_token_ids (list, optional): List of token ids that will end generation if sampled. Defaults to [].
        stop_words (list, optional): List of words that will end generation if sampled. Defaults to [].
        eos_weight (float, optional): The weight to assign to the EOS token. Defaults to 2.0.
        enable_logging (bool, optional): Enable logging for debugging. Defaults to False.

    Returns:
        str: The generated text.
    """
    # Move model to GPU if available
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')
    model = model.to(device)
    
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs.to(device)
    
    # Get the EOS token ID
    eos_token_id = tokenizer.eos_token_id
    stop_token_ids.append(eos_token_id)
    output_tokens = []
    
    # Generate num_tokens tokens
    for _ in range(max_num_tokens):
        # Forward pass through the model
        outputs = model(**inputs)

        # Get the output of the specified layer
        layer_output = outputs.hidden_states[layer]

        # Pass the output through the final linear layer
        logits = model.lm_head(layer_output)

        # Apply softmax to get probabilities
        probabilities = F.softmax(logits / temperature, dim=-1)

        # Increase the weight of the EOS token
        probabilities[0, -1, eos_token_id] *= eos_weight

        # Get the probabilities of the top k tokens
        top_k_probabilities, top_k_indices = torch.topk(probabilities[0, -1], top_k)

        # Normalize the top k probabilities
        top_k_probabilities /= torch.sum(top_k_probabilities)

        # Sample from the top k probability distribution
        sampled_token_id = top_k_indices[torch.multinomial(top_k_probabilities, 1)].item()
        output_tokens.append(sampled_token_id)
        
        # Decode the token id back into text
        sampled_token_text = tokenizer.decode([sampled_token_id])
        sampled_token = {'input_ids': torch.tensor([[1, sampled_token_id]]), 'attention_mask': torch.tensor([[1, 1]])}

        # If the sampled token is a stop token or stop word, return the generated text
        if sampled_token_id in stop_token_ids or sampled_token_text in stop_words:
            generated_text = tokenizer.decode(inputs['input_ids'][0])
            pure_output = tokenizer.decode(output_tokens)
            return {'text':generated_text,'output':pure_output,'prompt':prompt}

        combined_input_ids = torch.cat((inputs['input_ids'].detach().cpu(), sampled_token['input_ids'][:,1:]), dim=-1)
        combined_attention_mask = torch.cat((inputs['attention_mask'].detach().cpu(), sampled_token['attention_mask'][:,1:]), dim=-1)
        
        inputs = BatchEncoding({'input_ids': combined_input_ids, 'attention_mask': combined_attention_mask}).to(device)

        # Log the token and top k tokens if logging is enabled
        if enable_logging:
            print(f'Token: {sampled_token_text}')
            print('      -- Top tokens --')
            for i in range(top_k):
                token = tokenizer.decode([top_k_indices[i].item()])
                probability = top_k_probabilities[i].item()
                print(f'   {token}: {probability}')

            print()
    generated_text = tokenizer.decode(inputs['input_ids'][0])
    pure_output = tokenizer.decode(output_tokens)
    # return {'text':generated_text,'output':pure_output,'prompt':prompt}
    return generated_text

In [17]:
prompt = "My name is"
output_8 = generate_text(model, tokenizer, prompt, max_num_tokens = 10, layer = 8, temperature = 0.1, stop_words=['\n'], enable_logging=True)
print(f'layer 8: {output_8}')

Token: Prem
      -- Top tokens --
   Mand: 0.28617581725120544
   Prem: 0.19170637428760529
   ears: 0.19133175909519196
   Het: 0.1694001704454422
   gresql: 0.16138584911823273

Token: Prem
      -- Top tokens --
   ゆ: 0.2643747329711914
   Prem: 0.22167639434337616
   ette: 0.1893041580915451
   commission: 0.163179412484169
   ља: 0.16146530210971832

Token: Singh
      -- Top tokens --
   Spr: 0.30036622285842896
   aj: 0.2109912782907486
   hi: 0.17049676179885864
   ette: 0.16448424756526947
   Singh: 0.15366148948669434

Token: hel
      -- Top tokens --
   hel: 0.5729809403419495
   (: 0.13473165035247803
   Dutch: 0.12364538013935089
   op: 0.10515786707401276
   : 0.06348417699337006

Token: Son
      -- Top tokens --
   i: 0.8694928884506226
   own: 0.0381438173353672
   Giorg: 0.03432023897767067
   rec: 0.0301945973187685
   Son: 0.02784840390086174

Token: ham
      -- Top tokens --
   ham: 0.33561453223228455
   @@: 0.23042529821395874
   rich: 0.16570591926574707
   a

In [22]:
output_16 = generate_text(model, tokenizer, prompt, max_num_tokens = 10, layer = 16, temperature = 0.1, stop_words=['\n'], enable_logging=True)
print(f'layer 16: {output_16}')

Token: Außer
      -- Top tokens --
   Außer: 0.7033793926239014
   ̄: 0.09360671788454056
   Bedeut: 0.09052415937185287
   pur: 0.06295628100633621
   Prem: 0.04953346773982048

Token: gr
      -- Top tokens --
   gr: 0.4242660701274872
   tie: 0.26104244589805603
   p: 0.12375162541866302
   poj: 0.12263013422489166
   norm: 0.06830968707799911

Token: péri
      -- Top tokens --
   péri: 0.8730854392051697
   statunit: 0.040330689400434494
   xxx: 0.032812707126140594
   Orange: 0.031317807734012604
   hab: 0.02245336025953293

Token: Um
      -- Top tokens --
   amb: 0.28458425402641296
   hat: 0.21934747695922852
   odkazy: 0.19791728258132935
   cia: 0.180925190448761
   Um: 0.11722587794065475

Token: idense
      -- Top tokens --
   idense: 0.3420819044113159
   bre: 0.326993852853775
   endi: 0.17018598318099976
   ass: 0.08770638704299927
   bers: 0.073031947016716

Token: ur
      -- Top tokens --
   ur: 0.9688129425048828
   ek: 0.02209775149822235
   curity: 0.00491807935

In [23]:
output_24 = generate_text(model, tokenizer, prompt, max_num_tokens = 10, layer = 24, temperature = 0.1, stop_words=['\n'], enable_logging=True)
print(f'layer 24: {output_24}')

Token: in
      -- Top tokens --
   in: 0.9529890418052673
   .: 0.046198099851608276
   …: 0.0008117356919683516
   today: 9.878916671368643e-07
   and: 8.411416985154574e-08

Token: bold
      -- Top tokens --
   bold: 0.9978393316268921
   in: 0.0020528656896203756
   capital: 6.939901504665613e-05
   exp: 2.5015280698426068e-05
   print: 1.3429616046778392e-05

Token: ital
      -- Top tokens --
   ital: 0.5508857369422913
   bold: 0.37772342562675476
   font: 0.07139083743095398
   Ital: 7.815165048441486e-08
   letters: 3.5846399182304367e-09

Token: ital
      -- Top tokens --
   ital: 0.9998922348022461
   ics: 0.00010779927833937109
   bold: 1.0148318851932459e-09
   font: 2.600940574026822e-10
   Ital: 3.535142664712332e-11

Token: ital
      -- Top tokens --
   ital: 0.9992073178291321
   Ital: 0.0007926259422674775
   Ital: 3.913533497268418e-09
   ics: 6.642618041990261e-14
   bold: 2.3764170020918254e-14

Token: ital
      -- Top tokens --
   ital: 0.9977964162826538
   I

In [27]:
output_32 = generate_text(model, tokenizer, prompt, max_num_tokens = 10, layer = 32, temperature = 0.1, stop_words=['\n'], enable_logging=True)
print(f'layer 32: {output_32}')

Token: K
      -- Top tokens --
   K: 0.9287389516830444
   J: 0.030214475467801094
   L: 0.02070675790309906
   D: 0.014593163505196571
   A: 0.005746637936681509

Token: atie
      -- Top tokens --
   atie: 0.9969188570976257
   yle: 0.0026733395643532276
   atherine: 0.0004066026012878865
   ait: 7.016427048256446e-07
   else: 4.5767313849864877e-07

Token: and
      -- Top tokens --
   and: 0.9799004197120667
   .: 0.018081780523061752
   ,: 0.002017837017774582
   K: 8.309279913437662e-13
   H: 3.2420905695350333e-13

Token: I
      -- Top tokens --
   I: 1.0
   this: 1.7975648503336036e-16
   my: 7.999980100584632e-19
   i: 4.267395762701532e-23
   in: 2.2425649066771086e-24

Token: am
      -- Top tokens --
   am: 0.9986220598220825
   ’: 0.0010929569834843278
   ': 0.00028478875174187124
   have: 1.4038313622677379e-07
   live: 4.1338052980677276e-09

Token: a
      -- Top tokens --
   a: 0.9999997615814209
   the: 2.706961481635517e-07
   : 5.0966402653784826e-08
   an: 1.8923

In [28]:
print(f'layer 8: {output_8}, \nlayer 16: {output_16}, \nlayer 24: {output_24}, \nlayer 32: {output_32}')

layer 8: <s> My name is Prem Prem Singh hel Sonhamanaço AJAX, 
layer 16: <s> My name is Außergr péri Umidenseurclesoft™, 
layer 24: <s> My name is in bold ital ital ital ital ital Ital ital Ital, 
layer 32: <s> My name is Katie and I am a 20 year
