In [None]:
import impulsegpt_sdpa
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
#import char_tokenizer
print(torch.__version__)

In [None]:
torch.cuda.get_device_capability()

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    print("Using CPU")

In [None]:
config = impulsegpt_sdpa.Config()
config.ctx_len = 128
config.n_layers = 12
config.d_model = 768
config.n_heads = 12
config.n_kv_heads = 4
config.vocab = 50000
config.gpa = True

enable_mixed_pricision = True

In [None]:
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')
#tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
collator = DataCollatorWithPadding(tokenizer, 'max_length', config.ctx_len, return_tensors='pt')

config.vocab = len(tokenizer.vocab)
print(f"Model vocab set to: {config.vocab}, Embedding size: {config.d_model * config.vocab}")


In [None]:
ds = load_dataset("roneneldan/TinyStories")
ds = ds['train']

In [None]:
print(f"Length before filter: {len(ds)}")
#ds = ds.filter(lambda t: len(t['text']) < (config.ctx_len*2))
ds = ds.map(lambda t: tokenizer(t['text'], 
                                truncation=True, 
                                max_length=config.ctx_len,
                                return_overflowing_tokens=False,
                                return_length=True), batched=True)
ds = ds.remove_columns(['text','token_type_ids','attention_mask', 'length'])
ds = ds.with_format('torch')

print(ds[0])
print(f"Train data length: {len(ds)}")
#print(f"Validation data length: {len(ds['validation'])}")

In [None]:
model = impulsegpt_sdpa.ImpulseGPT(config=config).to(device)
model.compile()
#model = torch.load('ckpt/ts-64-1.pt')
summary(model)

In [None]:
def train(dataset, model, loss_fn, optimizer, epochs:int=1, batch_size:int=16, training_divides:int=10, scaler:torch.amp.GradScaler=None, logger:SummaryWriter=None):
    model.train()
    print(f"Start training for {epochs} epochs with {len(dataset)} rows of data each.")
    for s in range(epochs):
        for chunk in range(training_divides):
            print(f"Training on {chunk+1} of {training_divides} data chunks")
            dataloader = DataLoader(dataset=dataset.shard(num_shards=training_divides, index=chunk),
                                    collate_fn=collator, 
                                    batch_size=batch_size, 
                                    num_workers=16)
            pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {s+1} of {epochs}")
            for batch, row in pbar:
                step_loss = 0
                num_rows = row['input_ids'].shape[1] - 1
                for t in range(num_rows):
                    context = row['input_ids'][...,:t+1].to(device)
                    y = row['input_ids'][...,t+1].to(device)

                    with torch.autocast(device_type='cuda', 
                                        dtype=torch.bfloat16, 
                                        enabled=enable_mixed_pricision):
                        y_hat = model(context)
                        loss = loss_fn(y_hat, y)
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                        step_loss += loss.item()
                step_loss /= num_rows
                if logger:
                    logger.add_scalar('Loss', step_loss, batch+1)
                pbar.set_postfix({'Loss':step_loss})
        torch.save(model, f"ckpt/impgpt-{config.ctx_len}-{chunk}.pt")
    if logger:
        logger.close()


In [None]:
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=5e-4, 
                              betas=(0.9, 0.95),
                              weight_decay=0.1)
# For mixed precision
scalar = torch.amp.GradScaler('cuda')

writer = SummaryWriter()


with nn.attention.sdpa_kernel(nn.attention.SDPBackend.FLASH_ATTENTION):
    train(ds, model, loss_fn, optimizer, epochs=1, batch_size=16, scaler=scalar, training_divides=100)

In [None]:
torch.save(model, "ckpt/impgpt-final-1.pt")

In [None]:
#start_x = torch.tensor(tokenizer.encode('Once upon a time')).unsqueeze(dim=0).to(device=device)
#print(start_x)
start_ids = torch.tensor([[ 101,  100, 8644, 8224,  143, 8759]]).to(device)
max_length = 64
y = model.generate(start_ids, max_length=max_length, top_k=64, temp=0.75)
print(y)
txt = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
print(y.shape)
print(txt)

In [None]:
y = model(start_ids)
prob = nn.functional.softmax(y, dim=-1).cpu().detach().squeeze()
token_max = torch.argmax(prob)
print(token_max)
plt.plot(prob)
tokenizer.decode([token_max.tolist()])