## Run model and collect predictions on BABILong

In [1]:
import os
os.chdir('..')
os.environ['CUDA_VISIBLE_DEVICES'] = '5'

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets
from tqdm.auto import tqdm
import pandas as pd
import time
import json

from pathlib import Path

from babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
from babilong.babilong_utils import compare_answers

from modeling_amt.language_modeling import *

In [5]:
class Holder:
    def __init__(self) -> None:
        pass

args = Holder()
args.task_dataset = "qa1_single-supporting-fact"
# args.max_n_facts = 50
args.vary_n_segments = False
args.max_n_segments = 128
args.segment_size = 1024
args.sample_size = 512
args.segment_alignment = 'right'

# base_model_path = "unsloth/Llama-3.2-1B-Instruct"#"/home/jovyan/kuratov/models/Llama-3.2-1B-Instruct/"
base_model_path = "/home/jovyan/kuratov/models/Llama-3.2-1B-Instruct/"
#base_model_path = "meta-llama/Llama-3.2-1B"
dtype = torch.bfloat16
device = 'auto'

In [6]:
tokenizer = AutoTokenizer.from_pretrained(base_model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True,
                                             device_map=device, 
                                             torch_dtype=dtype,
                                             use_cache=False,
                                             attn_implementation='flash_attention_2')
if base_model_path == "meta-llama/Llama-3.2-1B":
    it_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True)
    tokenizer.chat_template = it_tokenizer.chat_template
model = model.eval()

In [7]:
from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    inference_mode=False, 
    r=8, 
    lora_alpha=32, 
    lora_dropout=0.1
    )
model = get_peft_model(model, peft_config)

In [23]:
mem_cell_args = dict(
    base_model=model.cpu(),
    num_mem_tokens=16,
)
mem_cell_args['d_mem'] = 64
mem_cell_args['wrap_pos'] = False
mem_cell_args['correction'] = False
mem_cell_args['layers_attr'] = "base_model.base_model.layers"

cell = AssociativeMemoryCell(**mem_cell_args)
model = AssociativeRecurrentWrapper(cell, 
                                segment_size=args.segment_size,
                                max_n_segments=args.max_n_segments, 
                                vary_n_segments=args.vary_n_segments,
                                segment_alignment='right',
                                k2=-1,
                                return_all_logits=False,
)

AttributeError: 'AssociativeRecurrentWrapper' object has no attribute 'get_input_embeddings'

In [8]:
# ls /home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v3/run_1

In [25]:
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_1x1024_mem16_bs64_bptt--1_from_cpt_0-1_lora/run_1/checkpoint-7500/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora/run_1/checkpoint-8000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora/run_1/checkpoint-8000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct/run_1/checkpoint-5000/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v2/run_1/checkpoint-2500/pytorch_model.bin"
# cpt_path = "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora_ct-v3/run_1/checkpoint-15500/pytorch_model.bin"
# checkpoints = [
#     "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_1x1024_mem16_bs64_bptt--1_from_cpt_0-1_lora_ct-v3/run_1/checkpoint-6500/pytorch_model.bin",
#     "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora_ct-v3/run_1/checkpoint-15500/pytorch_model.bin",
#     "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v3/run_1/checkpoint-2500/pytorch_model.bin",
#     "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_8x1024_mem16_bs64_bptt--1_from_cpt_4-8_lora_ct-v3/run_1/checkpoint-9000/pytorch_model.bin",
#     "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_linear_adamw_wd1e-03_16x1024_mem16_bs64_bptt--1_from_cpt_8-16_lora_ct-v3/run_1/checkpoint-6000/pytorch_model.bin",

# ]
# MODEL_CPT='/home/jovyan/rmt_it/data/pretrained_models/RMT-Llama-3.2-1B-Instruct-4x1024-mem16-lora-babilong-qa1-5_ct-v3/model.safetensors'

# checkpoints = [
#     '/home/jovyan/gkuzmin/rmt_it/data/pretrained_models/RMT-Llama-3.2-1B-Instruct-4x1024-mem16-lora-babilong-qa1-5_ct-v3/model.safetensors',
    
# ]
# checkpoints = [
#     '/home/jovyan/gkuzmin/rmt_it/data/pretrained_models/RMT-Llama-3.2-1B-Instruct-8x1024-mem16-lora-babilong-qa1-5_ct-v3.1/model.safetensors'
# ]

checkpoints = [
    # "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_1x1024_mem16_bs64_bptt--1_from_cpt_0-1_lora_ct-v3/run_1/checkpoint-6500/pytorch_model.bin",
    # "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_2x1024_mem16_bs64_bptt--1_from_cpt_1-2_lora_ct-v3/run_1/checkpoint-28000/pytorch_model.bin",
    # "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_4x1024_mem16_bs64_bptt--1_from_cpt_2-4_lora_ct-v3/run_1/checkpoint-24500/pytorch_model.bin",
    "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_8x1024_mem16_bs64_bptt--1_from_cpt_4-8_lora_ct-v3/run_1/checkpoint-30000/pytorch_model.bin",
    # "/home/jovyan/rmt/runs/test/babilong_multitask/meta-llama/Llama-3.2-1B-Instruct/lr_3e-04_d64_cosine_adamw_wd1e-03_16x1024_mem16_bs64_bptt--1_from_cpt_8-16_lora_ct-v3/run_1/checkpoint-6000/pytorch_model.bin",
]

eval_model_template = 'rmt-llama3.2-1b-{}-ct-v3-align_right'

dataset_name = "RMT-team/babilong"
results_folder = "/home/jovyan/rmt/babilong/babilong_evals/test/"
# results_folder = "/home/jovyan/gkuzmin/rmt_it/rmt_ift/babilong_evals/DEBUG_eval_versions/orig_eval_8s_best_bf16_no_autocast_ainst_right_our_loading/"

In [26]:
generate_kwargs = {
    'max_new_tokens': 30,
    'max_length': None,
    'num_beams': 1,
    'do_sample': False,
    'temperature': None,
    'top_p': None,
    'top_k': None,
    'pad_token_id': tokenizer.pad_token_id
}

if generate_kwargs['pad_token_id'] is None:
    generate_kwargs['pad_token_id'] = tokenizer.eos_token_id

# print(f'prompt template:\n{DEFAULT_TEMPLATE}')

In [27]:
model_name = "unsloth/Llama-3.2-1B-Instruct"
# model_cpt = '/home/jovyan/gkuzmin/rmt_it/data/pretrained_models/RMT-Llama-3.2-1B-Instruct-8x1024-mem16-lora-babilong-qa1-5_ct-v3.1/model.safetensors'
model_cpt = checkpoints[0]
use_peft = 1
segment_size = 1056
max_n_segments = 32
layers_attr = "model.layers"
add_question_prompt = "Answer with a single word."
segment_alignment = "right"
lora_r = 8
lora_alpha = 32
lora_dropout = 0.1
mem_size = 16
d_mem = 64
wrap_pos = False
no_correction = False
attend_to_previous_input = 0
api_url = ''

In [28]:
model_name = "/home/jovyan/kuratov/models/Llama-3.2-1B-Instruct/"
# model_name = ''
# model_name = "unsloth/Llama-3.2-1B-Instruct"


In [29]:
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if model_name == "meta-llama/Llama-3.2-1B":
    # hotfix - load 8b instruct tokenizer and reuse chat template from instruct tokenizer
    it_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.chat_template = it_tokenizer.chat_template
if not api_url:
    if len(model_cpt) != 0:
        model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
                                                     device_map='cpu', torch_dtype=dtype,
                                                     attn_implementation='flash_attention_2')
        if use_peft:
            peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM, 
                inference_mode=False, 
                r=lora_r, 
                lora_alpha=lora_alpha, 
                lora_dropout=lora_dropout,
                )
            model = get_peft_model(model, peft_config)
        device = "cpu"
        if d_mem is not None:
            mem_cell_args = dict(
                base_model=model,
                num_mem_tokens=mem_size,
            )
            # additional parameters for ARMT model
            mem_cell_args['d_mem'] = d_mem
            mem_cell_args['wrap_pos'] = wrap_pos
            mem_cell_args['correction'] = not(no_correction)
            mem_cell_args['use_lora'] = use_peft
            if layers_attr is not None:
                mem_cell_args['layers_attr'] = layers_attr
            # if attend_to_previous_input is not None:
            #     mem_cell_args['attend_to_previous_input'] = bool(attend_to_previous_input)
            cell = AssociativeMemoryCell(**mem_cell_args)
            model = AssociativeRecurrentWrapper(cell,
                                                segment_size=segment_size-2*mem_size,
                                                max_n_segments=max_n_segments,
                                                segment_alignment=segment_alignment,
                                                attend_to_previous_input=attend_to_previous_input,
            ).to(device)
        else:
            cell = MemoryCell(model, num_mem_tokens=mem_size)
            model = RecurrentWrapper(cell,
                                     segment_size=segment_size-2*mem_size,
                                     max_n_segments=max_n_segments,
            ).to(device)
        #model = model.to(dtype)
        try:
            cpt = torch.load(model_cpt, map_location=device)    
            model.load_state_dict(cpt, strict=True)
        except:
            # if the model saved in safetensors
            from safetensors.torch import load_model
            load_model(model, model_cpt, device=device)
            #model.load_state_dict(model_cpt, map_location=device)
        
        device = 'cuda:0'
        model.to(device)
        model.to(dtype)
        model.name_or_path = "custom_rmt" # workaround
        model.device = device
    # load the model locally if llamacpp API is not used
    else:
        try:
            print('trying to load model with flash attention 2...')
            model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
                                                         device_map='auto', torch_dtype=dtype,
                                                         attn_implementation='flash_attention_2')
        except ValueError as e:
            print(e)
            print('trying to load model without flash attention 2...')
            model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,
                                                         device_map='auto', torch_dtype=dtype)

    model = model.eval()

In [30]:
model.rmt_config

{'segment_size': 1024,
 'max_n_segments': 32,
 'segment_alignment': 'right',
 'attend_to_previous_input': 0}

In [31]:
tokenizer.eos_token_id

128009

In [32]:
# cpt

In [None]:
for n, p in model.memory_cell.named_parameters():
    print(n, p[0, :10])

memory tensor([ 0.0070, -0.0124,  0.0410, -0.0322,  0.0082,  0.0262,  0.0776,  0.0442,
        -0.0332,  0.0835], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.embed_tokens.weight tensor([ 0.0031,  0.0178,  0.0210, -0.0101,  0.0070,  0.0068,  0.0139,  0.0052,
         0.0094,  0.0062], device='cuda:0', dtype=torch.bfloat16)
model.base_model.model.model.layers.0.W_mq.weight tensor([-0.0320, -0.0195,  0.0037, -0.0053,  0.0206,  0.0063, -0.0074,  0.0386,
        -0.0120,  0.0060], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.layers.0.W_mk.weight tensor([ 0.0117,  0.0054,  0.0014, -0.0046,  0.0020, -0.0098, -0.0021, -0.0262,
        -0.0089,  0.0078], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.layers.0.W_mv.weight tensor([ 0.0135, -0.0090, -0.0017,  0.0042, -0.0049, -0.0041,  0.0140, -0.0074,
         0.0035, -0.0069], devi

IndexError: too many indices for tensor of dimension 1

In [33]:
# model

In [99]:
model.load_state_dict(cpt, strict=True)


<All keys matched successfully>

In [118]:
for n, p in model.memory_cell.named_parameters():
    print(n, p[0, :10])

memory tensor([ 0.0070, -0.0124,  0.0410, -0.0322,  0.0082,  0.0262,  0.0776,  0.0442,
        -0.0332,  0.0835], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.embed_tokens.weight tensor([ 0.0031,  0.0178,  0.0210, -0.0101,  0.0070,  0.0068,  0.0139,  0.0052,
         0.0094,  0.0062], device='cuda:0', dtype=torch.bfloat16)
model.base_model.model.model.layers.0.W_mq.weight tensor([-0.0320, -0.0195,  0.0037, -0.0053,  0.0206,  0.0063, -0.0074,  0.0386,
        -0.0120,  0.0060], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.layers.0.W_mk.weight tensor([ 0.0117,  0.0054,  0.0014, -0.0046,  0.0020, -0.0098, -0.0021, -0.0262,
        -0.0089,  0.0078], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SliceBackward0>)
model.base_model.model.model.layers.0.W_mv.weight tensor([ 0.0135, -0.0090, -0.0017,  0.0042, -0.0049, -0.0041,  0.0140, -0.0074,
         0.0035, -0.0069], devi

IndexError: too many indices for tensor of dimension 1

In [112]:
model.memory_cell.model.base_model.model.model.embed_tokens.weight[0:1, 0:20]

tensor([[ 0.0031,  0.0178,  0.0210, -0.0101,  0.0070,  0.0068,  0.0139,  0.0052,
          0.0094,  0.0062,  0.0052,  0.0008, -0.0009,  0.0021, -0.0247,  0.0254,
         -0.0078,  0.0267,  0.0096, -0.0137]], device='cuda:0',
       dtype=torch.bfloat16)

In [115]:
model_inputs = torch.load("tmp")
model_inputs

{'input_ids': tensor([[128000, 128006,   9125,  ...,  78191, 128007,    271]],
        device='cuda:0')}

In [109]:
generate_kwargs

{'max_new_tokens': 30,
 'max_length': None,
 'num_beams': 1,
 'do_sample': False,
 'temperature': None,
 'top_p': None,
 'top_k': None,
 'pad_token_id': 128004}

In [113]:
with torch.no_grad():
    attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
    # output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)
    output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)

In [114]:
output

tensor([[    65,  78932, 128009]], device='cuda:0')

In [84]:
# for n, p in model.memory_cell.model.base_model.named_parameters():
#     print(n, p[:10])a

In [86]:
# for n, p in model.named_parameters():
#     print(n, p[0, :10])

In [87]:
# model.memory_cell.model.base_model.model.model.layers[0].layer.self_attn.q_proj.lora_A.default.weight[0:1, 0:20]

In [91]:
model.load_state_dict(cpt, strict=True)

<All keys matched successfully>

In [37]:
model.memory_cell.model.disable_adapter_layers()
# model.memory_cell.model.enable_adapter_layers()

In [38]:
tasks = ['qa1']
tasks = ['qa1', 'qa2', 'qa3', 'qa4', 'qa5']
split_names = ['0k', '1k', '2k', '4k', '8k', '16k']
split_names = ['16k']

use_chat_template = True
use_instruction = False
use_post_prompt = False
use_examples = False
#model = model.to(torch.bfloat16)
for cpt_path in checkpoints:
    eval_model_name = eval_model_template.format(cpt_path.split('Llama-3.2-1B-Instruct')[1].split('_')[0])
    print('Evaluating ', eval_model_name)
    """
    device = 'cuda:0'
    if "safetensors" in cpt_path:
        load_model(model, cpt_path, device=device)
    else:
        with open(cpt_path, 'rb') as cpt:
            weights = torch.load(cpt)

        model.load_state_dict(weights)
    model.cuda()
    model = model.to(torch.bfloat16)
    model.eval()
    """


    

    for task in tqdm(tasks, desc='tasks'):
        # configure the prompt
        prompt_cfg = {
            'instruction': DEFAULT_PROMPTS[task]['instruction'] if use_instruction else '',
            'examples': DEFAULT_PROMPTS[task]['examples'] if use_examples else '',
            'post_prompt': DEFAULT_PROMPTS[task]['post_prompt'] if use_post_prompt else '',
            'template': DEFAULT_TEMPLATE,
            'chat_template': use_chat_template,
        }
        prompt_name = [f'{k}_yes' if prompt_cfg[k] else f'{k}_no' for k in prompt_cfg if k != 'template']
        prompt_name = '_'.join(prompt_name)

        for split_name in tqdm(split_names, desc='lengths'):
            # load dataset
            data = datasets.load_dataset(dataset_name, split_name)
            task_data = data[task]

            # Prepare files with predictions, prompt, and generation configurations
            outfile = Path(f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.csv')
            outfile.parent.mkdir(parents=True, exist_ok=True)
            cfg_file = f'{results_folder}/{eval_model_name}/{task}_{split_name}_{prompt_name}.json'
            json.dump({'prompt': prompt_cfg, 'generate_kwargs': generate_kwargs}, open(cfg_file, 'w'), indent=4)

            df = pd.DataFrame({'target': [], 'output': [], 'question': []})

            for sample in tqdm(task_data, desc=f'task: {task} length: {split_name}'):
                target = sample['target']
                context = sample['input']
                question = sample['question']

                # format input text
                input_text = get_formatted_input(context, question + 'Answer with a single word.', prompt_cfg['examples'],
                                                    prompt_cfg['instruction'], prompt_cfg['post_prompt'],
                                                    template=prompt_cfg['template'])

                if use_chat_template:
                    input_text = [{'role': 'user', 'content': input_text}]
                    model_inputs = tokenizer.apply_chat_template(input_text, add_generation_prompt=True,
                                                                    return_tensors='pt').to(device)
                    model_inputs = {'input_ids': model_inputs}
                else:
                    model_inputs = tokenizer(input_text, return_tensors='pt',
                                                add_special_tokens=True).to(device)

                sample_length = model_inputs['input_ids'].shape[1]
                #with torch.cuda.amp.autocast():
                with torch.no_grad():
                    attn_mask = torch.ones_like(model_inputs['input_ids'].bool().to(device))
                    output = model.generate(**model_inputs, **generate_kwargs, attention_mask=attn_mask)
                    # we need to reset memory states between samples for activation-beacon models
                    # if 'activation-beacon' in model.name_or_path and hasattr(model, 'memory'):
                    #     model.memory.reset()

                output = output[0]#[sample_length:]
                output = tokenizer.decode(output, skip_special_tokens=True).strip()

                df.loc[len(df)] = [target, output, question]
            # write results to csv file
                # df.to_csv(outfile)
                if df.shape[0] > 7:
                    1/0

Evaluating  rmt-llama3.2-1b-/lr-ct-v3-align_right


tasks:   0%|          | 0/5 [00:00<?, ?it/s]

lengths:   0%|          | 0/1 [00:00<?, ?it/s]

Using the latest cached version of the dataset since RMT-team/babilong couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration '16k' at /home/jovyan/.cache/huggingface/datasets/RMT-team___babilong/16k/0.0.0/ee0d588794c7ac098062ee0d247c733d62e94fe2 (last modified on Thu Aug  1 18:01:42 2024).


task: qa1 length: 16k:   0%|          | 0/100 [00:00<?, ?it/s]

ZeroDivisionError: division by zero

In [39]:
df

Unnamed: 0,target,output,question
0,bathroom,b b b b b b b b b b b b b b b b b b b b b b b ...,Where is Mary?
1,kitchen,hallhallhallhallhallhallhallhallhallhallhallha...,Where is Sandra?
2,kitchen,bedbedbedbedbedbedbedbedbedbedbedbedbedbedbedb...,Where is Mary?
3,kitchen,gggggggggggggggggggggggggggggg,Where is John?
4,bedroom,Kbedbed Kbedbed Kbedbed Kbedbed Kbedbed Kbedbe...,Where is Sandra?
5,office,gggggggggggggggggggggggggggggg,Where is John?
6,garden,bbbbbbbbbbbbbbbbbbbbbbbbbbbbbb,Where is Mary?
7,bathroom,bedbedbedbedbedbedbedbedbedbedbedbedbedbedbedb...,Where is Sandra?


In [36]:
df

Unnamed: 0,target,output,question
0,bathroom,bathroom,Where is Mary?
1,kitchen,kitchen,Where is Sandra?
2,kitchen,kitchen,Where is Mary?
3,kitchen,kitchen,Where is John?
4,bedroom,bedroom,Where is Sandra?
5,office,office,Where is John?
6,garden,garden,Where is Mary?
7,bathroom,bathroom,Where is Sandra?


In [None]:
df

Unnamed: 0,target,output,question
0,bathroom,bathroom,Where is Mary?
1,kitchen,kitchen,Where is Sandra?
2,kitchen,kitchen,Where is Mary?
3,kitchen,kitchen,Where is John?
4,bedroom,bedroom,Where is Sandra?
5,office,office,Where is John?
6,garden,garden,Where is Mary?
7,bathroom,bathroom,Where is Sandra?


In [56]:
target

'bathroom'

In [57]:
output

'bathroom'

In [25]:
segments = model.segment(input_ids=model_inputs['input_ids'])

In [32]:
[s['input_ids'].shape for s in segments]

[torch.Size([1, 594]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024]),
 torch.Size([1, 1024])]