In [1]:
import torch
from transformers import GPT2Tokenizer
from GPTNeoWithIntermediates import GPTNeoWithIntermediates
from BatchTokenizer import WikiBatchTokenizer
from datasets import load_dataset

In [2]:
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1")
dataset = dataset["test"]

In [3]:
model = GPTNeoWithIntermediates.from_pretrained("EleutherAI/gpt-neo-125M")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [4]:
# int(sample_count/batch_size) + 1

In [5]:
def iterative_inference_with_intermediates(dataset, batch_size, selected_layers):

    wiki_tokenizer = WikiBatchTokenizer(dataset=dataset, batch_size=batch_size)
    sample_count = wiki_tokenizer.sort_text(ret_len=True)
    
    for iter in range(2):            # +1 only if you dont get perfect batches.  
        tokenized_batch = wiki_tokenizer.gen_batch(iter=iter).to(device)
        input_ids, attention_mask = tokenized_batch.input_ids, tokenized_batch.attention_mask

        print(tokenized_batch)
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)

        predicted_ids = torch.argmax(outputs["logits"], dim=-1)
        generated_texts = [wiki_tokenizer.tokenizer.decode(ids, skip_special_tokens=True) for ids in predicted_ids]

        if selected_layers is None:
            selected_layers = list(range(len(outputs["layer_outputs"])))
        
        selected_outputs = [
            {i: layer_output for i, layer_output in enumerate(outputs["layer_outputs"]) if i in selected_layers}
            for _ in range(batch_size)
        ]

        for idx in range(batch_size):
            print(generated_texts[idx])
            print(" ------------------------------------------------------------- > \n")

        del outputs
        del tokenized_batch
        torch.cuda.empty_cache()

    #     return [
    #         {
    #             "generated_text": generated_texts[idx],
    #             "layer_outputs": selected_outputs[idx],
    #     }
    #     for idx in range(batch_size)
    # ]
        

In [6]:
import time

In [8]:
t = time.time()
iterative_inference_with_intermediates(dataset=dataset, batch_size=3, selected_layers=[6])
print(time.time() - t)

{'input_ids': tensor([[  383,  2615,  2512,   286,  6937,   479, 16628,   318,   262,  2063,
          2488,    12,    31,  2665,   366,   406,   366,  3127,   837, 13160,
           286,   257,  2168, 47472,  1168,   837,   290,   257,   427,  2797,
          6178, 47912,   575,    13,   383,   366,   479,   366,   287,   366,
          6937,   479,   366,   318,   262,  1988,  1813,   416,   837,   220,
           198, 50256, 50256, 50256, 50256],
        [49521, 26615,   373,  4642,   287,  8533,   837,  3936,   319,  2805,
           678,   837, 12122,   764,  2399,  3397, 25107,   618,   339,   373,
           838,   837,   290,   339,   373,  4376,   416,   465,  2802,   764,
           679,  2826,   287,  6205,  5701, 16861,   355,   257,  1200,   837,
          1390,  7703,  4041, 17362,   764,   220,   198, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256],
        [  554,  8275,   837, 46810,  6149,   257,  1862, 13459,   805,   837,
          1279,  2954,    2