In [None]:
import impulsegpt
import torch
from torch import nn
from torchinfo import summary
from tqdm import tqdm
from matplotlib import pyplot as plt

In [None]:
# device = torch.device('cpu')
# if torch.cuda.is_available():
#     device = torch.device('cuda')
#     print("Using CUDA")
# elif torch.backends.mps.is_available():
#     device = torch.device('mps')
#     print("Using MPS")
# else:
#     print("Using CPU")

In [None]:
device = torch.device('cuda')

In [None]:
config = impulsegpt.Config()
config.ctx_len = 64
config.n_layers = 6
config.d_model = 512
config.n_heads = 8

In [None]:
#data_dir = 'D:/dataset/tiny.txt'
data_dir = './train_data/wkz8.txt'

with open(data_dir, 'r', encoding='utf-8') as f:
    text = f.read()

# Count unique characters
unique_chars = set(text)
num_unique_chars = len(unique_chars)

config.vocab = num_unique_chars

print(f'Length of text: {len(text)}')
print(f"Number of unique characters in the file: {num_unique_chars}")
print("Unique characters:", ''.join(sorted(unique_chars)))

In [None]:
text = " ".join(text.split())
print(len(text))
char = ' '
count = 0
for t in text:
    if t == char:
        count += 1
print(f"number of char:{char} is {count}")


In [None]:
character_to_index = {char: i for i, char in enumerate(unique_chars)}
index_to_character = {i: char for i, char in enumerate(unique_chars)}

def encode(x):
    return [character_to_index[i] for i in x]

def decode(x):
    return [index_to_character[i] for i in x]

def decode_tensor(x):
    return ''.join([index_to_character[i.item()] for i in x])

In [None]:
encoded_text = torch.tensor(encode(text), dtype=torch.int, device=device)

# Create train-validation split (90-10)
n = int(0.9 * len(encoded_text))
train_data = encoded_text[:n]
val_data = encoded_text[n:]

print(f"Train data length: {len(train_data)}")
print(f"Validation data length: {len(val_data)}")

In [None]:
def get_batch(data, batch_size, context_length):
    batch = []
    for b in range(batch_size):
        i = torch.randint(0, len(data) - context_length - 1, (1,))
        batch.append(data[i:i+context_length+1])
    return torch.stack(batch)

def get_data(batches):
    num_batch, ctx_len = batches.shape
    context = []
    label = []
    for t in range(ctx_len-1):
        context.append(torch.stack([batches[i][:t+1] for i in range(num_batch)]).to(device))
        label.append(torch.stack([batches[i][t+1] for i in range(num_batch)]).type(torch.LongTensor).to(device))
    return context, label


In [None]:
model = impulsegpt.ImpulseGPT(config=config).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00005)
summary(model)

In [None]:
def train(dataset, model, loss_fn, optimizer, steps:int, num_batch=4):
    model.train()
    print(f"Start training with {steps} steps")
    pbar = tqdm(range(steps))
    for step in pbar:
        batch = get_batch(dataset, num_batch, config.ctx_len)
        context, label = get_data(batch)
        step_loss = 0
        for i in range(len(label)):
            pred = model(context[i])
            loss = loss_fn(pred, label[i])

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            step_loss += loss.item()
        step_loss = step_loss / len(label)
        pbar.set_postfix({'Loss:':step_loss})
        # if step % 25 == 0:
        #     print(f"Step {step}: Loss = {step_loss}")
        
    


In [None]:
train(train_data, model, loss_fn, optimizer, steps = 128, num_batch=64)

In [None]:
start_x = torch.tensor(encode('我')).unsqueeze(dim=0).to(device=device)

max_length = 60
y = model.generate(start_x, max_length=max_length, top_k=32)
txt = decode_tensor(y[0])
print(y.shape)
print(txt)

In [None]:
y = model(start_x)
prob = nn.functional.softmax(y, dim=-1).cpu().detach().squeeze()
token_max = torch.argmax(prob)
print(token_max)
plt.plot(prob)
decode_tensor([token_max])