In [None]:
import impulsegpt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from tqdm import tqdm
from matplotlib import pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
#import char_tokenizer

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    print("Using CPU")

In [None]:
config = impulsegpt.Config()
config.ctx_len = 256
config.n_layers = 12
config.d_model = 768
config.n_heads = 12
config.vocab = 50000

In [None]:
tokenizer = AutoTokenizer.from_pretrained('google-bert/bert-base-chinese')
#tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
collator = DataCollatorWithPadding(tokenizer, 'max_length', config.ctx_len, return_tensors='pt')

config.vocab = len(tokenizer.vocab)
print(f"Model vocab set to: {config.vocab}, Embedding size: {config.d_model * config.vocab}")


In [None]:
ds = load_dataset("roneneldan/TinyStories")
ds = ds['train']

In [None]:
model = impulsegpt.ImpulseGPT(config=config).to(device)
#model = torch.load('ckpt/ts-64-1.pt')
summary(model)

In [None]:
def train(dataset, model, loss_fn, optimizer, epochs:int, batch_size:int, training_divides:int, logger:SummaryWriter):
    model.train()
    print(f"Start training for {epochs} epochs with {len(dataset)} rows of data each.")
    for s in range(epochs):
        for chunk in range(training_divides):
            print(f"Training on {chunk+1} of {training_divides} data chunks")
            dataloader = DataLoader(dataset=dataset.shard(num_shards=training_divides, index=chunk),collate_fn=collator, batch_size=batch_size)
            pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {s+1} of {epochs}")
            for batch, row in pbar:
                step_loss = 0
                num_rows = row['input_ids'].shape[1] - 1
                for t in range(num_rows):
                    context = row['input_ids'][...,:t+1].to(device)
                    y = row['input_ids'][...,t+1].to(device)
                    y_hat = model(context)
                    loss = loss_fn(y_hat, y)
                    loss.backward()
                    optimizer.step()
                    optimizer.zero_grad()
                    step_loss += loss.item()
                step_loss /= num_rows
                logger.add_scalar('Loss', step_loss, batch+1)
                pbar.set_postfix({'Loss':step_loss})
        torch.save(model, f"ckpt/ts-{config.ctx_len}-{chunk}.pt")

    logger.close()


In [None]:
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005)
writer = SummaryWriter()
train(ds, model, loss_fn, optimizer, epochs=1, batch_size=24, training_divides=1000, logger=writer)

In [None]:
torch.save(model, "ckpt/impgpt-final-1.pt")

In [None]:
#start_x = torch.tensor(tokenizer.encode('Once upon a time')).unsqueeze(dim=0).to(device=device)
#print(start_x)
start_ids = torch.tensor([[ 101,  100, 8644, 8224,  143, 8759]]).to(device)
max_length = 256
y = model.generate(start_ids, max_length=max_length, top_k=64, temp=0.75)
print(y)
txt = tokenizer.decode(y[0].tolist(), skip_special_tokens=True)
print(y.shape)
print(txt)

In [None]:
y = model(start_ids)
prob = nn.functional.softmax(y, dim=-1).cpu().detach().squeeze()
token_max = torch.argmax(prob)
print(token_max)
plt.plot(prob)
tokenizer.decode([token_max.tolist()])