- в каждом блоке трансформера несколько голов, тогда будем в части голов считать атеншн по текущим хидденам (SA(QKV)), 
а в части голов KV считаем по хидденам после предыдущих блоков и эмбеддингам токенов, и Q по текущим хидденам (CA(KV, Q))
- $Attn_i = Cat[SA(h_i), CA(h_{i-1}, h_i), CA(h_{i-2}, h_i)]$
- в первом и втором слое все головы считаются по текущему контексту, начиная с 3 делаем reflex attention
- в этой секции зафиксируем, что на SA и на каждый из CA по 2 головы (всего 6)

# The Architecture

In [1]:
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/mingpt/mingpt/trainer.py
/kaggle/input/mingpt/mingpt/bpe.py
/kaggle/input/mingpt/mingpt/model.py
/kaggle/input/mingpt/mingpt/utils.py
/kaggle/input/mingpt/mingpt/__init__.py
/kaggle/input/openwebtext-data-prepared-for-nanogpt/train.bin
/kaggle/input/openwebtext-data-prepared-for-nanogpt/val.bin


In [2]:
import sys
sys.path.append('/kaggle/input/mingpt/')

In [3]:
import mingpt.bpe
import mingpt.utils
import mingpt.model
import mingpt.trainer

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [5]:
import math

In [6]:
from mingpt.utils import set_seed
from mingpt.bpe import BPETokenizer
set_seed(3407)

In [7]:
# let's run to see if layers can be accessed by names and added to a list

In [8]:
""" Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask

    if enable_gqa:
        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value"""

' Efficient implementation equivalent to the following:\ndef scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,\n        is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:\n    L, S = query.size(-2), key.size(-2)\n    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale\n    attn_bias = torch.zeros(L, S, dtype=query.dtype)\n    if is_causal:\n        assert attn_mask is None\n        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)\n        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))\n        attn_bias.to(query.dtype)\n\n    if attn_mask is not None:\n        if attn_mask.dtype == torch.bool:\n            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))\n        else:\n            attn_bias += attn_mask\n\n    if enable_gqa:\n        key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)\n        value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)\n\n    a

In [9]:
class NewGELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [10]:
# It is easier to add 3 separate W_q, W_k, W_v for now
# models in the experiments will be relatively small anyway
# [ ] do we need tril? I guess yes, because otherwise query would look into the future
# [ ] tried to make make as close to scaled_dot_product_attention
class ReflexAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        self.key = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.query = nn.Linear(config.n_embd, config.head_size, bias=False)
        self.value = nn.Linear(config.n_embd, config.head_size, bias=False)

        #self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
        self.register_buffer('tril', torch.tril(torch.ones(config.block_size, config.block_size)))

        self.dropout = nn.Dropout(config.attn_pdrop)
        #self.n_head = config.n_head [ ] do i need this
        #self.n_embd = config.n_embd

    def forward(self, x, x_prev = None):
        B,T,C = x.shape
        q = self.query(x) # (B,T,hs)
        if x_prev is not None:
            B_prev, T_prev, C_prev = x_prev.size()
            # DEBUG print(B == B_prev, T == T_prev, C, C_prev)
            k = self.key(x_prev)   # (B,T,hs)
            v = self.value(x_prev) # (B,T,hs)
        else:
            k = self.key(x)   # (B,T,hs)
            v = self.value(x) # (B,T,hs)
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

In [11]:
# [ ] there is a mess with drop out
class MultiHeadReflexAttention(nn.Module):
    # heads in parallel
    # [ ] we just split and than concat, they are independent
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList([ReflexAttention(config) for _ in range(config.n_head)])
        self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd)
        self.dropout = nn.Dropout(config.resid_pdrop)

    """ self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.attn_dropout = nn.Dropout(config.attn_pdrop)
        self.resid_dropout = nn.Dropout(config.resid_pdrop)"""
    def forward(self, x, b_i):
        out = []
        for h_i, h in enumerate(self.heads):
            # DEBUG print('forward', x.shape, b_i, h_i)
            if b_i == 0 or b_i == 1:
                t = h(x)
                out.append(t)
                continue
            if h_i == 0 or h_i == 1:
                out.append(h(x))
            else:
                # DEBUG print(hdn.hiddens[b_i-1].shape, embed.x_embed.shape)
                if h_i == 2 or h_i == 3:
                    out.append(h(x=x, x_prev=hdn.hiddens[b_i-1]+embed.x_embed))
                elif h_i == 4 or h_i ==5:
                    out.append(h(x=x, x_prev=hdn.hiddens[b_i-2]+embed.x_embed))
        # DEBUG print('cat', [r.shape for r in out])       
        # DEBUG print('proj', config.head_size * config.n_head, config.n_embd)
        out = torch.cat(out, dim=-1) # [ ] check dim
        out = self.dropout(self.proj(out))
        return out

In [12]:
# check head size
# add config
class Block(nn.Module):

    def __init__(self, config):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        # head_size = n_embd // n_head
        self.ra = MultiHeadReflexAttention(config)
        self.mlp = nn.ModuleDict(dict(
            c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd),
            c_proj  = nn.Linear(4 * config.n_embd, config.n_embd),
            act     = NewGELU(),
            dropout = nn.Dropout(config.resid_pdrop),
        ))
        m = self.mlp
        self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x))))
        
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x, b_i):
        x = x + self.ra(self.ln1(x), b_i)
        x = x + self.mlpf(self.ln2(x))
        return x

In [13]:
# [ ] I try to get all possible params considered, but it is not that easy
# [ ] nano has blocksize 1024
# nano uses 0 dropout
class Config(): 
        # either model_type or (n_layer, n_head, n_embd) must be given in the config
        # C.model_type = 'gpt'
        n_layer = 6
        n_head = 6
        n_embd = 64*6
        # these options must be filled in externally
        vocab_size = 50257
        block_size = 1024
        head_size = n_embd // n_head
        # dropout hyperparameters
        embd_pdrop = 0
        resid_pdrop = 0
        attn_pdrop = 0

In [14]:
config = Config()

In [15]:
class Hiddens():
    hiddens = []

In [16]:
hdn = Hiddens()

In [17]:
class Embed():
    x_embed = None

In [18]:
embed = Embed()

In [19]:
class ReflexTransformer(nn.Module):
    """ Transformer with reflex attention """

    def __init__(self, config):
        super().__init__()
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.embd_pdrop),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        # report number of parameters (note we don't count the decoder parameters in lm_head)
        n_params = sum(p.numel() for p in self.transformer.parameters())
        print("number of parameters: %.2fM" % (n_params/1e6,))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def configure_optimizers(self, train_config):
        """
        This long function is unfortunately doing something very simple and is being very defensive:
        We are separating out all parameters of the model into two buckets: those that will experience
        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
        We are then returning the PyTorch optimizer object.
        """

        # separate out all parameters to those that will and won't experience regularizing weight decay
        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, )
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
                # random note: because named_modules and named_parameters are recursive
                # we will see the same tensors p many many times. but doing it this way
                # allows us to know which parent module any tensor p belongs to...
                if pn.endswith('bias'):
                    # all biases will not be decayed
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    # weights of whitelist modules will be weight decayed
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    # weights of blacklist modules will NOT be weight decayed
                    no_decay.add(fpn)

        # validate that we considered every parameter
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        # create the pytorch optimizer object
        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        #assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)

        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        embed.x_embed = self.transformer.drop(tok_emb + pos_emb)
       
        hdn.hiddens = []
        # DEBUG print(x.shape, embed.x_embed.shape)
        for b_i, block in enumerate(self.transformer.h):
            x = block(x, b_i)
            hdn.hiddens.append(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        # if we are given some desired targets also calculate the loss
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -config.block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        if non_embedding:
            n_params -= self.transformer.wpe.weight.numel()
        return n_params

    
    def estimate_mfu(self, fwdbwd_per_iter, dt):
        """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
        # first estimate the number of flops we do per iteration.
        # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
        N = self.get_num_params()
        cfg = config # [ ] was self.config
        L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
        flops_per_token = 6*N + 12*L*H*Q*T
        flops_per_fwdbwd = flops_per_token * T
        flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
        # express our flops throughput as ratio of A100 bfloat16 peak flops
        flops_achieved = flops_per_iter * (1.0/dt) # per second
        flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
        mfu = flops_achieved / flops_promised
        return mfu

In [20]:
device = 'cuda'

In [21]:
model = ReflexTransformer(config)

number of parameters: 30.33M


In [22]:
model.to(device)
model.eval();

In [23]:
tokenizer = BPETokenizer()

downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json to /root/.cache/mingpt/encoder.json
downloading https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe to /root/.cache/mingpt/vocab.bpe


In [24]:
x = tokenizer('test test 1 2 3').to(device)

In [25]:
y = model.generate(x, 10)

In [26]:
for i in range(len(y)):
    out = tokenizer.decode(y[0].cpu().squeeze())
    print('-'*10)
    print(out)

----------
test test 1 2 3 E Unleashedgame sorcery antennas raplaughter trout Anchorage Scotland


# Training on openwebtext

In [27]:
max_iters=1000
log_interval=1
eval_interval=200
eval_iters=20
learning_rate=0.00008
gradient_accumulation_steps=4
batch_size=8
compile=False

In [28]:
device = 'cuda' 
dtype = 'bfloat16'
compile = True

In [29]:
torch.manual_seed(1337)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

In [30]:
import os
# poor man's data loader
data_dir = os.path.join('/kaggle/input/openwebtext-data-prepared-for-nanogpt') # [ ] Removed ,dataset
def get_batch(split):
    # We recreate np.memmap every batch to avoid a memory leak, as per
    # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [31]:
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
iter_num = 0
best_val_loss = 1e9

In [32]:
device_type = 'cuda'

In [33]:
class Train_config():
    weight_decay = 1e-1
    betas = [0.9, 0.95]
    learning_rate = learning_rate

In [34]:
optimizer = model.configure_optimizers(Train_config()) # [ ] Remvoed device_type)

In [35]:
# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


In [36]:
# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [37]:
# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

In [38]:
import time
import numpy as np

In [39]:
ddp = False

In [40]:
master_procesc = True

In [41]:
wandb_log = False

In [42]:
eval_only = False

In [43]:
block_size = config.block_size

In [44]:
decay_lr = True

In [45]:
warmup_iters = 2000

In [46]:
master_process = True

In [47]:
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

In [48]:
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [49]:
grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0

In [50]:
model_args = dict(n_layer=config.n_layer, n_head=config.n_head, n_embd=config.n_embd, block_size=config.block_size,
                  bias=False, vocab_size=None, dropout=0) # start with model_args from command line
# [ ] dropout is set to 0
# [ ] bias is set to False
out_dir = '/kaggle/working/'

In [51]:
# training loop
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
local_iter_num = 0 # number of iterations in the lifetime of this process

raw_model = model.module if ddp else model # unwrap DDP container if needed
running_mfu = -1.0
while True:
    # determine and set the learning rate for this iteration
    lr = get_lr(iter_num) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # evaluate the loss on train/val sets and write checkpoints
    if iter_num % eval_interval == 0 and master_process:
        losses = estimate_loss()
        print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if wandb_log:
            wandb.log({
                "iter": iter_num,
                "train/loss": losses['train'],
                "val/loss": losses['val'],
                "lr": lr,
                "mfu": running_mfu*100, # convert to percentage
            })
        if losses['val'] < best_val_loss or always_save_checkpoint:
            best_val_loss = losses['val']
            if iter_num > 0:
                checkpoint = {
                    'model': raw_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'model_args': model_args,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                    'config': config,
                }
                print(f"saving checkpoint to {out_dir}")
                torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt'))
    if iter_num == 0 and eval_only:
        break

    # forward backward update, with optional gradient accumulation to simulate larger batch size
    # and using the GradScaler if data type is float16
    for micro_step in range(gradient_accumulation_steps):
        if ddp:
            # in DDP training we only need to sync gradients at the last micro step.
            # the official way to do this is with model.no_sync() context manager, but
            # I really dislike that this bloats the code and forces us to repeat code
            # looking at the source of that context manager, it just toggles this variable
            model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1)
        with ctx:
            logits, loss = model(X, Y)
            loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
        # immediately async prefetch next batch while model is doing the forward pass on the GPU
        X, Y = get_batch('train')
        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()
    # clip the gradient
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    # step the optimizer and scaler if training in fp16
    scaler.step(optimizer)
    scaler.update()
    # flush the gradients as soon as we can, no need for this memory anymore
    optimizer.zero_grad(set_to_none=True)

    # timing and logging
    t1 = time.time()
    dt = t1 - t0
    t0 = t1
    if iter_num % log_interval == 0 and master_process:
        # get loss as float. note: this is a CPU-GPU sync point
        # scale up to undo the division above, approximating the true total loss (exact would have been a sum)
        lossf = loss.item() * gradient_accumulation_steps
        if local_iter_num >= 5: # let the training loop settle a bit
            mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
            running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
        print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%")
    iter_num += 1
    local_iter_num += 1

    # termination conditions
    if iter_num > max_iters:
        break

step 0: train loss 10.9165, val loss 10.9206
iter 0: loss 10.9336, time 10809.00ms, mfu -100.00%
iter 1: loss 10.9224, time 2077.51ms, mfu -100.00%
iter 2: loss 10.9180, time 2265.09ms, mfu -100.00%
iter 3: loss 10.9296, time 2266.94ms, mfu -100.00%
iter 4: loss 10.9179, time 2268.25ms, mfu -100.00%
iter 5: loss 10.9165, time 2264.42ms, mfu 1.50%
iter 6: loss 10.9161, time 2265.29ms, mfu 1.50%
iter 7: loss 10.9019, time 2266.16ms, mfu 1.50%
iter 8: loss 10.9092, time 2265.21ms, mfu 1.50%
iter 9: loss 10.9201, time 2265.88ms, mfu 1.50%
iter 10: loss 10.9156, time 2268.37ms, mfu 1.50%
iter 11: loss 10.8911, time 2268.89ms, mfu 1.50%
iter 12: loss 10.9125, time 2268.05ms, mfu 1.50%
iter 13: loss 10.9153, time 2265.26ms, mfu 1.50%
iter 14: loss 10.8818, time 2266.48ms, mfu 1.50%
iter 15: loss 10.8835, time 2268.18ms, mfu 1.50%
iter 16: loss 10.8849, time 2265.94ms, mfu 1.50%
iter 17: loss 10.8742, time 2266.74ms, mfu 1.50%
iter 18: loss 10.8646, time 2268.11ms, mfu 1.50%
iter 19: loss 10.8

- обучение и сравнение качества обычного небольшого трансформера (например по 6 голов и 6 слоев) и reflex attention (в разных сетапах)
- любые изменения/дополнения/улучшения, которые по-вашему могут работать
- отчет об экспериментах, что получилось и что нет

# Experminets
- Add a bit of dropout
- 

# Evaluation

# Metrics ???
- MinGPT's Dataset for the Sort problem. E.g. for problem length 6 Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
- Long context find max element
- 
- Language stuff
    - Perplexity
    - BLEU
- GLUE
- SuperGLUE (reasoning and understanding in complex contexts)
- Long Range Arena (LRA)
- MNLI: For natural language inference tasks ???
- CoNLL: For named entity recognition tasks