In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

In [3]:
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from transformers import AutoConfig
from transformers.optimization import get_scheduler

from src.datasets import ARDataset
from src.utils import ObjectView, get_cls_by_name, get_optimizer, get_fn_param_names

from collections import defaultdict
from copy import deepcopy
import contextlib
import numpy as np
from tqdm import tqdm
import time
from itertools import chain


In [None]:
args = ObjectView(dict(
    seed = 42,
    save_path = "/home/user36/metamem/runs",

    # LM model args
    arch = "gpt_neox",
    hidden_size = 128,
    num_hidden_layers = 4,
    num_attention_heads = 4,

    # AR dataset args
    num_symbols = 16,
    key_size = 2,
    value_size = 1,
    num_pairs = 16,
    ar_mode = "remember",

    pretrain_size = 100000,
    train_size = 100000,
    valid_size = 1000,
    test_size = 10000,
    data_n_workers = 4,

    # meta memory args
    num_mem_tokens = 4,
    use_lora = False,
    max_inner_iter = 1000,
    inner_target_loss = 0.0,

    # train args
    iters = 10000,
    log_interval = 100,
    valid_interval = 500,
    batch_size = 128,
    gradient_accumulation_steps = 1,

    # optimizer args
    inner_optimizer = "SGD",
    inner_lr = 1e-3,
    inner_momentum = 0.9,
    inner_weight_decay = 1e-2,
    nesterov = True,

    optimizer = "AdamW",
    lr = 3e-4,
    weight_decay = 1e-2,
    lr_scheduler = "linear",
    # num_warmup_steps = 1000,

    best_metric_value = 1.0,
    optimize_mode = 'max',
))

args['num_warmup_steps'] = args['iters'] // 10
args_cp = args

In [5]:
sep_token, impl_token, gen_token, eos_token = 100, 101, 102, 103

def collate_fn(batch, valid=False):
    keys = [b['keys'] for b in batch]
    values = [b['values'] for b in batch]
    tgt_inds = [b['target_key_idx'].item() for b in batch]
    n = len(keys[0])

    bs = len(keys)
    sep_tokens = torch.ones(bs, 1) * sep_token
    impl_tokens = torch.ones(bs, 1) * impl_token
    eos_tokens = torch.ones(bs, 1) * eos_token
    gen_tokens = torch.ones(bs, 1) * gen_token
    sample = []

    for i in range(n):
        sample.append(torch.stack([k[i] for k in keys]))
        sample.append(impl_tokens)
        sample.append(torch.stack([v[i] for v in values]))
        sample.append(sep_tokens)

    target_keys = torch.stack([k[i] for i, k in zip(tgt_inds, keys)])
    target_values = torch.stack([k[i] for i, k in zip(tgt_inds, values)])

    sample.append(target_keys)
    sample.append(gen_tokens)

    input_ids_generate = torch.cat(sample, dim=1)

    sample.append(target_values)
    sample.append(eos_tokens)
    input_ids = torch.cat(sample, dim=1)

    labels_mask = torch.zeros_like(input_ids).bool()
    labels_mask[:, -args.value_size - 2:] = True

    collated = {
        'input_ids': input_ids.long(), 
        'input_ids_generate': input_ids_generate.long(), 
        'attention_mask': torch.ones_like(input_ids).bool(),
        'attention_mask_generate': torch.ones_like(input_ids_generate).bool(),
        'labels': input_ids.long(), 
        'labels_mask': labels_mask, 
        'target_values': target_values,
    }
    return collated

In [6]:
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}

pretrain_dataset = ARDataset(args.key_size, args.value_size, num_pairs=args.num_pairs, num_samples=args.pretrain_size)
# train_dataset = ARDataset(args.key_size, args.value_size, sample_len=args.num_pairs, num_samples=args.train_size)
valid_dataset = ARDataset(args.key_size, args.value_size, num_pairs=args.num_pairs, num_samples=args.valid_size)
# test_dataset = ARDataset(args.key_size, args.value_size, sample_len=args.num_pairs, num_samples=args.test_size)

train_rnd_generator = torch.Generator()
train_rnd_generator.manual_seed(args.seed)
per_worker_batch_size = args.batch_size * args.gradient_accumulation_steps
kwargs = {'pin_memory': True, 'num_workers': args.data_n_workers}

pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=per_worker_batch_size, generator=train_rnd_generator,
                                 collate_fn=collate_fn, **kwargs)
# train_dataloader = DataLoader(train_dataset, batch_size=per_worker_batch_size,  generator=train_rnd_generator,
#                         collate_fn=collate_fn, **kwargs)
valid_dataloader = DataLoader(valid_dataset, batch_size=per_worker_batch_size,
                        collate_fn=collate_fn, **kwargs)
# test_dataloader = DataLoader(test_dataset, batch_size=per_worker_batch_size,
#                         collate_fn=collate_fn, **kwargs)

  0%|          | 0/100000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [12]:
%run ../src/configs/base_models/create_config.py --arch gpt_neox --hidden_size 128 --num_hidden_layers 4 --num_attention_heads 4

# model_config = json.load(f"{args.arch}_tiny_{args.num_hidden_layers}l{args.num_attention_heads}hd{args.hidden_size}")
cfg_path = f"/home/user36/metamem/src/configs/base_models/exp/{args.arch}_tiny_{args.num_hidden_layers}l{args.num_attention_heads}hd{args.hidden_size}.json"
model_cfg = AutoConfig.from_pretrained(cfg_path)

model_cls = get_cls_by_name(f"transformers:{model_cfg.architectures[0]}")
model = model_cls(config=model_cfg)
args = args_cp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer_cls = get_optimizer(args.optimizer)
optimizer = optimizer_cls(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer, args.num_warmup_steps, args.iters * 2)

Saving config gpt_neox_tiny_4l4hd128


In [13]:
def compute_accuracy(logits, labels_mask, target_values):
    preds = torch.argmax(logits.detach(), dim=-1)
    pred_labels = [p[m] for p, m in zip(preds, labels_mask)]

    for i, l in enumerate(pred_labels):
        if eos_token in l:
            pl = pred_labels[i].tolist()
            eos_ind = pl.index(eos_token)
            pred_labels[i] = pl[:eos_ind]

    correct = np.sum([text == pred for text, pred in zip (target_values.tolist(), pred_labels)])
    total = target_values.shape[0]
    acc = 100.0 * correct / total if total > 0 else 0.0

    return acc, correct, total

In [None]:
def validate(model, dataloader, device):
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for batch in dataloader:
            # move to device
            for k, v in batch.items():
                batch[k] = v.to(device)
            outputs = model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                labels=batch['labels'],
                labels_mask=batch['labels_mask'],
            )
            
            logits = outputs.logits

            shift_labels = batch['labels'][..., 1:].contiguous()
            shift_logits = logits[..., :-1, :].contiguous()
            flat_labels = shift_labels.view(-1)
            flat_logits = shift_logits.view(-1, shift_logits.size(-1))

            loss_fct = CrossEntropyLoss()
            shift_mask = batch['labels_mask'][..., :-1].contiguous()
            flat_labels = flat_labels[shift_mask.view(-1)]
            flat_logits = flat_logits[shift_mask.view(-1)]
        
            loss = loss_fct(flat_logits, flat_labels)

            acc, correct, total = compute_accuracy(logits, batch['labels_mask'], batch['target_values'])

            val_loss += loss.item()
            val_correct += correct
            val_total += total
        
        # idx = torch.randint(total, (5,))
        # print("Showing valid examples")
        # for i in idx:
        #     print(f"y = {batch['target_values'][i]}")
        #     print(f"p = {pred_labels[i]}")
        # print("-" * 30)

    avg_loss = val_loss / len(dataloader)
    avg_acc = 100.0 * val_correct / val_total
    return avg_loss, avg_acc

In [15]:
def fit(
    model,
    train_dataloader,
    valid_dataloader,
    optimizer,
    scheduler,
    device,
    args
):
    """
    Runs exactly args.iters iterations over train_dataloader,
    validates every args.validate_interval steps (and once at the end),
    and logs train loss & acc every args.log_interval steps.
    """
    model.to(device)
    model.train()

    train_iter = iter(train_dataloader)
    best_val_loss = float('inf')

    running_loss = 0.0
    running_correct = 0
    running_total = 0

    start_time = time.time()

    pbar = tqdm(total=args.iters, desc='Train')
    for step in range(args.iters):
        pbar.update(1)
        # fetch next batch (restarting the iterator as needed)
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_dataloader)
            batch = next(train_iter)

        # move to device
        for k, v in batch.items():
            batch[k] = v.to(device)

        # forward + backward + step
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
            labels_mask=batch['labels_mask'],
        )
        # loss = outputs.loss
        logits = outputs.logits

        shift_labels = batch['labels'][..., 1:].contiguous()
        shift_logits = logits[..., :-1, :].contiguous()
        flat_labels = shift_labels.view(-1)
        flat_logits = shift_logits.view(-1, shift_logits.size(-1))

        loss_fct = CrossEntropyLoss()
        shift_mask = batch['labels_mask'][..., :-1].contiguous()
        flat_labels = flat_labels[shift_mask.view(-1)]
        flat_logits = flat_logits[shift_mask.view(-1)]
    
        loss = loss_fct(flat_logits, flat_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # accumulate for logging
        acc, correct, total = compute_accuracy(logits, batch['labels_mask'], batch['target_values'])
        running_loss += loss.item()
        running_correct += correct
        running_total += total

        # train-side logging
        if step % args.log_interval == 0 or step == args.iters:
            elapsed = (time.time() - start_time) / step if step > 0 else 0
            avg_loss = running_loss / args.log_interval
            avg_acc = 100.0 * running_correct / running_total if running_total > 0 else 0.0
            print(f"[Train] Step {step:5d}/{args.iters:5d} • "
                  f"Loss: {avg_loss:.4f} • Acc: {avg_acc:5.2f}% • "
                  f"{elapsed:.3f}s/step")
            running_loss = 0.0
            running_correct = 0
            running_total = 0

        # periodic validation
        if step % args.valid_interval == 0:
            val_loss, val_acc = validate(model, valid_dataloader, device)
            print(f"⏸ [Valid] after {step} steps → "
                  f"Val Loss: {val_loss:.4f} • Val Acc: {val_acc:5.2f}%")
            # save best
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f"{args.save_path}/model.pth")
                print(f"  ✔ New best model saved at step {step} → {args.save_path}")

    # final validation at end if not aligned to interval
    if args.iters % args.valid_interval != 0:
        val_loss, val_acc = validate(model, valid_dataloader, device)
        print(f"⏸ [Valid] final → Val Loss: {val_loss:.4f} • Val Acc: {val_acc:5.2f}%")
        if val_loss < best_val_loss:
            torch.save(model.state_dict(), f"{args.save_path}/model.pth")
            print(f"  ✔ New best model saved at final step → {args.save_path}")


In [16]:
fit(model, pretrain_dataloader, valid_dataloader, optimizer, lr_scheduler, device, args)



[Train] Step     0/10000 • Loss: 0.0485 • Acc:  0.00% • 0.000s/step




⏸ [Valid] after 0 steps → Val Loss: 4.8345 • Val Acc:  0.00%
  ✔ New best model saved at step 0 → /home/user36/metamem/runs




[Train] Step   100/10000 • Loss: 3.9661 • Acc:  0.09% • 0.321s/step




[Train] Step   200/10000 • Loss: 2.8476 • Acc:  3.65% • 0.298s/step




[Train] Step   300/10000 • Loss: 2.1671 • Acc:  6.15% • 0.242s/step




[Train] Step   400/10000 • Loss: 1.7210 • Acc:  6.92% • 0.208s/step




[Train] Step   500/10000 • Loss: 1.5324 • Acc:  7.27% • 0.215s/step




⏸ [Valid] after 500 steps → Val Loss: 1.4871 • Val Acc:  0.05%
  ✔ New best model saved at step 500 → /home/user36/metamem/runs




[Train] Step   600/10000 • Loss: 1.4595 • Acc:  7.67% • 0.229s/step




[Train] Step   700/10000 • Loss: 1.4204 • Acc:  9.36% • 0.228s/step




[Train] Step   800/10000 • Loss: 1.3946 • Acc: 10.85% • 0.219s/step




[Train] Step   900/10000 • Loss: 1.3758 • Acc: 11.25% • 0.207s/step




[Train] Step  1000/10000 • Loss: 1.3593 • Acc: 11.96% • 0.217s/step




⏸ [Valid] after 1000 steps → Val Loss: 1.3252 • Val Acc:  0.16%
  ✔ New best model saved at step 1000 → /home/user36/metamem/runs




[Train] Step  1100/10000 • Loss: 1.3197 • Acc: 14.61% • 0.225s/step




[Train] Step  1200/10000 • Loss: 1.2823 • Acc: 15.72% • 0.217s/step




[Train] Step  1300/10000 • Loss: 1.2596 • Acc: 17.25% • 0.210s/step




[Train] Step  1400/10000 • Loss: 1.2530 • Acc: 16.91% • 0.203s/step




[Train] Step  1500/10000 • Loss: 1.2325 • Acc: 17.67% • 0.208s/step




⏸ [Valid] after 1500 steps → Val Loss: 1.2141 • Val Acc:  0.17%
  ✔ New best model saved at step 1500 → /home/user36/metamem/runs




[Train] Step  1600/10000 • Loss: 1.2310 • Acc: 17.56% • 0.214s/step




[Train] Step  1700/10000 • Loss: 1.2299 • Acc: 17.22% • 0.211s/step




[Train] Step  1800/10000 • Loss: 1.2236 • Acc: 17.49% • 0.206s/step




[Train] Step  1900/10000 • Loss: 1.2109 • Acc: 18.02% • 0.207s/step




[Train] Step  2000/10000 • Loss: 1.2011 • Acc: 18.27% • 0.211s/step




⏸ [Valid] after 2000 steps → Val Loss: 1.1851 • Val Acc:  0.17%
  ✔ New best model saved at step 2000 → /home/user36/metamem/runs




[Train] Step  2100/10000 • Loss: 1.1966 • Acc: 18.90% • 0.212s/step




[Train] Step  2200/10000 • Loss: 1.1884 • Acc: 18.58% • 0.209s/step




[Train] Step  2300/10000 • Loss: 1.1799 • Acc: 18.56% • 0.204s/step




[Train] Step  2400/10000 • Loss: 1.1790 • Acc: 18.82% • 0.208s/step




[Train] Step  2500/10000 • Loss: 1.1793 • Acc: 18.05% • 0.211s/step




⏸ [Valid] after 2500 steps → Val Loss: 1.1497 • Val Acc:  0.20%
  ✔ New best model saved at step 2500 → /home/user36/metamem/runs




[Train] Step  2600/10000 • Loss: 1.1758 • Acc: 18.10% • 0.210s/step




[Train] Step  2700/10000 • Loss: 1.1702 • Acc: 18.29% • 0.207s/step




[Train] Step  2800/10000 • Loss: 1.1597 • Acc: 19.05% • 0.205s/step




[Train] Step  2900/10000 • Loss: 1.1633 • Acc: 19.03% • 0.207s/step




[Train] Step  3000/10000 • Loss: 1.1572 • Acc: 18.85% • 0.210s/step




⏸ [Valid] after 3000 steps → Val Loss: 1.1486 • Val Acc:  0.18%
  ✔ New best model saved at step 3000 → /home/user36/metamem/runs




[Train] Step  3100/10000 • Loss: 1.1499 • Acc: 19.26% • 0.208s/step




[Train] Step  3200/10000 • Loss: 1.1579 • Acc: 18.20% • 0.205s/step




[Train] Step  3300/10000 • Loss: 1.1543 • Acc: 18.40% • 0.207s/step




[Train] Step  3400/10000 • Loss: 1.1551 • Acc: 17.88% • 0.210s/step




[Train] Step  3500/10000 • Loss: 1.1481 • Acc: 18.80% • 0.210s/step




⏸ [Valid] after 3500 steps → Val Loss: 1.1328 • Val Acc:  0.18%
  ✔ New best model saved at step 3500 → /home/user36/metamem/runs




[Train] Step  3600/10000 • Loss: 1.1431 • Acc: 19.08% • 0.208s/step




[Train] Step  3700/10000 • Loss: 1.1475 • Acc: 19.54% • 0.207s/step




[Train] Step  3800/10000 • Loss: 1.1417 • Acc: 18.92% • 0.209s/step




[Train] Step  3900/10000 • Loss: 1.1412 • Acc: 19.02% • 0.211s/step




[Train] Step  4000/10000 • Loss: 1.1437 • Acc: 18.50% • 0.210s/step




⏸ [Valid] after 4000 steps → Val Loss: 1.1291 • Val Acc:  0.18%
  ✔ New best model saved at step 4000 → /home/user36/metamem/runs




[Train] Step  4100/10000 • Loss: 1.1473 • Acc: 18.18% • 0.208s/step




[Train] Step  4200/10000 • Loss: 1.1420 • Acc: 18.74% • 0.211s/step




[Train] Step  4300/10000 • Loss: 1.1385 • Acc: 18.97% • 0.213s/step




[Train] Step  4400/10000 • Loss: 1.1357 • Acc: 19.30% • 0.211s/step




[Train] Step  4500/10000 • Loss: 1.1383 • Acc: 19.30% • 0.209s/step




⏸ [Valid] after 4500 steps → Val Loss: 1.1304 • Val Acc:  0.18%




[Train] Step  4600/10000 • Loss: 1.1338 • Acc: 19.38% • 0.209s/step




[Train] Step  4700/10000 • Loss: 1.1341 • Acc: 19.24% • 0.211s/step




[Train] Step  4800/10000 • Loss: 1.1402 • Acc: 18.64% • 0.211s/step




[Train] Step  4900/10000 • Loss: 1.1383 • Acc: 18.30% • 0.209s/step




[Train] Step  5000/10000 • Loss: 1.1351 • Acc: 18.97% • 0.207s/step




⏸ [Valid] after 5000 steps → Val Loss: 1.1225 • Val Acc:  0.19%
  ✔ New best model saved at step 5000 → /home/user36/metamem/runs




[Train] Step  5100/10000 • Loss: 1.1328 • Acc: 19.09% • 0.207s/step




[Train] Step  5200/10000 • Loss: 1.1298 • Acc: 19.19% • 0.208s/step




[Train] Step  5300/10000 • Loss: 1.1351 • Acc: 19.60% • 0.209s/step




[Train] Step  5400/10000 • Loss: 1.1247 • Acc: 19.55% • 0.207s/step




[Train] Step  5500/10000 • Loss: 1.1314 • Acc: 19.25% • 0.206s/step




⏸ [Valid] after 5500 steps → Val Loss: 1.1234 • Val Acc:  0.18%




[Train] Step  5600/10000 • Loss: 1.1369 • Acc: 18.80% • 0.206s/step




[Train] Step  5700/10000 • Loss: 1.1327 • Acc: 18.66% • 0.208s/step




[Train] Step  5800/10000 • Loss: 1.1307 • Acc: 18.79% • 0.208s/step




[Train] Step  5900/10000 • Loss: 1.1251 • Acc: 19.63% • 0.207s/step




[Train] Step  6000/10000 • Loss: 1.1277 • Acc: 19.12% • 0.205s/step
⏸ [Valid] after 6000 steps → Val Loss: 1.1270 • Val Acc:  0.18%




[Train] Step  6100/10000 • Loss: 1.1276 • Acc: 19.56% • 0.206s/step




[Train] Step  6200/10000 • Loss: 1.1220 • Acc: 20.00% • 0.207s/step




[Train] Step  6300/10000 • Loss: 1.1289 • Acc: 19.25% • 0.207s/step




[Train] Step  6400/10000 • Loss: 1.1309 • Acc: 19.05% • 0.206s/step




[Train] Step  6500/10000 • Loss: 1.1281 • Acc: 18.98% • 0.204s/step




⏸ [Valid] after 6500 steps → Val Loss: 1.1233 • Val Acc:  0.19%




[Train] Step  6600/10000 • Loss: 1.1262 • Acc: 18.97% • 0.205s/step




[Train] Step  6700/10000 • Loss: 1.1177 • Acc: 20.34% • 0.206s/step




[Train] Step  6800/10000 • Loss: 1.0999 • Acc: 21.32% • 0.206s/step




[Train] Step  6900/10000 • Loss: 1.0226 • Acc: 24.80% • 0.205s/step




[Train] Step  7000/10000 • Loss: 0.8661 • Acc: 33.30% • 0.204s/step




⏸ [Valid] after 7000 steps → Val Loss: 0.7867 • Val Acc:  0.36%
  ✔ New best model saved at step 7000 → /home/user36/metamem/runs




[Train] Step  7100/10000 • Loss: 0.6059 • Acc: 49.93% • 0.205s/step




[Train] Step  7200/10000 • Loss: 0.2856 • Acc: 76.12% • 0.206s/step




[Train] Step  7300/10000 • Loss: 0.1590 • Acc: 87.12% • 0.205s/step




[Train] Step  7400/10000 • Loss: 0.0960 • Acc: 92.49% • 0.204s/step




[Train] Step  7500/10000 • Loss: 0.0512 • Acc: 96.27% • 0.204s/step




⏸ [Valid] after 7500 steps → Val Loss: 0.0314 • Val Acc:  0.98%
  ✔ New best model saved at step 7500 → /home/user36/metamem/runs




[Train] Step  7600/10000 • Loss: 0.0386 • Acc: 97.44% • 0.206s/step




[Train] Step  7700/10000 • Loss: 0.0310 • Acc: 97.99% • 0.206s/step




[Train] Step  7800/10000 • Loss: 0.0285 • Acc: 97.91% • 0.205s/step




[Train] Step  7900/10000 • Loss: 0.0272 • Acc: 97.94% • 0.204s/step




[Train] Step  8000/10000 • Loss: 0.0249 • Acc: 98.17% • 0.205s/step




⏸ [Valid] after 8000 steps → Val Loss: 0.0246 • Val Acc:  0.98%
  ✔ New best model saved at step 8000 → /home/user36/metamem/runs




[Train] Step  8100/10000 • Loss: 0.0312 • Acc: 97.74% • 0.206s/step




[Train] Step  8200/10000 • Loss: 0.0260 • Acc: 97.85% • 0.206s/step




[Train] Step  8300/10000 • Loss: 0.0195 • Acc: 98.69% • 0.205s/step




[Train] Step  8400/10000 • Loss: 0.0148 • Acc: 98.89% • 0.204s/step




[Train] Step  8500/10000 • Loss: 0.0123 • Acc: 99.25% • 0.205s/step




⏸ [Valid] after 8500 steps → Val Loss: 0.0233 • Val Acc:  0.98%
  ✔ New best model saved at step 8500 → /home/user36/metamem/runs




[Train] Step  8600/10000 • Loss: 0.0109 • Acc: 99.38% • 0.206s/step




[Train] Step  8700/10000 • Loss: 0.0091 • Acc: 99.54% • 0.205s/step




[Train] Step  8800/10000 • Loss: 0.0098 • Acc: 99.34% • 0.204s/step




[Train] Step  8900/10000 • Loss: 0.0152 • Acc: 98.88% • 0.205s/step




[Train] Step  9000/10000 • Loss: 0.0094 • Acc: 99.41% • 0.206s/step




⏸ [Valid] after 9000 steps → Val Loss: 0.0100 • Val Acc:  0.99%
  ✔ New best model saved at step 9000 → /home/user36/metamem/runs




[Train] Step  9100/10000 • Loss: 0.0128 • Acc: 99.10% • 0.206s/step




[Train] Step  9200/10000 • Loss: 0.0091 • Acc: 99.43% • 0.205s/step




[Train] Step  9300/10000 • Loss: 0.0090 • Acc: 99.42% • 0.204s/step




[Train] Step  9400/10000 • Loss: 0.0083 • Acc: 99.42% • 0.205s/step




[Train] Step  9500/10000 • Loss: 0.0081 • Acc: 99.49% • 0.206s/step




⏸ [Valid] after 9500 steps → Val Loss: 0.0068 • Val Acc:  1.00%
  ✔ New best model saved at step 9500 → /home/user36/metamem/runs




KeyboardInterrupt: 

In [18]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("yurakuratov/N1-K4V4-S1_16-32_1M")

Train:  96%|█████████▌| 9575/10000 [1:32:47<04:07,  1.72it/s]


README.md:   0%|          | 0.00/453 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/43.6M [00:00<?, ?B/s]

valid-00000-of-00001.parquet:   0%|          | 0.00/218k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [23]:
ds['train'][0]

{'context': 'mVne8R!TGcU:ghK9!DnW|', 'query': '?!TGcU:', 'target': 'ghK9!|'}

In [None]:
import logging

logger_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(format=logger_fmt, level=logging.INFO)
logger = logging.getLogger('')

In [None]:
def keep_for_metrics_fn(batch, output):
    # select data from batch and model output that would be used to compute metrics
    data = {}
    if 'generation_outputs' in output:
        data['labels'] = batch['labels']
        data['labels_mask'] = batch['labels_mask']
        data['generation_outputs'] = output['generation_outputs']

    for key in batch.keys():
        if 'loss' in key: 
            data[key] = batch[key]

    return data

def metrics_fn(data):
    metrics = {}
    y, p = None, None
    if 'generation_outputs' in data:
        y = data['labels']
        p = data['generation_outputs']

        metrics['exact_match'] = np.mean([(len(p_) >= args.value_size + 1) and torch.all(torch.tensor(y_)[-args.value_size - 1:] == torch.tensor(p_[-args.value_size - 1:])) \
                                            for p_, y_ in zip (p, y)])
        if args.show_valid_examples > 0:
            for i in range(min(args.show_valid_examples, len(y))):
                logger.info(f"labels: {data['labels'][i]}")
                logger.info(f"gen: {data['generation_outputs'][i]}")
                logger.info(f'y: {y[i][-args.value_size - 1:]}')
                logger.info(f'p: {p[i][-args.value_size - 1:]}')

                logger.info('-' * 50)
    return metrics

batch_metrics_fn = lambda _, y: {key: y[key] for key in y.keys() if (('loss' in key) or ('!log' in key))}

In [None]:
args

{'seed': 42,
 'save_path': '/home/user36/metamem/runs',
 'arch': 'gpt_neox',
 'hidden_size': 128,
 'num_hidden_layers': 4,
 'num_attention_heads': 4,
 'num_symbols': 16,
 'key_size': 2,
 'value_size': 1,
 'num_pairs': 16,
 'ar_mode': 'remember',
 'pretrain_size': 100000,
 'train_size': 100000,
 'valid_size': 1000,
 'test_size': 10000,
 'data_n_workers': 4,
 'num_mem_tokens': 4,
 'use_lora': False,
 'max_inner_iter': 1000,
 'inner_target_loss': 0.0,
 'iters': 10000,
 'log_interval': 5,
 'valid_interval': 25,
 'batch_size': 128,
 'gradient_accumulation_steps': 1,
 'inner_optimizer': 'SGD',
 'inner_lr': 0.001,
 'inner_momentum': 0.9,
 'inner_weight_decay': 0.01,
 'nesterov': True,
 'optimizer': 'AdamW',
 'lr': 0.0003,
 'weight_decay': 0.01,
 'lr_scheduler': 'linear',
 'num_warmup_steps': 1000,
 '_ipython_canary_method_should_not_exist_': {}}

In [None]:
model_forward_args = set(get_fn_param_names(model.forward))
forward_kwargs = {}
generate_kwargs={'max_new_tokens': int(args.value_size * 2), 'pad_token_id': 103}

                      

def step(model, batch, optimizer, args, is_train_mode=True):
    if is_train_mode:
        model.train()
        optimizer.zero_grad()
    else:
        model.eval()

    batch_sizes = []
    for k in batch:
        # filter keys in batch to pass to model only supported arguments
        if k in batch.keys():
            batch[k] = batch[k].to(device)
            batch_sizes += [batch[k].size(dim=0)]
    if not np.all(np.array(batch_sizes) == batch_sizes[0]):
        raise RuntimeError(f'not all elements in a batch have equal dim 0 size: {batch_sizes}')
    batch_size = batch_sizes[0]

    batch_metrics = defaultdict(lambda: 0.0)
    batch_metrics_data = defaultdict(lambda: [])
    with torch.set_grad_enabled(is_train_mode):
        for j in range(0, batch_size, args.batch_size):
            is_last_batch = (j == (batch_size // args.batch_size - 1) * args.batch_size)
            # grad_sync_context = contextlib.nullcontext if is_last_batch else self.accelerator.no_sync
            grad_sync_context = contextlib.nullcontext
            with grad_sync_context(model):
                subbatch = {k: batch[k][j: j + args.batch_size].to('cuda') for k in batch}
                # filter items from batch that are not used by model forward
                outputs = model(**{k: subbatch[k] for k in subbatch if k in model_forward_args},
                                        **forward_kwargs)
                loss = outputs['loss']
                # divide loss on gradient_accumulation_steps to get average loss for sub-batches
                # no need, accelerate does it internally (need to pass gradient_accumulation_steps to accelerator)
                # loss = loss / self.args.gradient_accumulation_steps

                if not is_train_mode and args.use_generate_on_valid:
                    generate_kwargs = deepcopy(generate_kwargs)
                    if 'max_length' not in generate_kwargs and 'labels' in subbatch:
                        # if max_length is not set and labels are in subbatch, generate to the length of labels+1
                        # +1 as special tokens could be generated by the model
                        generate_kwargs['max_length'] = subbatch['labels'].shape[-1] + 1
                    if 'attention_mask' in subbatch:
                        generate_kwargs['attention_mask'] = subbatch['attention_mask']
                    if 'global_attention_mask' in subbatch:
                        generate_kwargs['global_attention_mask'] = subbatch['global_attention_mask']
                    generation_outputs = model.generate(subbatch['input_ids'], **generate_kwargs)
                    outputs['generation_outputs'] = generation_outputs

                metrics = batch_metrics_fn(subbatch, outputs)

                for k in metrics:
                    metrics[k] = metrics[k] / args.gradient_accumulation_steps
                    if isinstance(metrics[k], torch.Tensor):
                        metrics[k] = metrics[k].detach().item()
                    batch_metrics[k] += metrics[k]

                if keep_for_metrics_fn and metrics_fn:
                    for k, v in keep_for_metrics_fn(subbatch, outputs).items():
                        batch_metrics_data[k] += [v.detach().cpu() if isinstance(v, torch.Tensor) else v]

                if is_train_mode:
                    loss.backward()

        if is_train_mode:
            # log gradients norm, clip gradients and perform opt.step(), lr_scheduler.step()
            # if self.clip_grad:
            #     global_grad_norm = self._clip_gradients()
            # else:
            #     global_grad_norm = self._get_gradients_global_norm()
            # track clipped grad norms
            # global_grad_norms += [global_grad_norm]

            optimizer.step()

            if lr_scheduler:
                lr_scheduler.step()
        return batch_metrics, batch_metrics_data

In [None]:
def _train_batch_generator(train_dataloader, n_iter, args):
    for batch in train_dataloader:
        if n_iter > args.iters:
            return
        yield batch

In [None]:
batch_metrics = defaultdict(lambda: defaultdict(list))
metrics_data = defaultdict(lambda: defaultdict(list))


def collect_metrics(split: str) -> dict:
    if split is None:
        batch_metrics = defaultdict(lambda: defaultdict(list))
    else:
        batch_metrics[split] = defaultdict(list)
    if split is None:
        metrics_data = defaultdict(lambda: defaultdict(list))
    else:
        metrics_data[split] = defaultdict(list)


    metrics = {}
    metrics_keys = set(list(batch_metrics[split].keys()))
    
    # if metrics_keys != batch_metrics[split].keys():
    #     missing_metrics_keys = metrics_keys - batch_metrics[split].keys()

    metrics_keys = sorted(metrics_keys)
    for k in metrics_keys:
        metrics[k] = batch_metrics[split][k]
        metrics[k] = np.mean(metrics[k])
    # compute metrics from metrics data
    if keep_for_metrics_fn and metrics_fn:
        metrics_data = {}
        data_keys = set(list(metrics_data[split].keys()))

        # if data_keys != metrics_data[split].keys():
            # missing_data_keys = data_keys - metrics_data[split].keys()
    
        data_keys = sorted(data_keys)
        for k in data_keys:
            metrics_data[k] = metrics_data[split][k]
            m_shape = getattr(metrics_data[k][0], 'shape', None)
            if m_shape is None:
                # data is not a tensor, collect it into python list
                metrics_data[k] = list(chain.from_iterable(metrics_data[k]))
            elif len(m_shape) == 0:
                # if scalars
                metrics_data[k] = torch.stack(metrics_data[k])
            elif all(m_shape[1:] == t.shape[1:] for t in metrics_data[k]):
                # concat tensors if all shapes are equal except the first
                metrics_data[k] = torch.cat(metrics_data[k])
            else:
                # can't concat tensors with diff last shapes, so collecting them into python list
                metrics_data[k] = list(chain.from_iterable([t.tolist() for t in metrics_data[k]]))
        m = metrics_fn(metrics_data)
        if len(metrics.keys() & m.keys()) != 0:
            logger.warning(f'metrics ({m.keys()}) and batch-lvl metrics ({metrics.keys()}) have common names. '
                            f'Batch-lvl metric value would be overwritten.')
        metrics.update(m)

    metrics[split] = metrics
    return metrics

In [None]:
from typing import Dict, Union

def _add_batch_metrics(batch_metrics_add: Dict[str, Union[float, torch.Tensor]], split: str):
    for k in batch_metrics:
        batch_metrics[split][k] += [batch_metrics_add[k]]

def _add_metrics_data(metrics_data_add: Dict[str, torch.Tensor], split: str):
    for k in metrics_data:
        metrics_data[split][k] += metrics_data_add[k]

In [None]:

def validate(dataloader, split='valid') -> Dict[str, float]:
    logger.info('start validation')
    
    if split is None:
        batch_metrics = defaultdict(lambda: defaultdict(list))
    else:
        batch_metrics[split] = defaultdict(list)
    if split is None:
        metrics_data = defaultdict(lambda: defaultdict(list))
    else:
        metrics_data[split] = defaultdict(list)

    n_valid_batches = None
    try:
        n_valid_batches = len(dataloader)
    except TypeError:
        # in case if dataset has no len() method (IterableDataset?)
        n_valid_batches = None

    pbar = tqdm(total=n_valid_batches, desc='Validation')
    for batch in dataloader:
        batch_metrics, batch_metrics_data = step(batch, is_train_mode=False)
        _add_batch_metrics(batch_metrics, split=split)
        if keep_for_metrics_fn and metrics_fn:
            _add_metrics_data(batch_metrics_data, split=split)
        pbar.update()
    pbar.close()

    # metrics = collect_metrics(split=split)

    return metrics

In [None]:
def train(model, train_dataloader, valid_dataloader, optimizer) -> None:
    pbar = tqdm(total=args.iters, desc='Train')
    # pbar.update(n_iter)

    train_batches = _train_batch_generator(train_dataloader, 0, args)
    train_size = len(train_dataloader)

    metric_improved_fn = lambda old_m, new_m: old_m < new_m

    best_valid_metric = np.inf if args.optimize_mode == 'min' else -np.inf
    valid_metric = best_valid_metric
    valid_loss = np.inf
    train_loss = np.inf
    early_stopping_counter = 0
    best_metric_trigger = False

    for n_iter, batch in enumerate(train_batches):
        iteration_start = time.time()
        batch_metrics, batch_metrics_data = step(model, batch, optimizer, args, is_train_mode=True)
        iteration_time = time.time() - iteration_start
        _add_batch_metrics(batch_metrics, split='train')
        if keep_for_metrics_fn and metrics_fn:
            _add_metrics_data(batch_metrics_data, split='train')

        # logging
        # if args.log_interval and n_iter % args.log_interval == 0:
            # batch-lvl averaged metrics:
            # train_metrics = collect_metrics(split='train')
            # train_loss = train_metrics['loss']

        # validation
        # if valid_dataloader is not None and n_iter % args.valid_interval == 0:
        #     # todo: we can use other metrics than loss here
        #     valid_metrics = validate(valid_dataloader)
        #     valid_loss = valid_metrics['loss']
        #     valid_metric = valid_metrics[args.optimize_metric]
        #     if metric_improved_fn(best_valid_metric, valid_metric):
        #         best_valid_metric = valid_metric
        #         early_stopping_counter = 0
        #         logger.info(f'The best {args.optimize_metric} metric was improved to: {best_valid_metric}')
        #         if args.save_best:
        #             torch.save(model, f"{args.model_path}.pth")
        #     else:
        #         early_stopping_counter += 1
        #         logger.info(f'Metric was not improved for the last #{early_stopping_counter} evaluations')
        #     if best_valid_metric == args.best_metric_value:
        #         best_metric_trigger = True
        
        # pbar.update(1)
        # pbar.set_postfix({'train_loss': f'{train_loss:.3f}',
        #                     'valid_loss': f'{valid_loss:.3f}',
        #                     f'best_valid_{args.optimize_metric}': f'{best_valid_metric:.3f}'
        #                     })

        # if args.early_stopping_patience is not None and \
        #         early_stopping_counter > args.early_stopping_patience or best_metric_trigger:
        #     logger.info('Early stopping triggered: stopping training...')
        #     break

    # clean-up
    pbar.close()
    logger.info('Done!')

In [None]:
model.device

device(type='cpu')

In [None]:
train(model, pretrain_dataloader, valid_dataloader, optimizer)



[A[A

UnboundLocalError: cannot access local variable 'batch_metrics' where it is not associated with a value