In [1]:
import torch
import torch.nn as nn
import os

from tokenizers import ByteLevelBPETokenizer
from tokenizers import AddedToken

from cs336_basics.myModule import toy_Dataloader
from cs336_basics.myModule import toy_Transformer_lm
from cs336_basics.myOptimizer import toy_AdamW
from cs336_basics.myFunctional import toy_cross_entry, slow_generate, save_check_point , load_check_point, cosine_warm_up_lr, toy_grad_clip



In [2]:
#
vocab_path = "/root/workspace/cs336/assignment1/my_output/vocab.json"
merges_path = "/root/workspace/cs336/assignment1/my_output/merges.txt"
train_data_path = "/root/workspace/cs336/assignment1/data/TinyStoriesV2-GPT4-train.txt"
val_data_path = "/root/workspace/cs336/assignment1/data/TinyStoriesV2-GPT4-valid.txt"
weight_path = "/root/workspace/cs336/assignment1/my_output/weights/"


In [3]:
# files = ["/root/workspace/cs336/assignment1/data/TinyStoriesV2-GPT4-train.txt"]
# tokenizer = ByteLevelBPETokenizer()
# tokenizer.train(
#     files=files,
#     vocab_size=10_000,
#     min_frequency=2,
#     special_tokens=["<|endoftext|>"],
# )

# tokenizer.save_model("/root/workspace/cs336/assignment1/my_output")

In [4]:
tok = ByteLevelBPETokenizer(vocab_path, merges_path)
tok.add_special_tokens([AddedToken("<|endoftext|>", special=True)])
tok.get_vocab_size()

10000

In [5]:
device = "cuda:0"
vocab_size = tok.get_vocab_size()
max_context_len = 256
d_model = 512
d_ff = 1344
theta = 10000
num_layer = 4
num_heads = 16 

batch_size = 32
train_round = 40000

In [6]:
train_data_loader = toy_Dataloader(train_data_path, tok, max_context_len,4*batch_size ,device)
val_data_loader = toy_Dataloader(val_data_path, tok, max_context_len, 4*batch_size, device)
model = toy_Transformer_lm(vocab_size,max_context_len,d_model,num_layer,num_heads,d_ff,theta,device).to(device)
opt = toy_AdamW(model.parameters())

In [7]:
test_contetxt = "Once upon a time, there is a cat Tina"
generate_len = 50
max_lr = 1e-3
min_lr = 4e-4
warm_up_iter = 500
cos_end_iter = train_round - warm_up_iter

In [None]:
#Train Loop
for iter in range(train_round):
    data,target = train_data_loader.get_batch(batch_size)
    out = model(data)
    loss = toy_cross_entry(out,target)
    opt.zero_grad()
    loss.backward()
    opt.step(lr_new=cosine_warm_up_lr(iter,max_lr,min_lr,warm_up_iter,cos_end_iter))
    toy_grad_clip(model.parameters(),device = model.device)
    if iter%100 == 0:
        with torch.no_grad():
            val_data, val_target = val_data_loader.get_batch(batch_size)
            val_out = model(val_data)
            val_loss = toy_cross_entry(val_out,val_target)
            print("iter:{},loss:{},val_loss:{},lr:{}".format(iter,loss,val_loss,opt.param_groups[0]["lr"]))
            print(slow_generate(model, tok, test_contetxt, generate_len))
    if iter%10000 == 0:
        save_path = weight_path + "model_{}.pt".format(iter)
        save_check_point(model, opt, iter, save_path)
train_data_loader.close()    
print(slow_generate(model, tok, test_contetxt, generate_len))
save_path = weight_path + "model_final.pt"
save_check_point(model, opt, train_round, save_path)



iter:0,loss:9.377883911132812,val_loss:9.390636444091797,lr:0.0
Once upon a time, there is a cat Tina travels Buster lar rece pins prevented behavedaf clues natureantha rags resist traveled limp expert che eaten0 jealous pe heels filletsGrandma battery�aier cobwebWh response trembledChloe foroon Ph war bou�byCharlie eveniday sad Playing followed Zoom temucky
iter:100,loss:5.233175277709961,val_loss:5.123861312866211,lr:0.0002
Once upon a time, there is a cat Tina find. yet red to very snakes pr remembering tooth head. From that with you and instr to should the little back. He top a", you see with says found in theShe at idea. Can sorry of the wagged going. oak will
iter:200,loss:3.6209490299224854,val_loss:3.5198559761047363,lr:0.0004
Once upon a time, there is a cat Tina named crayons without Uni what. It lived blow Anna, it flew away. If there was story hard to play with a doll," remain said Sally. "Thank you, only a long train from in the story. weighing to their mom and the months