In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import json
import logging
import os
import math
import shutil
from pathlib import Path
from itertools import chain

# from dotenv import load_dotenv
import torch
import numpy as np
import datasets
import transformers
from torch.utils.data import DataLoader
from huggingface_hub import hf_hub_download


from transformers import AutoConfig, AutoTokenizer, HfArgumentParser, AutoModelForCausalLM

from modeling_rmt.language_modeling import MemoryCell, RecurrentWrapper
from modeling_amt.language_modeling import AssociativeMemoryCell, AssociativeRecurrentWrapper
from modeling_rmt.lm_parallel_mem import MemoryCell as PMemoryCell, RecurrentWrapper as PRecurrentWrapper

from torch.nn.utils.rnn import pad_sequence

In [2]:
nums_segments = [1024]

segment_size = 512
block_size = segment_size

test_size = 1000
batch_size = 1

model_name = 'armt'
num_mem_tokens = 10
model_path = './models'
d_mem = 64

task_dataset = "qa1_single-supporting-fact"
max_n_facts = None

In [3]:
if model_name == 'armt':
    mem_cell_cls = AssociativeMemoryCell
    rec_wrap_cls = AssociativeRecurrentWrapper
elif model_name == 'rmt':
    mem_cell_cls = MemoryCell
    rec_wrap_cls = RecurrentWrapper
elif model_name == 'prmt':
    mem_cell_cls = PMemoryCell
    rec_wrap_cls = PRecurrentWrapper
if model_name != 'armt':
    d_mem = None

In [4]:

model = AutoModelForCausalLM.from_pretrained("gpt2")
mem_cell_args = dict(
        base_model=model,
        num_mem_tokens=num_mem_tokens,
)
if d_mem is not None:
    mem_cell_args['d_mem'] = d_mem


cell = mem_cell_cls(**mem_cell_args, wrap_pos=False)
model = rec_wrap_cls(cell, segment_size=block_size, k2=-1)

model_cpt = os.path.join(model_path, "model_best/pytorch_model.bin")
cpt = torch.load(model_cpt, map_location='cuda')
model.load_state_dict(cpt, strict=False)
model.eval().to('cuda')

AssociativeRecurrentWrapper(
  (memory_cell): AssociativeMemoryCell(
    (model): GPT2LMHeadModel(
      (transformer): GPT2Model(
        (wte): Embedding(50257, 768)
        (wpe): Embedding(1024, 768)
        (drop): Dropout(p=0.1, inplace=False)
        (h): ModuleList(
          (0-11): 12 x AssociativeLayerWrapper(
            (W_mq): Linear(in_features=768, out_features=64, bias=False)
            (W_mk): Linear(in_features=768, out_features=64, bias=False)
            (W_mv): Linear(in_features=768, out_features=768, bias=False)
            (W_mb): Linear(in_features=768, out_features=1, bias=True)
            (layer): GPT2Block(
              (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
              (attn): GPT2Attention(
                (c_attn): Conv1D()
                (c_proj): Conv1D()
                (attn_dropout): Dropout(p=0.1, inplace=False)
                (resid_dropout): Dropout(p=0.1, inplace=False)
              )
              (ln_2): LayerNor

In [5]:
from babilong_utils import TaskDataset, NoiseInjectionDataset, SentenceSampler

babi_path = "/home/rodkin/lab/t5-experiments/data/tasks_1-20_v1-2/en-10k"

noise_dataset = datasets.load_dataset("pg19")
noise_dataset_test = noise_dataset['test']
test_path = os.path.join(babi_path, f"{task_dataset}_test.txt")
task_dataset_test = TaskDataset(test_path, max_n_facts=max_n_facts)
qa_margin = 20
tokenizer = AutoTokenizer.from_pretrained("gpt2")
for i, num_segments in enumerate(nums_segments):
    sample_size_ = segment_size * num_segments
    sample_size = sample_size_ - qa_margin
    noise_sampler_test = SentenceSampler(noise_dataset_test, tokenizer=tokenizer, max_sentence_len=None, shuffle=True, random_seed=42)
test_dataset = NoiseInjectionDataset(task_dataset=task_dataset_test,
                                     noise_sampler=noise_sampler_test,
                                     tokenizer=tokenizer,
                                     sample_size=sample_size,
                                     mixed_length_ratio=0,
                                     task_start_pct=None,
                                     task_end_pct=None
                                    )

In [6]:
from torch.utils.data.distributed import DistributedSampler

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
gen_token = tokenizer.encode('GEN')[0]
eos_token = tokenizer.eos_token_id

def collate_fn(batch):
        targets = [torch.tensor(b['target_tokens']) for b in batch]
        input_ids = [torch.tensor(b['input_tokens'] + b['question_tokens'] + [gen_token] + b['target_tokens'] + [eos_token]) for b in batch]
        gen_inputs = [torch.tensor(b['input_tokens'] + b['question_tokens'] + [gen_token]) for b in batch]

        attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
        labels_mask = [torch.zeros_like(b, dtype=bool) for b in input_ids]
        for m, t in zip(labels_mask, targets):
            m[-len(t) - 2:] = True

        input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
        gen_inputs = pad_sequence(gen_inputs, padding_value=id_pad_value, batch_first=True)
        attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
        labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

        collated = {}
        collated['input_ids'] = collated['labels'] = input_ids
        collated['input_ids_generate'] = gen_inputs
        collated['labels_mask'] = labels_mask
        collated['attention_mask'] = attention_mask.bool()
        collated['attention_mask_generate'] = (gen_inputs != id_pad_value).bool()
        collated['target_text'] = [b['answer'] for b in batch]
        return collated

test_dataloader = DataLoader(batch_size=batch_size, dataset=test_dataset, collate_fn=collate_fn, pin_memory=True)


In [7]:
from tqdm.auto import tqdm
device = 'cuda'
metrics = []
bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
with torch.no_grad():
  for j, batch in bar:
    inp_gen = batch['input_ids_generate'].to(device)
    print(inp_gen.shape)
    gen_mask = batch['attention_mask_generate'].to(device)
    p = model.generate(input_ids=inp_gen, attention_mask=gen_mask, max_new_tokens=10)
    generation_outputs = tokenizer.batch_decode([d for d in p], add_special_tokens=False)
    for i, o in enumerate(generation_outputs):
        if '<|endoftext|>' in o:
            generation_outputs[i] = o.split('<|endoftext|>')[1].strip()
    y = batch['target_text']
    exact_match = np.mean([text == pred for text, pred in zip (y, generation_outputs)])
    metrics.append(exact_match)
    bar.set_description(f"current exact match: {np.mean(metrics)}")

metric = np.mean(metrics)


  0%|          | 0/999 [00:00<?, ?it/s]

torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Token indices sequence length is longer than the specified maximum sequence length for this model (2108 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([1, 524274])


In [8]:
def last_log_forward(model, input_ids, attention_mask):
    model.memory_cell.zero_mem()
    segmented = model.segment(input_ids=input_ids, attention_mask=attention_mask)

    for seg_num, segment in enumerate(segmented[:-1]):
        _ = model.memory_cell(**segment, output_hidden_states=True, zero_mem=False)

    final_segment = segmented[-1]
    out = model.memory_cell(**final_segment, zero_mem=False)
    model.memory_cell.zero_mem()
    return out

In [9]:
from tqdm.auto import tqdm

metrics = []
bar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
device = 'cuda'
with torch.no_grad():
  for j, batch in bar:
    inp = batch['input_ids'].to(device)
    mask = batch['attention_mask'].to(device)
    output = last_log_forward(model, inp, mask)
    predictions = torch.argmax(output['logits'], dim=-1)
    labels_mask = model.segment(labels_mask=batch['labels_mask'])[-1]['labels_mask']

    predicted_labels = [p[m] for p, m in zip(predictions, labels_mask)]
    predicted_labels = tokenizer.batch_decode(predicted_labels, add_special_tokens=False)
    for i, l in enumerate(predicted_labels):
        if '<|endoftext|>' in l:
            eos_ind = predicted_labels[i].index('<|endoftext|>')
            predicted_labels[i] = predicted_labels[i][:eos_ind]
    y = batch['target_text']
    exact_match = np.mean([text == pred for text, pred in zip (y, predicted_labels)])
    metrics.append(exact_match)
    bar.set_description(f"current exact match: {np.mean(metrics)}")

metric = np.mean(metrics)

  0%|          | 0/63 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.56 GiB (GPU 0; 10.91 GiB total capacity; 8.58 GiB already allocated; 1016.06 MiB free; 9.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [9]:
import numpy as np
import wandb
run = wandb.init(project='associative_retrieval', entity='irodkin', name=f'{model_name}')
table = wandb.Table(data=np.vstack([nums_pairs, metrics]).T, columns=['evaluated_on', 'test/exact_match'])
mem_cap = wandb.Table(data=np.vstack([nums_pairs, pairs_in_mem]).T, columns=['evaluated_on', 'test/pairs_in_mem'])


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mirodkin[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
line = run.plot_table("wandb/line/v0", table, {"x":'evaluated_on', "y":'test/exact_match'})
line_cap = run.plot_table("wandb/line/v0", mem_cap, {"x":'evaluated_on', "y":'test/pairs_in_mem'})
wandb.log({f'exact_match_on_pairs': line})
wandb.log({f'capacity_on_pairs': line_cap})
wandb.finish()