## Run model and collect predictions on BABILong

In [1]:
import os
os.chdir('..')
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets
from tqdm.auto import tqdm
import pandas as pd
import time
import json

from pathlib import Path

from babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
from babilong.babilong_utils import compare_answers

from modeling_amt.language_modeling import *

In [3]:
class Holder:
    def __init__(self) -> None:
        pass

args = Holder()
# args.task_dataset = "qa1_single-supporting-fact"
# args.max_n_facts = 50
args.vary_n_segments = False
args.max_n_segments = 128
args.segment_size = 1024
args.sample_size = 512
args.segment_alignment = 'right'

base_model_path = "/home/jovyan/kuratov/models/Llama-3.2-1B-Instruct/"
dtype = torch.bfloat16
device = 'auto'

In [4]:
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True,
                                             device_map=device, 
                                             torch_dtype=dtype,
                                             use_cache=False,
                                             attn_implementation='flash_attention_2')
model = model.eval()

In [5]:
from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False, 
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1
    )
model = get_peft_model(model, peft_config)

In [6]:
mem_cell_args = dict(
    base_model=model.cpu(),
    num_mem_tokens=16,
)
mem_cell_args['d_mem'] = 64
mem_cell_args['wrap_pos'] = False
mem_cell_args['correction'] = False
mem_cell_args['layers_attr'] = "base_model.base_model.layers"

cell = AssociativeMemoryCell(**mem_cell_args)
model = AssociativeRecurrentWrapper(cell, 
                                segment_size=args.segment_size,
                                max_n_segments=args.max_n_segments, 
                                vary_n_segments=args.vary_n_segments,
                                segment_alignment='right',
                                k2=-1,
                                return_all_logits=False,
)

model = model.to(torch.bfloat16)  # Ensure all layers use bfloat16


In [7]:
# ls /home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v3/run_1

In [12]:
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_1x1024_mem16_bs64_bptt--1_from_cpt_0-1_lora/run_1/checkpoint-7500/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora/run_1/checkpoint-8000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora/run_1/checkpoint-8000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct/run_1/checkpoint-5000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v2/run_1/checkpoint-2500/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora_ct-v3/run_1/checkpoint-15500/pytorch_model.bin"

checkpoints = [
    "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_1x1024_mem16_bs64_bptt--1_from_cpt_0-1_lora_ct-v3/run_1/checkpoint-6500/pytorch_model.bin",
    "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora_ct-v3/run_1/checkpoint-28000/pytorch_model.bin",
    "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v3/run_1/checkpoint-24500/pytorch_model.bin",
    "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_8x1024_mem16_bs64_bptt--1_from_cpt_4-8_lora_ct-v3/run_1/checkpoint-30000/pytorch_model.bin",
    # "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_16x1024_mem16_bs64_bptt--1_from_cpt_8-16_lora_ct-v3/run_1/checkpoint-6000/pytorch_model.bin",
]

eval_model_template = 'armt-llama3.2-1b-{}-ct-v3-retrain-align_right-prompt-v2'

dataset_name = "RMT-team/babilong"
results_folder = "/home/jovyan/rmt/babilong/babilong_evals/"

In [13]:
generate_kwargs = {
    'max_new_tokens': 30,
    'max_length': None,
    'num_beams': 1,
    'do_sample': False,
    'temperature': None,
    'top_p': None,
    'top_k': None,
    'pad_token_id': tokenizer.pad_token_id
}

if generate_kwargs['pad_token_id'] is None:
    generate_kwargs['pad_token_id'] = tokenizer.eos_token_id

# print(f'prompt template:\n{DEFAULT_TEMPLATE}')

In [None]:
tasks = ['qa1', 'qa2', 'qa3', 'qa4', 'qa5']
# split_names = ['0k', '1k', '2k', '4k', '8k', '16k']
split_names = ['32k', '64k']

use_chat_template = True
use_instruction = False
use_post_prompt = False
use_examples = False

for cpt_path in checkpoints[:-1]:
    eval_model_name = eval_model_template.format(cpt_path.split('adamw_wd1e-03_')[1].split('_')[0])
    print('Evaluating ', eval_model_name)

    with open(cpt_path, 'rb') as cpt:
        weights = torch.load(cpt)

    model.load_state_dict(weights)
    model.cuda()
    # model = model.eval()

    device = 'cuda:0'

    for task in tqdm(tasks, desc='tasks'):
        # configure the prompt
        prompt_cfg = {
            'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
            'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
            'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
            'template': DEFAULT_TEMPLATE,
            'chat_template': use_chat_template,
        }
        prompt_name = [f'{k}_yes' if prompt_cfg[k] else f'{k}_no' for k in prompt_cfg if k != 'template']
        prompt_name = '_'.join(prompt_name)

        for split_name in tqdm(split_names[::-1], desc='lengths'):
            # load dataset
            data = datasets.load_dataset(dataset_name, split_name)
            task_data = data[task]

            # Prepare files with predictions, prompt, and generation configurations
            outfile = Path(f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.csv')
            outfile.parent.mkdir(parents=True, exist_ok=True)
            cfg_file = f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.json'
            json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

            df = pd.DataFrame({'target': [], 'output': [], 'question': []})

            for sample in tqdm(task_data, desc=f'task: {task} length: {split_name}'):
                target = sample['target']
                context = sample['input']
                question = sample['question']

                # format input text
                input_text = get_formatted_input(context, question + 'Answer with a single word.', prompt_cfg['examples'],
                                                    prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                                    template=prompt_cfg['template'])

                if use_chat_template:
                    input_text = [{'role': 'user', 'content': input_text}]
                    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                    return_tensors='pt').to(device)
                    model_inputs = {'input_ids': model_inputs}
                else:
                    model_inputs = tokenizer(input_text, return_tensors='pt',
                                                add_special_tokens=True).to(device)

                sample_length = model_inputs['input_ids'].shape[1]
                # with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                with torch.no_grad():
                    attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
                    output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)

                output = output[0]#[sample_length:]
                output = tokenizer.decode(output, skip_special_tokens=True).strip()

                # 1/0

                df.loc[len(df)] = [target, output, question]
            # write results to csv file
            df.to_csv(outfile)

Evaluating  armt-llama3.2-1b-1x1024-ct-v3-retrain-align_right-prompt-v2


tasks:   0%|          | 0/5 [00:00<?, ?it/s]

lengths:   0%|          | 0/2 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '64k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/64k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 10:04:31 2024).


task: qa1 length: 64k:   0%|          | 0/100 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '32k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/32k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 14:22:30 2024).


task: qa1 length: 32k:   0%|          | 0/100 [00:00<?, ?it/s]

lengths:   0%|          | 0/2 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '64k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/64k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 10:04:31 2024).


task: qa2 length: 64k:   0%|          | 0/100 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '32k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/32k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 14:22:30 2024).


task: qa2 length: 32k:   0%|          | 0/100 [00:00<?, ?it/s]

lengths:   0%|          | 0/2 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '64k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/64k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 10:04:31 2024).


task: qa3 length: 64k:   0%|          | 0/100 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '32k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/32k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Wed Jul 24 14:22:30 2024).


task: qa3 length: 32k:   0%|          | 0/100 [00:00<?, ?it/s]

In [19]:
input_text

[{'role': 'user',

### Interpret

In [190]:
use_chat_template = True
use_instruction = False
use_post_prompt = False
use_examples = False

cpt_path = checkpoints[-1]
task = 'qa1'

split_name = '4k'

model.rmt_config['segment_alignment'] = 'right'
eval_model_name = eval_model_template.format(cpt_path.split('adamw_wd1e-03_')[1].split('_')[0])
print('Evaluating ', eval_model_name)

with open(cpt_path, 'rb') as cpt:
    weights = torch.load(cpt)

model.load_state_dict(weights)
model.cuda()
model = model.eval()

device = 'cuda:0'

# configure the prompt
prompt_cfg = {
    'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
    'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
    'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
    'template': DEFAULT_TEMPLATE,
    'chat_template': use_chat_template,
}
prompt_name = [f'{k}_yes' if prompt_cfg[k] else f'{k}_no' for k in prompt_cfg if k != 'template']
prompt_name = '_'.join(prompt_name)

# load dataset
data = datasets.load_dataset(dataset_name, split_name)
task_data = data[task]

# Prepare files with predictions, prompt, and generation configurations
outfile = Path(f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.csv')
outfile.parent.mkdir(parents=True, exist_ok=True)
cfg_file = f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.json'
# json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

df = pd.DataFrame({'target': [], 'output': [], 'question': [], 'input': [], 
                   'n_segments': [], 'segments': []})

inputs = []
for sample in tqdm(task_data.select(range(20)), desc=f'task: {task} length: {split_name}'):
    target = sample['target']
    context = sample['input']
    question = sample['question']

    # format input text
    input_text = get_formatted_input(context, question, prompt_cfg['examples'],
                                        prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                        template=prompt_cfg['template'])

    if use_chat_template:

        template = "{} {} Answer with a single word."
        # context = tokenizer.decode(sample['input_tokens'])
        input_text = [
            {"role": "user", "content": template.format(context, question)},
        ]
        # input_text = [{'role': 'user', 'content': input_text}]

        model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                        return_tensors='pt').to(device)
        model_inputs = {'input_ids': model_inputs}
    else:
        model_inputs = tokenizer(input_text, return_tensors='pt',
                                    add_special_tokens=True).to(device)

    sample_length = model_inputs['input_ids'].shape[1]
    with torch.cuda.amp.autocast():
        with torch.no_grad():
            attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
            output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)
            # we need to reset memory states between samples for activation-beacon models
            # if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
            #     model.memory.reset()

    output = output[0]#[sample_length:]
    output = tokenizer.decode(output, skip_special_tokens=True).strip()
    # 1/0
    inp = tokenizer.decode(model_inputs['input_ids'][0].cpu(), add_special_tokens=True)

    segments = model.segment(input_ids=model_inputs['input_ids'])
    seg_texts = [tokenizer.decode(s['input_ids'][0], add_special_tokens=True) for s in segments]

    df.loc[len(df)] = [target, output, question, inp, len(segments), '\n------------\n'.join(seg_texts)]
# write results to csv file

# df.to_csv(outfile)

Evaluating  rmt-llama3.2-1b-8x1024-ct-v3-retrain-align_right


Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '4k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/4k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Thu Aug  1 17:59:53 2024).


task: qa1 length: 4k:   0%|          | 0/20 [00:00<?, ?it/s]

In [181]:
df

Unnamed: 0,target,output,question,input,n_segments,segments
0,bathroom,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
1,kitchen,kitchen,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,3,<|begin_of_text|><|start_header_id|>system<|en...
2,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
3,kitchen,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
4,bedroom,bedroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
5,office,office,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
6,garden,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
7,bathroom,bathroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
8,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
9,bedroom,bedroom,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...


In [191]:
df

Unnamed: 0,target,output,question,input,n_segments,segments
0,bathroom,bathroom,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
1,kitchen,kitchen,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,3,<|begin_of_text|><|start_header_id|>system<|en...
2,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
3,kitchen,kitchen,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
4,bedroom,bedroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
5,office,office,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
6,garden,garden,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
7,bathroom,bathroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
8,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
9,bedroom,bedroom,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...


In [246]:
sample = task_data[16]

target = sample['target']
context = sample['input']
question = sample['question']

# format input text
input_text = get_formatted_input(context, question, prompt_cfg['examples'],
                                    prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                    template=prompt_cfg['template'])

if use_chat_template:

    template = "{} {} Answer with a single word."
    # context = tokenizer.decode(sample['input_tokens'])
    input_text = [
        {"role": "user", "content": template.format(context, question)},
    ]
    # input_text = [{'role': 'user', 'content': input_text}]

    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                    return_tensors='pt').to(device)
    model_inputs = {'input_ids': model_inputs}
else:
    model_inputs = tokenizer(input_text, return_tensors='pt',
                                add_special_tokens=True).to(device)

sample_length = model_inputs['input_ids'].shape[1]

In [247]:
with torch.cuda.amp.autocast():
    with torch.no_grad():
        attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
        output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)
        # we need to reset memory states between samples for activation-beacon models
        # if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
        #     model.memory.reset()

output = output[0]#[sample_length:]
tokenizer.decode(output, skip_special_tokens=False)

'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!'

In [250]:
target

'kitchen'

In [248]:
with torch.cuda.amp.autocast():
    with torch.no_grad():
        attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
        output = model(**model_inputs)#, attention_mask=attn_mask)
        # output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)5            output = model(**model_inputs, attention_mask=attn_mask)

In [249]:
tokenizer.decode(output.logits.argmax(dim=-1)[0][:])

"Tags\n\n\n\n\n\nIting back Space\n\n  2022\nI's: December0 of 2023\nLet<|start_header_id|>system<|start_header_id|>\n\nLet the timeI best of the two, the two of the the, theest, for twokustering of the most of of the two of of of the the of the beginningthe of of of of the world of to of to of of or of most ofof the, variety of of the two world of of the only the ofof word of the case of<|eot_id|>'s to the lab<|eot_id|><|eot_id|> and, you same of the is is and of thetheasesuriousbers of of is the the, or is same of of ofof the thousand of of of of from of the of the many of the,the of<|eot_id|> room of the which of the, not in in thethe of of the is a that to of of the first way ofof of the of of one of the one of to be a thousand tothere of1,<|eot_id|> if of the two to of the to ofof to are to the and the to of the the, the the first are and firstthe number to of the to to to make the on ininside the the first of of of to of to to, to are no to toto of1,<|eot_id|> is to the other<|eo

In [251]:
tokenizer.decode(output.logits.argmax(dim=-1)[0][-1024*3 -1:])

",Tags k havek this such kbedimateskkbedkkkkkbedkbedkkkbedilitykbed kbedbedkbedbed bedbedbedbedbedbedbed Kbedbedbedbedkkkbedbed arebedbedbedbed kbedbedbedk k k of k kbed k and to to the city, \n\nbedODbed K k k kkbededbedbed kbedbedkbedbed the bedbedbed and I arebedbed kbed of K bedk k and onbedbedbed thebeded ofbed K Kbedbed k thek k k kkbedbedbedbedbedbedbed k K kbed a k bed ofbedbed kbedbed k bed world ofbed k bedk ofbed the bedbed ofbedbedbedbed kbed the bedbedbedbedbed is a k of be bebedbed k of bedkbed of the Kbedbed the k of thebedbed k isbed ak of k I is notbed k k are thebed k ofbedbedbed k the k to the we bed k ofbedbed kbedbed of k on k K not a about the kbedkbed I thek K I bed k arebed k with to the bed of k the kk k of I bed k ofbed in of sidebed of andbed, thebedbed kbedbed and I arekplacementedbed arebed sincebed kbed kbed the of kbedbedbed of k K kbed k of the bed kkARD K k and k thebed arebedbed k ink K k of the kbedkk the k, bed of by k ofkbed bed ofbed k k of to k k 

In [239]:
output.loss

0

In [237]:
tokenizer.decode(output.logits.argmax(dim=-1)[0][-2049:])

"bedimportk k kk kbedbed bed bedbed bed the bedbedbedbed bybedbedbedbedbedbedbedbedbedbedbedbednowledbedbedbedbed ofbedbedbedbedbedbed of ofbedbedbedbedbed ed edk kbedkbedbed though bed bedbedbed bedbedbed kbedbedbedbedkbed bed of thebededbedbed k intobed k ofbed bedbedbedbed ofbed bedbedarden bed of thebedbedbedbed a k of kbedbed we bedbed bedbedbed bedbedbedbedbed bed kbedbedbedbedbedbedbedbedbedbedbedbedbedbedbedbedbedbedbed ofbedbedbedbedbed k bedbedbedbedbed kbedbedbedbed bybed kbedbedbed bedbed k ofbed bedbedbedbedkbed least kbed bedk k kbed bedbedbedbedbed kbedbed ofbedbed kbedbedbedbedbedbedbedbed and \n\n much bed k of arebedbedbedbedbedbedbedbedbed ofbedbed thebedbedbedbed bed k frombedbedbed bedbed of k k isbedbedbed edbedbed bedbedbedbedbeded of bedbedbed ofdbedbedbedbedbedbedbedbed bedbed bedbedbedbedbedbedbedbedbedbed to bedbedbed ofbedbedbedbedbedbedbedbedbedbedbedkbed the k kbedbed ofbed bed k frombedbed I kbed bedbed bedbedbedbed kbed k bed to tokbed ofbedbedbedbedbed 

In [238]:
tokenizer.decode(model_inputs['input_ids'][0][-2049:])

'\nwhich remind us of days far earlier. Edward the Fourth and Richard the\nThird were chosen Kings, or at least had their claims to the Crown\nacknowledged, by gatherings of the citizens of London which remind us\nof the wars of Stephen and Matilda(59). Still even in this age, the\npower of Parliament was advancing(60); the anxiety of every pretender\nto get a parliamentary sanction for his claims was a sign of the\ngrowing importance of Parliament, and we get incidental notices which\nshow that a seat in the House of Commons, and that not as a knight of a\nshire, but as a burgess of a borough, was now an object of ambition for\nmen of the class from which knights of the shire were chosen, and even\nfor the sons of members of the Upper House(61). At last came the sixteenth century, the time of trial for parliamentary\ninstitutions in so many countries of Europe. Daniel travelled to the kitchen. Not a few assemblies which\nhad once been as free as our own Parliament were, during that ag

In [179]:
context

'If all this effort and\nexpenditure had resulted in success, it would be possible to keep\nsilent and shrug one\'s shoulders; but when the mode of undertaking\nthis expedition can be clearly shown to have been the direct cause of\nits failure, silence would be a crime. When Lord Wolseley told the\nsoldiers at Korti on their return from Metemmah, "It was not _your_\nfault that Gordon has perished and Khartoum fallen," the positiveness\nof his assurance may have been derived from the inner conviction of\nhis own stupendous error. The expedition was finally sanctioned in August, and the news of its\ncoming was known to General Gordon in September, before, indeed, his\nown despatches of 31st July were received in London, and broke the\nsuspense of nearly half a year. He thought that only a small force was\ncoming, under the command of Major-General Earle, and he at once, as\nalready described, sent his steamers back to Shendy, there to await\nthe troops and convey them to Khartoum. He see

In [51]:
df

Unnamed: 0,target,output,question,input,n_segments,segments
0,bathroom,bathroom,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
1,kitchen,kitchen,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,3,<|begin_of_text|><|start_header_id|>system<|en...
2,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
3,kitchen,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
4,bedroom,bedroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
5,office,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
6,garden,garden,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
7,bathroom,bathroom,Where is Sandra?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
8,kitchen,kitchen,Where is Mary?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...
9,bedroom,bedroom,Where is John?,<|begin_of_text|><|start_header_id|>system<|en...,4,<|begin_of_text|><|start_header_id|>system<|en...


In [None]:
for i, row in df.iterrows():
    print(i, row.question, row.output, row.segments)
    print('\n\n\n\n')

In [31]:
segments = model.segment(input_ids=model_inputs['input_ids'])

In [32]:
[s['input_ids'].shape for s in segments]

[torch.Size([1, 594]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024])]