In [1]:
import impulsegpt
import torch
from torch import nn, functional


In [2]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
elif torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    print("Using CPU")

Using CUDA


In [3]:
config = impulsegpt.Config()
config.ctx_len = 64
config.n_layers = 6
config.d_model = 512
config.n_heads = 8

In [4]:
data_dir = './train_data/wkz8.txt'

with open(data_dir, 'r', encoding='utf-8') as f:
    text = f.read()

# Count unique characters
unique_chars = set(text)
num_unique_chars = len(unique_chars)

config.vocab = num_unique_chars

print(f'Length of text: {len(text)}')
print(f"Number of unique characters in the file: {num_unique_chars}")
print("Unique characters:", ''.join(sorted(unique_chars)))

Length of text: 87542
Number of unique characters in the file: 2149
Unique characters: 
 *.08altwx—‘’“”…、。一丁七万丈三上下不与丑专且世丘业东丝丢两严个丫中丸丹为主丽举乃久么义之乌乍乎乐乔乖乘九也书买乱了予争事二于亏云互五井些亡交亮亲人亿什仇今仍从仑仔他仗仙仞代令以仨仪们仰件价任份仿伍伏伐休众优伙会伟传伤伦伶伸似但位低住佑体何余佛作你佩佼使依侠侧便促俊俏俗保信俩修俯俱俺倒候倚借倦值倾假偈偎偏做停偷偿傲傻像僧儿兀元兄充先光克免兔入全八公六兮兰共关兴兵其具养兽内冉再冒冕写冠冤冥冬冰冲决况冷冻净凄准凉凌减凑凛凝几凡凤凭凰凳凶出击刀分切刑划刚初判利别刮到制刹刺刻前剑剥剧剩副割劈力劝办功加务劣动助努劫劲劳势勇勒勾包匆化北匹区医十千升午半华卑卒单卖南博占卫卯印危即却卷历厉压厌厚原厮去参又及友双反发取受变叛叠口古句另只叫叭叮可台史右叶号叹叽吁吃各吆合同名后吏吐向吓吗君吞吟否吧吩含听吭启吵吸吹吻吼吽呀呆告呜呢呦周味呵呸呼命咆咋和咏咐咒咔咕咚咦咧咪咬咱咳咽哀品哄哇哈哉响哎哑哗哟哥哦哧哩哪哭哮哲哼哽唉唐唠唤唬唯唱唳唵啃啄啊啥啦啪啷啸啼喂喃善喊喘喜喝喧喳喷喽嗔嗝嗡嗨嗬嗯嗵嗷嘎嘘嘛嘟嘤嘲嘴嘶嘻嘿噙噢器嚎嚓嚣嚼囚四回因团园困围固国图圆圈土圣在地场均坏坐坑块坚坠坡垂型垒垫埋城堂堆堵塌塔塞填境墓墙增墟壁壑壤士声壳壶处备复夏夕外多夜够大天太夫央失头夹夺奇奈奋奔套奘奥女奶她好如妃妄妇妈妖妙妨妹始姐姑姓姿威娘娲婉婢婶媳嫁嫩子孔字存孙季孤学孩孱宁它宇守安完宗官宙定宛宜宝实客宫宰害宴家容宽宿寂密寒寞察寸对寺寻导封射将尊小少尖尘尚尝尤就尸尺尽尾局屁层居屈屋屏屑展山岁岂岐岔岚岛岩岭岸峰峻崖崩崽嶽巉巡左巧巨差己已巴巾市布帅师希帘帝带席帮常帽幅幕干平年并幸幻幽广庄庆床序应底庙府庞废度座庭廊廖延建开异弃弄式引弟张弥弧弯弱弹强弼归当形彩影彼往径待很律徐徒得御微德心必忆忌忍忒志忘忙忧快念忽怀态怎怏怒怔怕怖怜思急性怨怪怯总恋恍恐恒恙恨恩恬恭息恳恶恼悄悉悔悟悠悦您悬悲情惊惑惜惧惨惬惮想惶惹愁愉意愕感愣愤愿慌慕慢慧憧憬憾懂懒戏成我戒或战戟戳戴所扁扇手才扎扑打扔托执扫扬扮扯扰扶找承把抓投抖抗折抚抛抢护报披抬抱抵抹押抽拂拄担拆拉拍拎拐拒拖拘招拜拣拥拦拨择拯拱拳拼拽拾拿持挂指按挑挖挚挠挡挣挤挥挨振挺挽捂捅捉捏捕捞捡换捣捧据捶捷掀

In [5]:
character_to_index = {char: i for i, char in enumerate(unique_chars)}
index_to_character = {i: char for i, char in enumerate(unique_chars)}

def encode(x):
    return [character_to_index[i] for i in x]

def decode(x):
    return [index_to_character[i] for i in x]

def decode_tensor(x):
    return ''.join([index_to_character[i.item()] for i in x])

In [6]:
encoded_text = torch.tensor(encode(text), dtype=torch.int, device=device)

# Create train-validation split (90-10)
n = int(0.9 * len(encoded_text))
train_data = encoded_text[:n]
val_data = encoded_text[n:]

print(f"Train data length: {len(train_data)}")
print(f"Validation data length: {len(val_data)}")

Train data length: 78787
Validation data length: 8755


In [7]:
def get_batch(data, batch_size, context_length):
    batch = []
    for b in range(batch_size):
        i = torch.randint(0, len(data) - context_length - 1, (1,))
        batch.append(data[i:i+context_length+1])
    return torch.stack(batch)

def get_data(batches):
    num_batch, ctx_len = batches.shape
    context = []
    label = []
    for t in range(ctx_len-1):
        context.append(torch.stack([batches[i][:t+1] for i in range(num_batch)]).to(device))
        label.append(torch.stack([batches[i][t+1] for i in range(num_batch)]).type(torch.LongTensor).to(device))
    return context, label


In [8]:
model = impulsegpt.ImpulseGPT(config=config).to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [9]:
def train(dataset, model, loss_fn, optimizer, steps:int, num_batch=4):
    model.train()
    print(f"Start training with {steps} steps")
    for step in range(steps):
        batch = get_batch(dataset, num_batch, config.ctx_len)
        context, label = get_data(batch)
        step_loss = 0
        for i in range(len(label)):
            pred = model(context[i])
            loss = loss_fn(pred, label[i])

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            step_loss += loss.item()
        step_loss = step_loss / len(label)
        #print(f"Step {step} Complete with loss:{step_loss}")
        if step % 10 == 0:
            print(f"Step {step}: Loss = {step_loss}")
        
    


In [11]:
train(train_data, model, loss_fn, optimizer, 101, num_batch=4)

Start training with 101 steps
Step 0: Loss = 5.827203020453453
Step 10: Loss = 5.767085198312998
Step 20: Loss = 6.020226635038853
Step 30: Loss = 5.755541533231735
Step 40: Loss = 6.043134070932865
Step 50: Loss = 5.698156736791134
Step 60: Loss = 5.335667874664068
Step 70: Loss = 5.3639691062271595
Step 80: Loss = 5.080534815788269
Step 90: Loss = 5.090525045990944
Step 100: Loss = 5.5048802718520164
