In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from transformers import AutoConfig
from transformers.optimization import get_scheduler

# from src.datasets import ARDataset
import datasets
from src.utils import create_noisy_ar_tokenizer

from src.utils import ObjectView, get_cls_by_name, get_optimizer, get_fn_param_names

from collections import defaultdict
from copy import deepcopy
import contextlib
import numpy as np
import tqdm
import time
from itertools import chain


In [3]:
args = ObjectView(dict(
    seed = 42,
    save_path = "/home/user36/metamem/runs",

    # LM model args
    arch = "gpt_neox",
    hidden_size = 128,
    num_hidden_layers = 4,
    num_attention_heads = 4,

    # AR dataset args
    num_symbols = 16,
    key_size = 2,
    value_size = 1,
    num_pairs = 16,
    ar_mode = "remember",

    pretrain_size = 100000,
    train_size = 100000,
    valid_size = 1000,
    test_size = 10000,
    data_n_workers = 4,

    # meta memory args
    num_mem_tokens = 4,
    use_lora = False,
    max_inner_iter = 1000,
    inner_target_loss = 0.0,

    # train args
    iters = 10000,
    log_interval = 100,
    valid_interval = 500,
    batch_size = 128,
    gradient_accumulation_steps = 1,

    # optimizer args
    inner_optimizer = "SGD",
    inner_lr = 1e-3,
    inner_momentum = 0.9,
    inner_weight_decay = 1e-2,
    nesterov = True,

    optimizer = "AdamW",
    lr = 3e-4,
    weight_decay = 1e-2,
    lr_scheduler = "linear",
    # num_warmup_steps = 1000,

    best_metric_value = 1.0,
    optimize_mode = 'max',
))

args['num_warmup_steps'] = args['iters'] // 10
args_cp = args

In [4]:
def collate_fn(batch, tokenizer):

    query = [item['context'] + item['query'] + item['target'] for item in batch]
    query_input_ids = tokenizer(query, return_tensors="pt", add_special_tokens=False,
                                padding=True, pad_to_multiple_of=8).input_ids
    # target_ids = tokenizer([item['target'] for item in batch], return_tensors="pt", add_special_tokens=False)

    # add labels_mask
    # input_seq: 0, target_seq: 1, seq = input_seq + target_seq
    labels_mask = torch.zeros_like(query_input_ids)
    for i, item in enumerate(batch):
        # context_seq_len = len(item['context'])
        query_seq_len = len(item['query']) + len(item['context'])
        target_seq_len = len(item['target'])
        labels_mask[i, query_seq_len:query_seq_len + target_seq_len] = 1

    labels = query_input_ids * labels_mask + (1 - labels_mask) * -100
    return {
        'input_ids': query_input_ids,
        'labels': labels,
        'labels_mask': labels_mask.bool(),
        # 'target_ids': target_ids,
    }

In [None]:
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}

ar_dataset = datasets.load_dataset("yurakuratov/N2-K4V4-S1_16-32_1M")
pretrain_dataset = ar_dataset["train"].select(range(100_000))
valid_dataset = ar_dataset["valid"]

train_rnd_generator = torch.Generator()
train_rnd_generator.manual_seed(args.seed)
per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}


tokenizer = create_noisy_ar_tokenizer()
pad_token_id = tokenizer.convert_tokens_to_ids('[PAD]')
bos_token_id = tokenizer.convert_tokens_to_ids('[BOS]')
eos_token_id = tokenizer.convert_tokens_to_ids('[EOS]')

collate_fn_caller = lambda batch: collate_fn(batch, tokenizer)

pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=per_worker_batch_size, generator=train_rnd_generator,
                                 collate_fn=collate_fn_caller, **kwargs)
valid_dataloader = DataLoader(valid_dataset, batch_size=per_worker_batch_size,
                        collate_fn=collate_fn_caller, **kwargs)

Downloading readme:   0%|          | 0.00/453 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/45.9M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/231k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [37]:

# %run ../src/configs/base_models/create_config.py --arch gpt_neox --hidden_size 128 --num_hidden_layers 4 --num_attention_heads 4

cfg_path = f"/home/user36/metamem/src/configs/base_models/exp/{args.arch}_tiny_{args.num_hidden_layers}l{args.num_attention_heads}hd{args.hidden_size}.json"
model_cfg = AutoConfig.from_pretrained(cfg_path)

model_cls = get_cls_by_name(f"transformers:{model_cfg.architectures[0]}")
model = model_cls(config=model_cfg)
# sd = torch.load("/home/user36/metamem/runs/neox_ar_simple.pth")
# model.load_state_dict(sd)
args = args_cp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer_cls = get_optimizer(args.optimizer)
optimizer = optimizer_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer, args.num_warmup_steps, args.iters * 2)

In [38]:
# for batch in pretrain_dataloader:
#     break

# for k, v in batch.items():
#     batch[k] = v.to(device)

# outputs = model(
#     input_ids=batch['input_ids'],
#     labels=batch['labels'],
#     labels_mask=batch['labels_mask'],
# )

In [39]:
ignore_token_ids = [tokenizer.convert_tokens_to_ids(t) for t in ['!', '|']] + [-100]

def compute_accuracy(logits, input_ids, labels_mask):
    preds = torch.argmax(logits.detach(), dim=-1)
    for t_id in ignore_token_ids:
        labels_mask &= (input_ids != t_id)

    
    shift_labels = input_ids[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous().argmax(dim=-1)

    shift_mask = labels_mask[..., 1:].contiguous()

    # target_values = [p[m] for p, m in zip(input_ids.cpu(), labels_mask.cpu())]
    # pred_labels = [p[m] for p, m in zip(preds.cpu(), labels_mask.cpu())]


    target_values = [p[m] for p, m in zip(shift_labels.cpu(), shift_mask.cpu())]
    pred_labels = [p[m] for p, m in zip(shift_logits.cpu(), shift_mask.cpu())]

    # print(target_values[:16])
    # print(pred_labels[:16])

    correct = np.sum([torch.all(text == pred).cpu().item() for text, pred in zip(target_values, pred_labels)])
    total = len(target_values)
    acc = 100.0 * correct / total if total > 0 else 0.0

    return acc, correct, total

In [40]:
def validate(model, dataloader, device):
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in dataloader:
            # move to device
            for k, v in batch.items():
                batch[k] = v.to(device)
            outputs = model(
                input_ids=batch['input_ids'],
                labels=batch['labels'],
                labels_mask=batch['labels_mask'],
            )
            
            logits = outputs.logits
            labels_mask = batch['labels_mask']
            shift_labels = batch['labels'][..., 1:].contiguous()
            shift_logits = logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))

            shift_mask = labels_mask[..., 1:].contiguous()
            flat_labels = flat_labels[shift_mask.view(-1)]
            flat_logits = flat_logits[shift_mask.view(-1)]
        
            loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100)

            acc, correct, total = compute_accuracy(logits, batch['input_ids'], labels_mask)

            val_loss += loss.item()
            val_correct += correct
            val_total += total

    avg_loss = val_loss / len(dataloader)
    avg_acc = 100.0 * val_correct / val_total
    return avg_loss, avg_acc

In [41]:
from itertools import repeat

def inf_loop(dataloader):
    for loader in repeat(dataloader):
        yield from loader

def fit(model, train_dataloader, valid_dataloader, optimizer, scheduler, device, args):
    model.to(device)
    model.train()

    best_val_loss = float('inf')
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    start_time = time.time()

    pbar = tqdm.auto.tqdm(total=args.iters, desc='Train')
    for step, batch in enumerate(inf_loop(train_dataloader)):
    # for step, batch in 
        pbar.update(1)

        # move to device
        for k, v in batch.items():
            batch[k] = v.to(device)

        # forward + backward + step
        outputs = model(
            input_ids=batch['input_ids'],
            labels=batch['labels'],
            labels_mask=batch['labels_mask'],
        )

        logits = outputs.logits
        labels_mask = batch['labels_mask']
        shift_labels = batch['labels'][..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
        flat_labels = shift_labels.view(-1)
        flat_logits = shift_logits.view(-1, shift_logits.size(-1))

        shift_mask = labels_mask[..., 1:].contiguous()
        flat_labels = flat_labels[shift_mask.view(-1)]
        flat_logits = flat_logits[shift_mask.view(-1)]
    
        loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # accumulate for logging
        acc, correct, total = compute_accuracy(logits, batch['input_ids'], labels_mask)
        running_loss += loss.item()
        running_correct += correct
        running_total += total

        # train-side logging
        if step % args.log_interval == 0 or step == args.iters:
            elapsed = (time.time() - start_time) / step if step > 0 else 0
            avg_loss = running_loss / args.log_interval
            avg_acc = 100.0 * running_correct / running_total if running_total > 0 else 0.0
            print(f"[Train] Step {step:5d}/{args.iters:5d} • "
                  f"Loss: {avg_loss:.4f} • Acc: {avg_acc:5.2f}% • "
                  f"{elapsed:.3f}s/step")
            running_loss = 0.0
            running_correct = 0
            running_total = 0

        # periodic validation
        if step % args.valid_interval == 0:
            val_loss, val_acc = validate(model, valid_dataloader, device)
            print(f"⏸ [Valid] after {step} steps → "
                  f"Val Loss: {val_loss:.4f} • Val Acc: {val_acc:5.2f}%")
            # save best
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f"{args.save_path}/model.pth")
                print(f"  ✔ New best model saved at step {step} → {args.save_path}")
        # break
        if step == args.iters:
            break

    # final validation at end if not aligned to interval
    if args.iters % args.valid_interval != 0:
        val_loss, val_acc = validate(model, valid_dataloader, device)
        print(f"⏸ [Valid] final → Val Loss: {val_loss:.4f} • Val Acc: {val_acc:5.2f}%")
        if val_loss < best_val_loss:
            torch.save(model.state_dict(), f"{args.save_path}/model.pth")
            print(f"  ✔ New best model saved at final step → {args.save_path}")


In [110]:
print(model.generate.__annotations__)

{'inputs': typing.Optional[torch.Tensor], 'generation_config': typing.Optional[transformers.generation.configuration_utils.GenerationConfig], 'logits_processor': typing.Optional[transformers.generation.logits_process.LogitsProcessorList], 'stopping_criteria': typing.Optional[transformers.generation.stopping_criteria.StoppingCriteriaList], 'prefix_allowed_tokens_fn': typing.Optional[typing.Callable[[int, torch.Tensor], list[int]]], 'synced_gpus': typing.Optional[bool], 'assistant_model': typing.Optional[ForwardRef('PreTrainedModel')], 'streamer': typing.Optional[ForwardRef('BaseStreamer')], 'negative_prompt_ids': typing.Optional[torch.Tensor], 'negative_prompt_attention_mask': typing.Optional[torch.Tensor], 'use_model_defaults': typing.Optional[bool], 'custom_generate': typing.Optional[str], 'return': typing.Union[transformers.generation.utils.GenerateDecoderOnlyOutput, transformers.generation.utils.GenerateEncoderDecoderOutput, transformers.generation.utils.GenerateBeamDecoderOnlyOutpu

In [112]:
input_ids

{'input_ids': tensor([[16, 51, 17,  8, 64, 47, 66, 49, 36,  6, 50, 68, 10, 11, 40, 65, 66, 33,
         17, 52, 69, 67, 66, 49, 36,  6, 50, 68]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1]])}

In [116]:
pretrain_dataset[0]['target']

'ghK9!|'

In [117]:
pretrain_dataset[0]["context"] + pretrain_dataset[0]["query"] + pretrain_dataset[0]['target']

'mVne8R!TGcU:ghK9!DnW|?!TGcU:ghK9!|'

In [13]:
query = pretrain_dataset[0]["context"] + pretrain_dataset[0]["query"]
input_ids = tokenizer(query, return_tensors="pt")
generate_kwargs = {"pad_token_id": -100, "max_new_tokens":10}
tokenizer.decode(model.generate(inputs=input_ids['input_ids'].cuda(), max_new_tokens=10).tolist()[0]).replace(' ', '')

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:102 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'mVne8R!TGcU:ghK9!DnW|?!TGcU:ghK9!|hhh!'

In [12]:
query = pretrain_dataset[0]["context"] + pretrain_dataset[0]["query"] + pretrain_dataset[0]['target']

input_ids = tokenizer(query, return_tensors="pt")
generate_kwargs = {"pad_token_id": -100, "max_new_tokens":10}
tokenizer.decode(model(input_ids['input_ids'].cuda()).logits.argmax(dim=-1).tolist()[0]).replace(' ', '')

'!!!!!!|!!!!m!!!!|!!!!!|K!!!ghK9!|h'

In [None]:
model.generate()

In [14]:
for batch in valid_dataloader:
    break

for k, v in batch.items():
    batch[k] = v.to(device)

outputs = model(
    input_ids=batch['input_ids'],
    labels=batch['labels'],
    labels_mask=batch['labels_mask'],
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av