In [None]:
import torch
from leap import LeapForCausalLM, LeapConfig
from transformers import TrainingArguments, Trainer, default_data_collator
from datasets import load_dataset

from itertools import chain
import logging
logging.disable(logging.INFO)

In [None]:
# t5 tokenzier, warning is nothing to worry about since we will group the texts
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-small")

In [None]:
raw_datasets = load_dataset("wikitext", "wikitext-2-v1")
column_names = raw_datasets["train"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]
block_size = 2048

def tokenize_function(examples):
    output = tokenizer(examples[text_column_name])
    return output

def group_texts(examples):
    # concatenate text
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    # drop last block
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size

    # split by chunks of block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_datasets = raw_datasets.map(
            tokenize_function,
            batched=True,
            num_proc=1,
            remove_columns=column_names,
        )

lm_dataset = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=1,
    desc=f"Grouping texts in chunks of {block_size}",
)

lm_dataset.set_format('pt')

In [None]:
# hyperparameters
training_args = TrainingArguments(
    output_dir = "./results",
    logging_strategy = "epoch",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    report_to = "none",
    learning_rate = 5e-4, 
    num_train_epochs = 10,
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    load_best_model_at_end = True,
    metric_for_best_model = "eval_loss",
    max_grad_norm = 1,
    fp16 = True
)

In [None]:
def run_leap(window_sizes):
    config = LeapConfig(
        hidden_size = 128,
        vocab_size = len(tokenizer),
        n_positions = block_size,
        n_heads = 4,
        n_layer = len(window_sizes),
        use_local_att = True,
        window_sizes = window_sizes, # will be set automatically
        hidden_dropout_prob = .1,
        rescale = 10
    )
    print(config.window_sizes)
    window_model = LeapForCausalLM(config)
    
    window_trainer = Trainer(
        model=window_model,
        args=training_args,
        data_collator=default_data_collator,
        train_dataset=lm_dataset["train"],
        eval_dataset=lm_dataset["validation"]
    )

    window_trainer.train()
    
    # free gpu memory
    del window_trainer
    window_model.cpu()
    torch.cuda.empty_cache()

# Simple experiments to find a heuristic for window values

Having the first and last layer be global attention is for sure. The first global attention layer should establish "context" like if there's task metadata at the start of the sequence. The last global attention layer to accumulate  all information. 

# 1. Let's start with a baseline

The first and last layers are global attention (as stated before) and of course we need local attention so let's try a window size of 4.

In [None]:
run_leap([2048, 4, 2048])

[2048, 4, 2048]


Epoch,Training Loss,Validation Loss
1,6.4034,5.614212
2,5.7185,5.347041
3,5.4739,5.14035
4,5.2813,5.001758
5,5.1452,4.916826
6,5.0444,4.843692
7,4.9704,4.792968
8,4.9158,4.760973
9,4.8776,4.74264
10,4.8556,4.734987


### 2. Now try a local window size of 8 to compare

In [None]:
run_leap([2048, 8, 2048])

[2048, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.4172,5.626072
2,5.7306,5.369969
3,5.4919,5.174758
4,5.3106,5.040751
5,5.1804,4.953936
6,5.0831,4.892673
7,5.0105,4.835805
8,4.9577,4.805156
9,4.9205,4.786395
10,4.898,4.778767


## 3. Seems slightly worse, let's move on. Try adding a second local attention layer vs adding another global attention layer

In [None]:
run_leap([2048, 4, 4, 2048])

[2048, 4, 4, 2048]




Epoch,Training Loss,Validation Loss
1,6.4084,5.6144
2,5.7154,5.35498
3,5.4763,5.146724
4,5.2777,4.997527
5,5.1329,4.891938
6,5.0296,4.823465
7,4.9513,4.765365
8,4.8945,4.735488
9,4.8553,4.713768
10,4.8318,4.704624


In [None]:
run_leap([2048, 4, 2048, 2048])

[2048, 4, 2048, 2048]


Epoch,Training Loss,Validation Loss
1,6.3976,5.615664
2,5.7204,5.347426
3,5.4685,5.142619
4,5.2788,5.005457
5,5.1456,4.910377
6,5.047,4.841533
7,4.9721,4.789617
8,4.9167,4.759144
9,4.8774,4.736964
10,4.8545,4.729196


## 4. How about also trying 8

In [None]:
run_leap([2048, 4, 8, 2048])

[2048, 4, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.4074,5.610317
2,5.7107,5.349531
3,5.4688,5.135484
4,5.2682,4.98766
5,5.1248,4.884447
6,5.0213,4.814922
7,4.9423,4.753984
8,4.8852,4.72265
9,4.8455,4.700609
10,4.8216,4.691215


## 5. Okay, the 4 and 8 local windows are the best so far. Let's try another local window of double the size of the previous local window comparing against just adding another 8 layer

In [None]:
run_leap([2048, 4, 8, 16, 2048])

[2048, 4, 8, 16, 2048]




Epoch,Training Loss,Validation Loss
1,6.4134,5.598911
2,5.6937,5.320351
3,5.4386,5.114291
4,5.2452,4.965509
5,5.1035,4.867014
6,4.9994,4.791622
7,4.9208,4.734686
8,4.8637,4.698231
9,4.8227,4.677152
10,4.8006,4.668426


In [None]:
run_leap([2048, 4, 8, 8, 2048])

[2048, 4, 8, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.4155,5.602804
2,5.6987,5.326443
3,5.445,5.120928
4,5.2519,4.972217
5,5.1105,4.872621
6,5.0067,4.796437
7,4.9279,4.740082
8,4.8709,4.704421
9,4.8301,4.683311
10,4.808,4.674782


## 6. Adding 16 is best, let's continue the pattern of adding a local window of double the size. Let's compare against adding a 4 instead, and try also another global layer

In [None]:
run_leap([2048, 4, 8, 16, 32, 2048])

[2048, 4, 8, 16, 32, 2048]




Epoch,Training Loss,Validation Loss
1,6.3686,5.593536
2,5.6767,5.310952
3,5.4253,5.093761
4,5.2309,4.957909
5,5.0914,4.851718
6,4.9864,4.780315
7,4.9075,4.719919
8,4.8491,4.683742
9,4.8092,4.665123
10,4.7845,4.655414


In [None]:
run_leap([2048, 4, 8, 16, 4, 2048])

[2048, 4, 8, 16, 4, 2048]




Epoch,Training Loss,Validation Loss
1,6.37,5.593909
2,5.6773,5.309531
3,5.4237,5.092208
4,5.2286,4.955739
5,5.0898,4.849909
6,4.9851,4.780025
7,4.9065,4.718997
8,4.8478,4.682447
9,4.8079,4.663561
10,4.7829,4.653922


In [None]:
run_leap([2048, 4, 8, 16, 2048, 2048])

[2048, 4, 8, 16, 2048, 2048]




Epoch,Training Loss,Validation Loss
1,6.3701,5.590558
2,5.6756,5.310473
3,5.4233,5.092281
4,5.2285,4.954886
5,5.0894,4.850015
6,4.9848,4.778233
7,4.9067,4.720803
8,4.8491,4.686124
9,4.8095,4.666592
10,4.7849,4.657202


## 7. They all seem to work well, though restarting back to 4 seems to work best. Let's continue with adding and 8 vs adding a global layer BEFORE the 4 (the idea being to alternate between global and local attention)

In [None]:
run_leap([2048, 4, 8, 16, 4, 8, 2048])

[2048, 4, 8, 16, 4, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.3577,5.587782
2,5.6732,5.30924
3,5.4139,5.093601
4,5.2145,4.946199
5,5.0756,4.841637
6,4.9735,4.767156
7,4.8951,4.711235
8,4.837,4.674083
9,4.7966,4.652811
10,4.7722,4.644971


In [None]:
run_leap([2048, 4, 8, 16, 2048, 4, 2048])

[2048, 4, 8, 16, 2048, 4, 2048]




Epoch,Training Loss,Validation Loss
1,6.3562,5.586043
2,5.6719,5.305367
3,5.4082,5.082562
4,5.2076,4.938675
5,5.0689,4.835489
6,4.9666,4.76275
7,4.8888,4.705896
8,4.8306,4.670035
9,4.7906,4.64848
10,4.7664,4.639331


## 8. So alternation might work! If that's true let's try getting rid of the first global attention layer (and using the alternating strategy), vs leaving it and not alternating

In [None]:
run_leap([4, 8, 16, 2048, 4, 8, 16, 2048])

[4, 8, 16, 2048, 4, 8, 16, 2048]


Epoch,Training Loss,Validation Loss
1,6.4476,5.701015
2,5.7989,5.437718
3,5.5548,5.219329
4,5.3603,5.061584
5,5.2104,4.942722
6,5.0945,4.855447
7,5.0056,4.793694
8,4.9413,4.749907
9,4.8972,4.725306
10,4.8717,4.716165


In [None]:
run_leap([2048, 4, 8, 16, 4, 8, 16, 2048])

[2048, 4, 8, 16, 4, 8, 16, 2048]




Epoch,Training Loss,Validation Loss
1,6.3797,5.578846
2,5.6657,5.309939
3,5.4097,5.077781
4,5.2035,4.924696
5,5.0578,4.821679
6,4.9509,4.74357
7,4.8706,4.690719
8,4.8129,4.654828
9,4.7719,4.633359
10,4.7471,4.624532


## 9. Okay so looks like there should be global attention up front (good to clarify that at least), let's back up and just try the alternating strategy with global attention up front (comparing non alternating strategy directly above)

In [None]:
run_leap([2048, 4, 8, 16, 2048, 4, 8, 2048])

[2048, 4, 8, 16, 2048, 4, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.3793,5.57653
2,5.6662,5.304617
3,5.4079,5.078589
4,5.2003,4.926067
5,5.0535,4.82244
6,4.946,4.74207
7,4.8651,4.686777
8,4.8069,4.651878
9,4.7657,4.628971
10,4.7408,4.619902


## 10. That seems to work well, let's compare that to maybe trying faster cycle rates (compared to baseline 4, 8, 16)

In [None]:
run_leap([2048, 4, 16, 2048, 4, 16, 2048, 4, 2048])

[2048, 4, 16, 2048, 4, 16, 2048, 4, 2048]




Epoch,Training Loss,Validation Loss
1,6.3374,5.570338
2,5.6543,5.292899
3,5.3967,5.068539
4,5.1968,4.919131
5,5.0542,4.819403
6,4.9491,4.747835
7,4.8694,4.690602
8,4.8108,4.650536
9,4.7685,4.629301
10,4.7442,4.620091


In [None]:
run_leap([2048, 4, 8, 2048, 4, 8, 2048, 4, 2048])

[2048, 4, 8, 2048, 4, 8, 2048, 4, 2048]




Epoch,Training Loss,Validation Loss
1,6.3411,5.579933
2,5.6623,5.301711
3,5.4066,5.078926
4,5.2032,4.92617
5,5.0615,4.829375
6,4.956,4.757246
7,4.8776,4.698025
8,4.8184,4.658959
9,4.7767,4.636538
10,4.7524,4.627797


In [None]:
run_leap([2048, 8, 16, 2048, 8, 16, 2048, 8, 2048])

[2048, 8, 16, 2048, 8, 16, 2048, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.3412,5.585158
2,5.6727,5.32093
3,5.4244,5.105385
4,5.2311,4.962701
5,5.0942,4.868821
6,4.993,4.796904
7,4.9172,4.742193
8,4.8608,4.706789
9,4.8212,4.685561
10,4.7974,4.676823


In [None]:
run_leap([2048, 4, 8, 16, 2048, 4, 8, 16, 2048])

[2048, 4, 8, 16, 2048, 4, 8, 16, 2048]




Epoch,Training Loss,Validation Loss
1,6.337,5.573453
2,5.6518,5.286858
3,5.393,5.06443
4,5.1925,4.914081
5,5.0499,4.814898
6,4.9445,4.743355
7,4.8649,4.685818
8,4.8057,4.646133
9,4.764,4.622876
10,4.7399,4.614115


## 11. Looks like 4, 8, 16 is best. Though, what if the last layer wasn't global?

In [None]:
run_leap([2048, 4, 8, 16, 2048, 4, 8, 16])

[2048, 4, 8, 16, 2048, 4, 8, 16]




Epoch,Training Loss,Validation Loss
1,6.378,5.577907
2,5.665,5.303333
3,5.4058,5.074307
4,5.1986,4.922642
5,5.0526,4.819136
6,4.9446,4.740305
7,4.8637,4.685725
8,4.8056,4.650772
9,4.7642,4.627913
10,4.7394,4.61888


In [None]:
run_leap([2048, 4, 8, 16, 2048, 4, 8, 2048])

[2048, 4, 8, 16, 2048, 4, 8, 2048]




Epoch,Training Loss,Validation Loss
1,6.3793,5.57653
2,5.6662,5.304617
3,5.4079,5.078589
4,5.2003,4.926067
5,5.0535,4.82244
6,4.946,4.74207
7,4.8651,4.686777
8,4.8069,4.651878
9,4.7657,4.628971
10,4.7408,4.619902


# Conclusion for now: first layer should be global attention, layers of sizes 4, then 8, then 16, then global attention, and repeat.