In [None]:
!pip install colt5-attention
!pip install torch 


import torch 
from colt5_attention import (
    ConditionalRoutedFeedForward,
    ConditionalRoutedAttention,
    ConditionalRoutedTransformerBlock
)


#mock unout
tokens = torch.randn(2, 50000, 512)
print(f"tokens: {tokens}")

mask = torch.ones(2, 50000).bool() #variable lenthed sequences
print(f"mask: {mask}")


ff = ConditionalRoutedFeedForward(
    dim= 512,
    light_ff_mult= 0.5,
    heavy_ff_mult = 4,
    num_heavy_tokens = 1024
)

ff_out = ff(tokens, mask=mask)
print(f"feed forward: {ff}")


attn = ConditionalRoutedAttention(
    dim=512,
    light_dim_head=64,
    light_heads = 8,
    light_window_size=128,
    heavy_dim_head = 64,
    heavy_heads=8,
    num_heavy_tokens_q=1024,
    num_heavy_tokens_kv=1024,
    use_flash_attn = True
)

attn_out = attn(tokens, mask=mask)
print(f"attn out: {attn_out}")

block = ConditionalRoutedTransformerBlock(
    dim=512,
    light_dim_head=64,
    light_heads=8,
    light_window_size=128,
    heavy_dim_head=64,
    heavy_heads=8,
    light_ff_mult=0.5,
    heavy_ff_mult = 4,
    num_heavy_ff_tokens=1024,
    num_heavy_attn_tokens_q= 1024,
    num_heavy_attn_tokens_kv = 1024
)

block_out = block(tokens, mask=mask)

print(f"block out: {block_out}")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
tokens: tensor([[[-0.1166,  2.7874, -1.0335,  ...,  0.4842, -0.2701,  1.9254],
         [ 0.3403, -0.5440,  1.1186,  ..., -0.2711,  1.1225,  0.3496],
         [ 0.5147, -3.2057, -0.0480,  ..., -0.3550,  0.6639,  0.2929],
         ...,
         [ 0.8364,  0.6330, -0.8668,  ...,  2.2426, -0.0989, -0.8650],
         [-0.3178, -1.0713,  0.8390,  ..., -1.8518,  0.4022, -0.1280],
         [-0.1367,  0.5613,  0.4291,  ..., -2.0379, -0.7547, -0.0296]],

        [[ 0.3970, -0.2397, -0.4050,  ..., -1.8274, -0.8286, -0.4790],
         [ 1.6024,  2.0110, -0.5000,  ...,  0.4076,  0.6969, -0.3910],
         [ 2.3210, -0.7658, -1.0843,  ...,  0.0580,  1.0062,  0.1042],
         ...,
         [ 0.1934, -0.5202,  0.1969,  ...,  1.5385,  1.4037, -1.6028],
         [-1.9333,  0.2146,  0.2495,  ..., -0.1534,  0

In [None]:
#conditionally routed attention for cross atten tion
import torch 
from colt5_attention import ConditionalRoutedCrossAttention

#mock input lets say it is a transformer of 1024 attending to 1 million context memories


tokens = torch.randn(1, 1024, 512).cuda()
print(f"tokens: {tokens}")

tokens_mask = torch.ones(1, 1024).bool().cuda()
print(f"tokens mask: {tokens_mask}")

memories = torch.randn(1, 1_048_576, 512).cuda()
print(f"memories: {memories}")

memories_mask = torch.ones(1, 1_048_576).bool().cuda()
print(f"memories: {memories_mask}")


#conditionally routed cross attention
cross_attn = ConditionalRoutedCrossAttention(
    dim=512,
    dim_head=64,
    heads=8,
    num_tokens_q=512,
    num_tokens_kv=1024,
    kv_routing_tokens=2,
    use_triton=True,
    route_block_size=131072
).cuda()

cross_attn_out = cross_attn(
    tokens,
    context=memories,
    mask=tokens_mask,
    context_mask=memories_mask
)

shape =  cross_attn_out.shape
print(f"shape {shape}" )

tokens: tensor([[[ 1.0263,  0.0444,  0.4157,  ...,  1.9324,  0.2600,  0.0207],
         [-0.0459,  1.0451,  1.1374,  ...,  0.7586, -0.0329,  0.6849],
         [ 1.3423, -0.9727,  1.0685,  ...,  0.0900,  0.1832, -0.1667],
         ...,
         [ 0.2739,  1.6226,  0.5003,  ...,  0.3694,  0.3642,  0.2331],
         [ 0.7136, -0.2688, -0.0857,  ...,  0.2841,  0.3298,  0.0403],
         [ 1.6305,  0.0530,  0.3946,  ...,  0.5485, -0.5693,  0.0652]]],
       device='cuda:0')
tokens mask: tensor([[True, True, True,  ..., True, True, True]], device='cuda:0')
memories: tensor([[[ 0.0280,  0.6959, -0.1431,  ..., -0.6348,  0.9171, -0.3882],
         [-0.0847,  0.1931, -0.6492,  ...,  0.6191, -0.3137,  1.0662],
         [-2.1065, -0.4429,  0.0470,  ..., -0.0771,  0.6260,  0.5998],
         ...,
         [-0.0182,  0.8396, -0.4933,  ...,  1.0157,  1.1854, -0.2614],
         [-0.3336,  1.1808,  0.5988,  ..., -1.2341,  0.4146, -2.2379],
         [ 1.8775, -0.3667,  0.7416,  ..., -1.0526,  0.5791, -0.

In [None]:
!pip install x-transformers

import torch 
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(
    model,
    mask_prob=0.15,
)

x = torch.randint(0, 20000, (1, 5000))
print(f"tokens: {x}")

# results = model(x)

loss = model(x)
print(f"loss: {loss}")


result = loss.backward()
print(f"Result: {result}")
print(f"results {result}")

#generate
# model.generate(seq_len=2000)

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
tokens: tensor([[  958,   995, 18363,  ...,  2906, 13673,  3010]])
loss: 10.073033332824707
Result: None
results None


TypeError: ignored

In [None]:
import time

import torch
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import get_scheduler, default_data_collator, get_linear_schedule_with_warmup
from torch.optim import AdamW

from kosmos import Kosmos, KosmosTokenizer
from accelerate import Accelerator

from rich.progress import Progress
from datasets import Image
from bitsandbytes.optim import AdamW8bit


def count_number_of_parameters(model, only_trainable: bool = True) -> int:
    if only_trainable:
        num_params: int = sum(p.numel()
                              for p in model.parameters() if p.requires_grad)
    else:
        num_params: int = sum(p.numel() for p in model.parameters() if p)
    return int(num_params)



#change to enwiki8
def prep_sample(sample):
    question = sample["question"]
    answer = sample["answer"].split("|!+")[1]
    explanation = sample["explanation"]
    text = f"Question: {question} Answer: {answer} Explanation: {explanation}"
    image = sample["image"]
    return {
        "image": image,
        "target_text": text
    }


def train(args):
    accelerator = Accelerator(
        mixed_precision="fp16"
    )

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # change to andromeda
    model = Kosmos()
    model = model.to(accelerator.device)

    optimizer = AdamW8bit(model.parameters(), lr=args.learning_rate,
                      weight_decay=args.weight_decay)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps,
    )

    #change to andromeda tokenizer
    tokenizer = KosmosTokenizer()

    #change to enwiki
    dataset = load_dataset("bjoernp/vqax", split="test")
    #dataset = dataset.cast_column("URL", Image)
    #map
    dataset = dataset.map(prep_sample, num_proc=8)

    #change to enwiki
    remove_columns = ['id', 'img_id', 'question', 'answer',
                      'explanation', 'none', 'image', 'target_text']

    #change batch_size
    dataset = dataset.map(tokenizer.tokenize, batched=True,
                          batch_size=128, remove_columns=remove_columns)

    train_dataloader = DataLoader(
        dataset, collate_fn=default_data_collator, batch_size=args.batch_size, pin_memory=True
    )

    model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(model, train_dataloader, optimizer,
                                                                           lr_scheduler)
    model.train()
    accelerator.register_for_checkpointing(lr_scheduler)

    model.clip_model.requires_grad_(False)
    model.clip_model.encoder.layers[-1].requires_grad_(True)

    accelerator.print(
        f"Number of parameters: {count_number_of_parameters(model):,}")
    accelerator.print(
        f"Number of trainable parameters: {count_number_of_parameters(model, only_trainable=True):,}")

    # Log model and optimizer parameters to wandb
    accelerator.init_trackers(project_name="kosmos")

    train_loader = iter(train_dataloader)
    epoch_loss = 0
    total_loss = 0
    start_time = time.time()

    with Progress() as progress:
        task = progress.add_task("[red]Training...", total=args.max_steps)
        for step in range(0, args.max_steps):
            batch_start = time.time()
            batch = next(train_loader)
            outputs = model(**batch, self_attn_padding_mask=batch["attention_mask"])
            # Shift so that tokens < n predict n
            outputs = torch.cat([outputs[:, :1], outputs[:, 67:]], dim=1).contiguous()
            # shift_logits = outputs[..., :-1, :].contiguous()
            # shift_labels = batch["labels"][..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            one_hot_labels = torch.nn.functional.one_hot(batch["labels"][:, 1:], num_classes=32002).float()
            loss = loss_fct(outputs[:,:-1], one_hot_labels)

            epoch_loss += loss.detach().float()

            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()

            batch_end = time.time()
            logs = {
                "loss": loss.item(),
                "perplexity": torch.exp(loss).item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "examples": args.batch_size * (step + 1),
                "examples_per_second": args.batch_size / (batch_end - batch_start),
            }
            if step % args.log_every == args.log_every - 1:
                accelerator.log(logs, step=step)
                progress.update(task, advance=1, description=f"Step Loss: {loss.item():.5f} "
                                                             f"| Mean Loss: {(total_loss + epoch_loss) / step:.5f} "
                                                             f"| Mean PPL: {torch.exp((total_loss + epoch_loss) / step):.2f} "
                                                             f"| Examples: {args.batch_size * (step + 1)} "
                                                             f"| Examples/s: {args.batch_size / (batch_end - batch_start):.2f} "
                                                             f"| Elapsed: {time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))}")

            if step % args.save_every == args.save_every - 1:
                train_epoch_loss = epoch_loss / args.save_every
                total_loss += epoch_loss
                epoch_loss = 0

                accelerator.log({
                    "train_ppl": torch.exp(train_epoch_loss),
                    "train_epoch_loss": train_epoch_loss,
                }, step=step)

                progress.print(f"Saving checkpoint at step {step}...")
                accelerator.save_state(
                    f"{args.checkpoint_dir}/checkpoint_at_step_{step}/")



In [None]:
!git clone https://github.com/kyegomez/Optimus-Prime.git
%cd Optimus-Prime
%cd examples
%cd enwiki8_simple

from torch.serialization import load
import torch 
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

#training
import random
import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 1024

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(model)
model.cuda()

with gzip.open('./data/enwik8.gz') as file:
  data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
  train_x, valid_x = np.split(data, [int(90e6)])
  data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

Cloning into 'Optimus-Prime'...
remote: Enumerating objects: 1614, done.[K
remote: Counting objects: 100% (1614/1614), done.[K
remote: Compressing objects: 100% (531/531), done.[K
remote: Total 1614 (delta 1113), reused 1529 (delta 1073), pack-reused 0[K
Receiving objects: 100% (1614/1614), 37.48 MiB | 47.50 MiB/s, done.
Resolving deltas: 100% (1113/1113), done.
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enw

In [None]:
!git clone https://github.com/kyegomez/Optimus-Prime.git
%cd Optimus-Prime
%cd examples
%cd enwik8_simple
!python3 trainandromeda.py

Cloning into 'Optimus-Prime'...
remote: Enumerating objects: 1614, done.[K
remote: Counting objects: 100% (453/453), done.[K
remote: Compressing objects: 100% (191/191), done.[K
remote: Total 1614 (delta 309), reused 367 (delta 260), pack-reused 1161[K
Receiving objects: 100% (1614/1614), 37.51 MiB | 45.67 MiB/s, done.
Resolving deltas: 100% (1100/1100), done.
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/enwik8_simple/Optimus-Prime
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/Optimus-Prime/examples/enwik8_simple/Optimus-Prime/examples/enwik8_simple
  File "/content/Optimus-Prime/Optimus-Prime/Optimus-Prime/ex

In [3]:
!git clone https://github.com/kyegomez/Optimus-Prime.git
%cd Optimus-Prime
!pip install --upgrade torch
# !pip install -r requirements.txt
!pip install einops
# !pip install --upgrade torch

# %cd Optimus-Prime
# # %cd examples
# # !ls
# !python3 trainandromeda.py 
# #%cd enwik8_simple
# # !python trainandromeda.py


from torch.serialization import load
import torch 
from optimus_prime import TransformerWrapper, Decoder, AutoregressiveWrapper

#training
import random
import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import os
# from torch.utils.tensorboard import SummaryWriter
# from torchmetrics import MetricCollection, Accuracy


# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 1024
SAVE_EVERY=500


# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(model)
model.cuda()

with gzip.open('./enwik8.gz') as file:
  data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
  train_x, valid_x = np.split(data, [int(90e6)])
  data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) #.cuda()??

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

# #init tensorboard 
# writer = SummaryWriter(log_dir="./log")

# #define metrics
# metrics = MetricCollection({'accuracy': Accuracy(num_classes=num_classes, task='classification')})
device="cuda"
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))#.to(device)
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()#.to(device)#.cuda()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()


    if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                loss = model(next(val_loader))
                print(f'validation loss: {loss.item()}')

                # # Calculate validation metrics
                # val_metrics = MetricCollection({'val_accuracy': Accuracy()})
                # val_metrics(loss, model(next(val_loader)).argmax(dim=-1))

                # # Add validation metrics to the SummaryWriter
                # writer.add_scalar('Validation/Accuracy', val_metrics['val_accuracy'].compute(), global_step=i)

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

    # Save the model every save_every iterations
    if i % SAVE_EVERY == 0:
        # Specify the directory and filename to save the model
        save_dir = './saved_models/'
        save_filename = 'model_checkpoint.pt'
        from google.colab import files
        files.download(save_filename)

        # Create the save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Save the model checkpoint
        torch.save(model.state_dict(), os.path.join(save_dir, save_filename))
        print(f"Model saved at iteration {i}")

#     # Add training metrics to the SummaryWriter
#     writer.add_scalar('Training/Accuracy', metrics['accuracy'].compute(), global_step=i)

#     # Close the SummaryWriter
# writer.close()

Cloning into 'Optimus-Prime'...
remote: Enumerating objects: 1674, done.[K
remote: Counting objects: 100% (513/513), done.[K
remote: Compressing objects: 100% (231/231), done.[K
remote: Total 1674 (delta 341), reused 412 (delta 277), pack-reused 1161[K
Receiving objects: 100% (1674/1674), 37.52 MiB | 17.27 MiB/s, done.
Resolving deltas: 100% (1132/1132), done.
/content/Optimus-Prime/Optimus-Prime/Optimus-Prime
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


ImportError: ignored

In [None]:
!nvidia-smi



```
# with multiquery + flash + alibi + xpos + deepnorm
```

# New Section

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#concurrency
!git clone https://github.com/kyegomez/Optimus-Prime.git
!pip install einops
!pip install torchmetrics
!pip install tensorboard

!pip install -r requirements.txt
%cd Optimus-Prime

import threading

def run_tensorboard():
    !tensorboard --logdir=./logs --port=6006

def run_training():
    !python3 trainandromeda.py

# Start TensorBoard in a separate thread
tensorboard_thread = threading.Thread(target=run_tensorboard)
tensorboard_thread.start()

# Run the training script
run_training()

# Wait for the TensorBoard thread to finish
tensorboard_thread.join()

In [None]:
!git clone https://github.com/kyegomez/Optimus-Prime.git
%cd Optimus-Prime
!pip install --upgrade torch
# !pip install -r requirements.txt
!pip install einops
# !pip install --upgrade torch

# %cd Optimus-Prime
# # %cd examples
# # !ls
# !python3 trainandromeda.py 
# #%cd enwik8_simple
# # !python trainandromeda.py


from torch.serialization import load
import torch 
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

#training
import random
import tqdm
import gzip
import numpy as np
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import os
# from torch.utils.tensorboard import SummaryWriter
# from torchmetrics import MetricCollection, Accuracy


# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 500
GENERATE_LENGTH = 1024
SEQ_LEN = 1024
SAVE_EVERY=500


# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

model = TransformerWrapper(
    num_tokens=20000,
    max_seq_len=5000,
    use_abs_pos_emb = False,
    attn_layers = Decoder(
        dim=512,
        depth=6,
        heads=8,
        alibi_pos_bias=True,
        alibi_num_heads=4,
        rotary_xpos=True,
        attn_flash = True,
        deepnorm=True,
        # dynamic_pos_bias=True,
        # dynamic_pos_bias_log_distance=False,
        shift_tokens=1,
        attn_one_kv_head = True,
        # rel_pos_bias=True
    )
)


model = AutoregressiveWrapper(model)
model.cuda()

with gzip.open('./enwik8.gz') as file:
  data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
  train_x, valid_x = np.split(data, [int(90e6)])
  data_train, data_val = torch.from_numpy(train_x), torch.from_numpy(valid_x) #.cuda()??

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
val_dataset   = TextSamplerDataset(data_val, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE, drop_last = True))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE, drop_last = True))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

# #init tensorboard 
# writer = SummaryWriter(log_dir="./log")

# #define metrics
# metrics = MetricCollection({'accuracy': Accuracy(num_classes=num_classes, task='classification')})
device="cuda"
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))#.to(device)
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()#.to(device)#.cuda()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()


    if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                loss = model(next(val_loader))
                print(f'validation loss: {loss.item()}')

                # # Calculate validation metrics
                # val_metrics = MetricCollection({'val_accuracy': Accuracy()})
                # val_metrics(loss, model(next(val_loader)).argmax(dim=-1))

                # # Add validation metrics to the SummaryWriter
                # writer.add_scalar('Validation/Accuracy', val_metrics['val_accuracy'].compute(), global_step=i)

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        prime = decode_tokens(inp)
        print(f'%s \n\n %s', (prime, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        output_str = decode_tokens(sample)
        print(output_str)

    # Save the model every save_every iterations
    if i % SAVE_EVERY == 0:
        # Specify the directory and filename to save the model
        save_dir = './saved_models/'
        save_filename = 'model_checkpoint.pt'

        # Create the save directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Save the model checkpoint
        torch.save(model.state_dict(), os.path.join(save_dir, save_filename))
        print(f"Model saved at iteration {i}")

#     # Add training metrics to the SummaryWriter
#     writer.add_scalar('Training/Accuracy', metrics['accuracy'].compute(), global_step=i)

#     # Close the SummaryWriter
# writer.close()