## Prepare Data + Model

In [None]:
!cat examples/data/text_forward.txt

In [None]:
!ls -al /mnt/data4/made_workspace/newlm-output/elmo-bert-causal-en.1-percent-rerun/model

In [None]:
import torch
from newlm.lm.elmo.modeling_elmo.elmo_head import ELMOBertLMHeadModel
from newlm.lm.elmo.lm_builder import ELMOLMBuilder
from transformers import BertConfig

#### Model

In [None]:
model = ELMOBertLMHeadModel.from_pretrained(
    "/mnt/data4/made_workspace/newlm-output/elmo-bert-causal-en.1-percent-rerun/model"
)

In [None]:
model.eval()
print("Model in eval mode for consistency")

#### Data

In [None]:
%%capture

from newlm.utils.file_util import read_from_yaml
# lm builder (helper)
elmo_lm_builder = ELMOLMBuilder(
    model_config = read_from_yaml('examples/configs/run.1-percent-bert-causal.yaml'),
    tokenizer="/mnt/data4/made_workspace/newlm-output/elmo-bert-causal-en.1-percent-rerun/model",
    model_type="bert-causal-elmo"
)

# dataset-forward
train_path = "./examples/data/text_forward.txt"
ds_f = elmo_lm_builder._get_dataset(train_path)

In [None]:
# trainer (helper)
from transformers import TrainingArguments, Trainer
args = TrainingArguments(output_dir="tmpout",**config_file['lm']['hf_trainer']['args'])

# dataloader-forward
trainer = Trainer(model=model, args=args, data_collator=elmo_lm_builder.data_collator, train_dataset=ds_f,)
dl_f = trainer.get_train_dataloader() # Data Loader-forward

In [None]:
batch_f = next(iter(dl_f))
batch_f['input_ids'].shape

## Sanity Check

In [None]:
# batch_f

In [None]:
import torch

# reverse input
batch_f_input = torch.clone(batch_f['input_ids'])
batch_f_rev_input = torch.cat(
    (
        batch_f_input[0][0:1],
        torch.flip(batch_f_input[0][1:-1], [0]),
        batch_f_input[0][-1:]
    )
)
batch_f_rev_input = batch_f_rev_input.reshape(1,-1)

# reverse labels
batch_f_labels = torch.clone(batch_f['labels'])
batch_f_rev_labels = torch.cat(
    (
        batch_f_labels[0][0:1],
        torch.flip(batch_f_labels[0][1:-1], [0]),
        batch_f_labels[0][-1:]
    )
)
batch_f_rev_labels = batch_f_rev_labels.reshape(1,-1)

batch_rev = batch_f.copy()
batch_rev['input_ids'] = batch_f_rev_input
batch_rev['labels'] = batch_f_rev_labels

In [None]:
tokens_f = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_f['input_ids'][0])
tokens_f_rev = elmo_lm_builder.tokenizer.convert_ids_to_tokens(batch_rev['input_ids'][0])

import pandas as pd
pd.DataFrame({"forward": tokens_f, "reverse": tokens_f_rev})

In [None]:
batch_f['input_ids'].shape, batch_rev['input_ids'].shape

In [None]:
res = model(**batch_f) # forward

In [None]:
res = model(**batch_rev) # reverse