In [1]:
import sys
import datasets
from transformers import AutoTokenizer
sys.path.append("..")
from babilong_utils import TaskDataset, SentenceSampler, NoiseInjectionDataset
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ### extract dataset archive
# !unzip ../data/tasks_1-20_v1-2.zip -d ../data/

In [3]:
!ls ../data/tasks_1-20_v1-2/en-10k/

qa10_indefinite-knowledge_test.txt   qa1_single-supporting-fact_test.txt
qa10_indefinite-knowledge_train.txt  qa1_single-supporting-fact_train.txt
qa11_basic-coreference_test.txt      qa20_agents-motivations_test.txt
qa11_basic-coreference_train.txt     qa20_agents-motivations_train.txt
qa12_conjunction_test.txt	     qa2_two-supporting-facts_test.txt
qa12_conjunction_train.txt	     qa2_two-supporting-facts_train.txt
qa13_compound-coreference_test.txt   qa3_three-supporting-facts_test.txt
qa13_compound-coreference_train.txt  qa3_three-supporting-facts_train.txt
qa14_time-reasoning_test.txt	     qa4_two-arg-relations_test.txt
qa14_time-reasoning_train.txt	     qa4_two-arg-relations_train.txt
qa15_basic-deduction_test.txt	     qa5_three-arg-relations_test.txt
qa15_basic-deduction_train.txt	     qa5_three-arg-relations_train.txt
qa16_basic-induction_test.txt	     qa6_yes-no-questions_test.txt
qa16_basic-induction_train.txt	     qa6_yes-no-questions_train.txt
qa17_positional-reasoning_test.

In [4]:
task = "qa2_two-supporting-facts"

In [5]:
train_path =f"../data/tasks_1-20_v1-2/en-10k/{task}_train.txt"
test_path = f"../data/tasks_1-20_v1-2/en-10k/{task}_test.txt"
noise_dataset_name = "pg19"
noise_dataset = datasets.load_dataset(noise_dataset_name)
noise_dataset_train = noise_dataset['train']
noise_dataset_test = noise_dataset['test']

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


### Load task datasets

In [6]:
# task 
task_dataset_train = TaskDataset(train_path)
task_dataset_test = TaskDataset(test_path)

In [7]:
# background text
tokenizer = AutoTokenizer.from_pretrained('gpt2')

noise_sampler_train = SentenceSampler(noise_dataset['train'], tokenizer=tokenizer, shuffle=True, random_seed=None)
noise_sampler_test = SentenceSampler(noise_dataset['test'], tokenizer=tokenizer, shuffle=True, random_seed=42)

In [8]:
sample_size = [460] # taken from config file where sample size is 480, 460 = 480 - 20
train_dataset = NoiseInjectionDataset(task_dataset=task_dataset_train,
                                        noise_sampler=noise_sampler_train,
                                        tokenizer=tokenizer,
                                        sample_size=sample_size
                                     )

test_dataset = NoiseInjectionDataset(task_dataset=task_dataset_test,
                                        noise_sampler=noise_sampler_test,
                                        tokenizer=tokenizer,
                                        sample_size=sample_size[0],
                                        mixed_length_ratio=0.0,
                                        task_start_pct=None,
                                        task_end_pct=None
                                        )

In [9]:
sample = train_dataset[0]
sample.keys()

dict_keys(['facts', 'question', 'answer', 'references', 'background_text', 'fact_positions', 'input_tokens', 'question_tokens', 'target_tokens'])

In [10]:
for f in sample['facts']:
    print(f)
print("fact position:", sample['fact_positions'])
print("question:", sample['question'])
print("\nBACKGROUND:")

background_text = tokenizer.batch_decode(sample['background_text'][:20])
for s in background_text:
    print(f'\'{s}\',')

Mary moved to the bathroom.
Sandra journeyed to the bedroom.
Mary got the football there.
John went to the kitchen.
Mary went back to the kitchen.
Mary went back to the garden.
fact position: [3 5]
question: Where is the football? 

BACKGROUND:
'Dear Sir,

    I am much obliged to you for the compliment you make me in
    thinking my approbation of any value, to tell you the truth the
    reason of my setting so little value on it myself, proceeds not so
    much from modesty, or an opinion that I cannot feel the powers of
    Poetry, or distinguish beauties from defects, but from a
    consciousness that I am unable to determine (as all excellence in
    comparative) what rank it ought to hold in the scale of Art; and
    this judgement can be possess'd I think by those only who are
    acquainted with what the world has produced of that kind.',
'I have lately had the pleasure of reading your Poem to several
    friends, who have spoken much in its commendation, and Mr.',
'Johnson
   

### Visualize one sample

In [11]:
sample.keys()

dict_keys(['facts', 'question', 'answer', 'references', 'background_text', 'fact_positions', 'input_tokens', 'question_tokens', 'target_tokens'])

In [12]:
facts = sample['facts']
question = sample['question']
answer = tokenizer.decode(sample['target_tokens'])

#background_text = sample['background_text']

input_tokens = tokenizer.decode(sample['input_tokens'])

print(f"Facts: {' '.join(facts)}")
print(f"Question: {question}")
print(f"Answer: {answer}")
print(f"References: {' '.join(sample['references'])}")
print()
#print('Background text: ', ' '.join(background_text))
print('Fact positions: ', sample['fact_positions'])
print('Combined input: ', input_tokens)

print(f"Target: {answer}")


Facts: Mary moved to the bathroom. Sandra journeyed to the bedroom. Mary got the football there. John went to the kitchen. Mary went back to the kitchen. Mary went back to the garden.
Question: Where is the football? 
Answer: garden
References: Mary got the football there. Mary went back to the garden.

Fact positions:  [3 5]
Combined input:  Dear Sir,

    I am much obliged to you for the compliment you make me in
    thinking my approbation of any value, to tell you the truth the
    reason of my setting so little value on it myself, proceeds not so
    much from modesty, or an opinion that I cannot feel the powers of
    Poetry, or distinguish beauties from defects, but from a
    consciousness that I am unable to determine (as all excellence in
    comparative) what rank it ought to hold in the scale of Art; and
    this judgement can be possess'd I think by those only who are
    acquainted with what the world has produced of that kind.I have lately had the pleasure of reading you

### collate function

In [22]:
import torch
from torch.nn.utils.rnn import pad_sequence

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
gen_token = tokenizer.encode('GEN')[0]
eos_token = tokenizer.eos_token_id
print("eos token:", eos_token, "gen token:", gen_token, "id_pad_value:", id_pad_value)

def collate_fn(batch):
    targets = [torch.tensor(b['target_tokens']) for b in batch]
    input_ids = [torch.tensor(b['input_tokens'] + [gen_token] + b['target_tokens'] + [eos_token]) for b in batch]
    gen_inputs = [torch.tensor(b['input_tokens'] + [gen_token]) for b in batch]

    attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
    labels_mask = [torch.zeros_like(b, dtype=bool) for b in input_ids]
    for m, t in zip(labels_mask, targets):
        m[-len(t) - 2:] = True

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
    gen_inputs = pad_sequence(gen_inputs, padding_value=id_pad_value, batch_first=True)
    # labels = pad_sequence(input_ids, padding_value=-100, batch_first=True)
    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
    labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

    collated = {}
    collated['input_ids'] = collated['labels'] = input_ids
    collated['input_ids_generate'] = gen_inputs
    collated['labels_mask'] = labels_mask
    collated['attention_mask'] = attention_mask.bool()
    collated['attention_mask_generate'] = (gen_inputs != id_pad_value).bool()

    collated['target_text'] = [b['answer'] for b in batch]
    
    collated['background_text'] = [b['background_text'] for b in batch]
    collated['facts'] = [b['facts'] for b in batch]
    collated['question'] = [b['question'] for b in batch]
    
    return collated

eos token: 50256 gen token: 35353 id_pad_value: 50256


In [23]:
def example():
    batch = [test_dataset[0]]
    facts_tokens = [tokenizer(list(b['facts']))['input_ids'] for b in batch]
    print(facts_tokens[0])
    print(batch[0]['facts'])

example()

[[24119, 1392, 262, 7545, 612, 13], [7554, 3888, 284, 262, 14043, 13], [50, 15918, 1816, 736, 284, 262, 9592, 13], [24119, 21650, 284, 262, 23959, 13]]
['Mary got the milk there.' 'John moved to the bedroom.'
 'Sandra went back to the kitchen.' 'Mary travelled to the hallway.']


In [24]:
# batch = [dataset_test[i] for i in range(10)]
# collated = collate_fn(batch)
# collated.keys()

In [25]:
# #labels are marked with labels_mask
# tokenizer.batch_decode([c[m][1:-1] for c, m in zip(collated['input_ids'], collated['labels_mask'])])

In [26]:
# different input_ids for .forward() and .generate()
#tokenizer.batch_decode([c[m] for c, m in zip(collated['input_ids'], collated['attention_mask'])])

In [27]:
#tokenizer.batch_decode([c[m] for c, m in zip(collated['input_ids_generate'], collated['attention_mask_generate'])])

### Dataset wrappers from the `run_finetuning_babilong_rmt.py` script

In [28]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

kwargs = {'pin_memory': True, 'num_workers': 1, 'collate_fn': collate_fn}
per_worker_batch_size = 1
seed=43

train_sampler = DistributedSampler(train_dataset, rank=0, num_replicas=1, shuffle=True, drop_last=True, seed=43)
test_sampler = DistributedSampler(test_dataset, rank=0, num_replicas=1, drop_last=False, shuffle=False)
train_dataloader = DataLoader(batch_size=per_worker_batch_size, dataset=train_dataset, sampler=train_sampler, **kwargs)
test_dataloader = DataLoader(batch_size=per_worker_batch_size, dataset=test_dataset, sampler=test_sampler, **kwargs)

### Testing Pretrained RMT on QA2

In [29]:
from modeling_rmt.language_modeling import MemoryCell
from modeling_rmt.language_modeling import RecurrentWrapper
from transformers import AutoModelForCausalLM
import os
import numpy as np
import math

memory_cell_cls = "modeling_rmt.language_modeling:MemoryCell"
recurrent_wrapper_cls = "modeling_rmt.language_modeling:RecurrentWrapper"
model_cls = "transformers:AutoModelForCausalLM"
pretrained_rmt_path = "../../runs/server/babilong_checkpoints/qa2_test/run_4/model_best"


In [30]:
import os
def load_pretrained_rmt(pretrained_rmt_path):
    num_memory_tokens = 16
    from_pretrained_base = "gpt2"
    segment_size = 512
    max_n_segments = 1
    model = AutoModelForCausalLM.from_pretrained(from_pretrained_base, use_safetensors=False)
    cell = MemoryCell(model, num_memory_tokens)
    
    model = RecurrentWrapper(cell, segment_size=segment_size, max_n_segments=max_n_segments, k2=-1)
    model_cpt = os.path.join(pretrained_rmt_path, "pytorch_model.bin")
    cpt = torch.load(model_cpt, map_location='cpu')
    model.load_state_dict(cpt, strict=False)
    return model

rmt = load_pretrained_rmt(pretrained_rmt_path).to("cuda")

In [44]:
@torch.no_grad()
def keep_for_metrics_fn(batch, output):
    # select data from batch and model output that would be used to compute metrics
    data = {}
    data['labels'] = batch['labels']
    data['loss'] = output['loss']
    data['target_text'] = batch['target_text']
    if 'logits' in output:
        data['predictions'] = torch.argmax(output['logits'].detach(), dim=-1)
        data['predicted_labels'] = [p[m] for p, m in zip(data['predictions'], batch['labels_mask'])]
    if 'generation_outputs' in output:
        data['generation_outputs'] = output['generation_outputs']
    return data


@torch.no_grad()
def metrics_fn(data, add_generation=True):
    # compute metrics based on stored labels, predictions, ...
    metrics = {}
    if 'generation_outputs' in data:
        generation_outputs = tokenizer.batch_decode([d for d in data['generation_outputs']], add_special_tokens=False)
        for i, o in enumerate(generation_outputs):
            if '<|endoftext|>' in o:
                # print(f"gt: {data['target_text'][i]}, generated {o}")
                generation_outputs[i] = o.split('<|endoftext|>')[1].strip()

        metrics['exact_match'] = np.mean([text == pred for text, pred in zip (data['target_text'], generation_outputs)])

    elif 'predictions' in data:
        y, p = data['labels'], data['predictions']
        predicted_labels = tokenizer.batch_decode(data['predicted_labels'], add_special_tokens=False)
        for i, l in enumerate(predicted_labels):
            if '<|endoftext|>' in l:
                eos_ind = predicted_labels[i].index('<|endoftext|>')
                predicted_labels[i] = predicted_labels[i][:eos_ind]
                
        metrics['exact_match'] = np.mean([text == pred for text, pred in zip (data['target_text'], predicted_labels)])
        if args.show_valid_examples > 0:
            for i in range(min(args.show_valid_examples, len(y))):
                logger.info(f'y: {y[i][-50:]}')
                logger.info(f'p: {p[i][-50:]}')

                logger.info(f"y_text: {data['target_text'][i]}")
                logger.info(f"p_text: {predicted_labels[i]}")

                logger.info('-' * 50)
    try:
        perplexity = math.exp(data["loss"].mean())
    except OverflowError:
        perplexity = float("inf")
    metrics["perplexity"] = perplexity
    
    if add_generation:
        metrics['generation_outputs'] = generation_outputs
    
    return metrics


    

In [32]:
from copy import deepcopy

def prepare_generate_args(default_kwargs, batch, device):
    generate_kwargs = deepcopy(default_kwargs)
    if 'max_length' not in generate_kwargs and 'labels' in batch:
        # if max_length is not set and labels are in subbatch, generate to the length of labels+1
        # +1 as special tokens could be generated by the model
        generate_kwargs['max_length'] = batch['labels'].shape[-1] + 1
    if 'attention_mask_generate' in batch:
        generate_kwargs['attention_mask'] = batch['attention_mask_generate'].to(device)
    elif 'attention_mask' in batch:
        generate_kwargs['attention_mask'] = batch['attention_mask']
    if 'global_attention_mask' in batch:
        generate_kwargs['global_attention_mask'] = batch['global_attention_mask']
    
    return generate_kwargs

@torch.no_grad()
def test_model_on_sample(model, batch):
    model.eval()
    device = next(model.parameters()).device
    default_generate_kwargs = {'max_new_tokens': 10, 'pad_token_id': 50256, }
    model_forward_args = {'attention_mask', 'input_ids', 'inputs_embeds', 'labels', 'labels_mask', 'output_attentions', 'output_hidden_states'}
    for k in batch:
    # filter keys in batch to pass to model only supported arguments
        if k in model_forward_args:
            batch[k] = batch[k].to(device)

    outputs = model(**{k: batch[k] for k in batch if k in model_forward_args})
    loss = outputs['loss']

    generate_kwargs = prepare_generate_args(default_generate_kwargs, batch, device)
    generation_outputs = model.generate(batch['input_ids_generate'].to(device), **generate_kwargs)
    
    outputs['generation_outputs'] = generation_outputs
    data = keep_for_metrics_fn(batch, outputs)
    metric = metrics_fn(data)
    metric['loss'] = loss
    return metric

In [33]:
samples_limit = 10
num_samples = 0
exact_matches = 0.
for batch in test_dataloader:
    batch_len = batch['input_ids'].shape[0]
    metric = test_model_on_sample(rmt, batch)
    for i in range(batch_len):
        print(f'=======SAMPLE #{i+num_samples}==========')
        print("Print ONLY FACTS:")
        for f in batch['facts'][i]:
            print(f)
        print("QUESTION:", batch['question'][i])
        print("LABEL:", batch['target_text'][i])
        print("PRED:", metric['generation_outputs'][i])
        
    exact_matches += metric['exact_match']*batch_len
    num_samples += batch_len
    if num_samples >= samples_limit: break

print("accuracy:",  exact_matches/num_samples)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Print ONLY FACTS:
Mary got the milk there.
John moved to the bedroom.
Sandra went back to the kitchen.
Mary travelled to the hallway.
QUESTION: Where is the milk? 
LABEL: hallway
PRED: hallway
Print ONLY FACTS:
Mary got the milk there.
John moved to the bedroom.
Sandra went back to the kitchen.
Mary travelled to the hallway.
John got the football there.
John went to the hallway.
QUESTION: Where is the football? 
LABEL: hallway
PRED: hallway
Print ONLY FACTS:
Mary got the milk there.
John moved to the bedroom.
Sandra went back to the kitchen.
Mary travelled to the hallway.
John got the football there.
John went to the hallway.
John put down the football.
Mary went to the garden.
QUESTION: Where is the football? 
LABEL: hallway
PRED: hallway
Print ONLY FACTS:
Mary got the milk there.
John moved to the bedroom.
Sandra went back to the kitchen.
Mary travelled to the hallway.
John got the football there.
John went to the hallway.
John put down the football.
Mary went to the garden.
John wen

In [47]:
@torch.no_grad()
def total_test(dataloader):
    L = len(dataloader)
    exact_matches = 0.
    total_samples = 0
    losses = []
    for i, batch in enumerate(dataloader):    
        bs = batch['input_ids'].shape[0]
        metric = test_model_on_sample(rmt, batch)

        total_samples += bs
        exact_matches += metric['exact_match']*bs
        losses.append(metric['loss'].item())
        print(f'\rcurr_index={i+1}/{L}, exact_match={exact_matches/total_samples:.3f}, E[loss]={np.mean(losses):.3f}', end="")
    
    print(f"\nexact_match: {exact_matches/total_samples:.4f} E[loss]={np.mean(losses):.4f}")

total_test(test_dataloader)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


curr_index=999/999, exact_match=0.622, E[loss]=0.596
exact_match: 0.6216 E[loss]=0.5959


Loss is too high and accuracy is to low somehow...

## Let's try to create trainer like in `run_finetuning_babilong_rmt.py`

In [36]:
import accelerate
from transformers import HfArgumentParser
from lm_experiments_tools import Trainer, TrainerArgs, get_optimizer
%load_ext autoreload
%autoreload 2

accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
batch_metrics_fn = lambda _, y: {key: y[key] for key in y.keys() if (('loss' in key) or ('!log' in key))}


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [37]:
parser = HfArgumentParser(TrainerArgs)
parser.add_argument('--task_dataset', type=str, help="Task name", default="qa1_single-supporting-fact")
parser.add_argument('--noise_dataset', type=str, help="Task name", default='wikitext')
parser.add_argument('--noise_dataset_split', type=str, help="Task name", default=None)
parser.add_argument('--babi_path', type=str, help="path to babi folder", default="data/tasks_1-20_v1-2/en-10k")


parser.add_argument('--validate_only', action='store_true', default=False,
                    help='Skip training and run only validation. (default: False)')
parser.add_argument('--working_dir', type=str, default='.',
                    help='working dir, should be a dir with t5-experiments repo (default: .)')
parser.add_argument('--seed', type=int, default=42, help='random seed')
parser.add_argument('--show_valid_examples', type=int, default=0,
                    help='how many valid examples to show during training (default: 0)')
parser.add_argument('--block_size', type=int, default=128, help='max size of language modeling block')
parser.add_argument('--history_size', type=int, default=0, help='max number of past tokens for each block')
parser.add_argument('--data_n_workers', type=int, default=2, help='number of dataloader workers (default: 2)')

parser.add_argument('--input_prefix', type=str, default='', help='add task prefix to an input string (default: "")')

# model args
parser.add_argument('--from_pretrained', type=str, help='model name in HF Model Hub (default: "")')
parser.add_argument('--model_cfg', type=str, help='path to model configuration file (default: "")')
parser.add_argument('--model_cls', type=str, default='transformers:BertForPreTraining',
                    help='model class name to use (default: transformers:BertForPreTraining)')
parser.add_argument('--memory_cell_cls', type=str, default=None, help='cell class for RMT')
parser.add_argument('--recurrent_wrapper_cls', type=str, default=None, help='recurrent wrapper class for RMT')
parser.add_argument('--model_cpt', type=str, default=None, help='pretrained model checkpoint path')
parser.add_argument('--model_type', type=str, default='encoder-decoder',
                    help='model type, encoder, encoder-decoder, decoder, affects preprocessing '
                         '(default: encoder-decoder)')

# Babilong parameters
parser.add_argument('--sample_size', type=int, default=None, help='max number of tokens in sample')
parser.add_argument('--max_n_facts', type=int, default=None, help='drop samples with higher number of facts')
parser.add_argument('--task_start_pct', type=float, default=None, help='left border of facts in sample, between 0 and 1')
parser.add_argument('--task_end_pct', type=float, default=None, help='right border of facts in sample, between task_start_pct and 1')


# RMT args 
parser.add_argument('--segment_size', type=int, default=None, help='maximal input size of the backbone model')
parser.add_argument('--num_mem_tokens', type=int, default=None, help='number of memory tokens.')
parser.add_argument('--max_n_segments', type=int, default=1, help='maximal segment number')
parser.add_argument('--vary_n_segments', action='store_true', default=False, help='randomly sample input size for each batch')
parser.add_argument('--mixed_length_ratio', type=float, default=0.0, help='used for mixed length curriculum. '
                    'r > 0.0 means that we will start to sample batches with lengths <= max_n_segments')
parser.add_argument('--bptt_depth', type=int, default=-1, help='max number of previous segments in gradient computation.')
parser.add_argument('--segment_alignment', type=str, help='way of aligning segments, one of right, left, center', default=None)
parser.add_argument('--k2', type=int, default=-1, help='number of last segments used by backward')
parser.add_argument('--freeze_model_weights', action='store_true', default=False,
                    help='Stop training all model weights except memory layers')
parser.add_argument('--backbone_cpt', type=str, default=None, help='backbone model checkpoint path')

# tokenizer
# todo: add wordpiece tokenizers support?
parser.add_argument('--tokenizer', type=str, default=None, help='path or name of pre-trained HF Tokenizer')

# optimizer args
parser.add_argument('--optimizer', type=str, default='AdamW', help='optimizer name: AdamW, Adafactor. (default: AdamW)')
parser.add_argument('--weight_decay', type=float, default=0.0, help='optimizer weight decay (default: 0.0)')
parser.add_argument('--scale_parameter', action='store_true', default=False,
                    help='Adafactor scale_parameter (default: False)')
parser.add_argument('--relative_step', action='store_true', default=False,
                    help='Adafactor relative_step (default: False)')
parser.add_argument('--warmup_init', action='store_true', default=False,
                    help='Adafactor warmup_init (default: False)')

# LoRA args
parser.add_argument('--use_lora', action='store_true', default=False, help='')
parser.add_argument('--lora_attn_dim', type=int, default=8, help='')
parser.add_argument('--lora_attn_alpha', type=int, default=32, help='')
parser.add_argument('--lora_dropout', type=float, default=0.1, help='')
parser.add_argument('--layers_pattern', type=str, default=None, help='')

# Parallel Adapter args
parser.add_argument('--use_adapter', action='store_true', default=False, help='')
parser.add_argument('--adapter_bottleneck_dim', type=int, default=512, help='')
parser.add_argument('--adapter_dropout', type=float, default=0.1, help='')
parser.add_argument('--adapter_scale', type=float, default=4.0, help='')

# Dataset args
parser.add_argument('--pile_subset_names', type=str, default=None, help='use only these subsets of The PILE, separated by ;')
parser.add_argument('--min_tokens_in_document', type=int, default=None, help='do not use documents shorter than this value')
parser.add_argument('--max_tokens_in_document', type=int, default=None, help='do not use documents longer than this value')

arg_string = "--validate_only --task_dataset qa2_two-supporting-facts --noise_dataset pg19 --babi_path data/tasks_1-20_v1-2/en-10k --model_path ../runs/babilong/qa1_single-supporting-fact/gpt2/lr1e-04_linear_adamw_wd1e-03_2x128_mem10_bs16_bptt--1_vary/run_1 --from_pretrained gpt2 --model_type decoder --memory_cell_cls modeling_rmt.language_modeling:MemoryCell --recurrent_wrapper_cls modeling_rmt.language_modeling:RecurrentWrapper --model_cls transformers:AutoModelForCausalLM --segment_size 512 --sample_size 480 --num_mem_tokens 16 --max_n_segments 1 --batch_size 2 --gradient_accumulation_steps 1 --num_training_steps 4000 --iters 5000 --save_best --k2 -1 --optimizer AdamW --weight_decay 0.01 --lr 1e-05 --lr_scheduler linear --num_warmup_steps 500 --data_n_workers 1 --log_interval 50 --valid_interval 250 --optimize_metric exact_match --optimize_mode max --show_valid_examples 5 --early_stopping_patience 15 --seed 43 --clip_grad_norm 1.0 --model_cpt ../runs/server/babilong_checkpoints/qa2_test/run_4/model_best --vary_n_segments --use_generate_on_valid"
args = parser.parse_args(arg_string.split())

In [38]:
optimizer_cls = get_optimizer(args.optimizer)
if optimizer_cls is None:
    raise RuntimeError(f'{args.optimizer} was not found in optimizers, torch.optim, transformers.optimization')


# todo: group optimizer params
optimizer = optimizer_cls(rmt.parameters(), lr=args.lr, weight_decay=args.weight_decay) 

In [45]:
trainer = Trainer(
    args, 
    accelerator, 
    rmt, 
    optimizer, 
    train_dataloader, 
    test_dataloader,
    keep_for_metrics_fn=keep_for_metrics_fn, 
    metrics_fn=lambda d: metrics_fn(d, add_generation=False),
    batch_metrics_fn=batch_metrics_fn,
    generate_kwargs={"pad_token_id": id_pad_value, "max_new_tokens":10}
)

In [46]:
from accelerate.logging import get_logger
logger = get_logger('')
trainer.validate(test_dataloader, write_tb=True, split='test')

Validation:   0%|                                       | 0/999 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Validation: 100%|█████████████████████████████| 999/999 [01:09<00:00, 14.46it/s]


{'loss': 0.5959115713413048,
 'exact_match': 0.6216216216216216,
 'perplexity': 1.8146843893715991}