In [1]:
import os
import torch
import pickle
import numpy as np
import torch.nn as nn

from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import PreTrainedTokenizerFast

MAX_SEQ_LEN = 256
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="models/tokenizer.json",
    pad_token="[PAD]",
    unk_token="[UNK]",
    eos_token="<|endoftext|>",
    max_len = MAX_SEQ_LEN,
    add_prefix_space=False
)

BATCH_SIZE = 64
VOCAB_SIZE = tokenizer.vocab_size
DEVICE="cuda"

train_file = "data/train-sampled"
test_file = "data/valid"

In [2]:
class TinyStoriesDataset(Dataset):
    def __init__(
        self,  
        input_file,
        tokenizer,
        seq_len,
        device="cuda",
        lazy_load=True
    ):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.DEVICE = device
        self.lazy_load = lazy_load
        
        self._load_data(input_file, lazy_load)

    def _load_data(self, file, lazy_load):
        memmap_flag = "r" if lazy_load else None
        self.data = np.load(f"{file}.npy", mmap_mode = memmap_flag)
        with open(f"{file}.pickle", "rb") as f:
            self.idx2pos = pickle.load(f)
    
    def __len__(self):
        return len(self.idx2pos)

    def __getitem__(self, idx):
        i = self.idx2pos[idx]
        x = self.data[i:i+self.seq_len]
        
        if (i+self.seq_len+1)>=len(self.data):
            y = np.pad(x[1:], pad_width=(0,1), mode="constant", constant_values=0)
        else:
            next_token = self.data[i+self.seq_len+1]
            if (
                x[-1] in (self.tokenizer.pad_token_id, self.tokenizer.eos_token_id) and 
                next_token!=self.tokenizer.pad_token_id
            ):
                y = np.pad(x[1:], pad_width=(0,1), mode="constant", constant_values=0)
            else:
                y = self.data[i+1:i+1+self.seq_len]

        return (
            torch.from_numpy(x.astype(np.int64)).to(self.DEVICE),
            torch.from_numpy(y.astype(np.int64)).to(self.DEVICE)
        )
        

In [3]:
test_loader = DataLoader(
    TinyStoriesDataset(
        input_file=test_file, 
        tokenizer=tokenizer,
        seq_len=MAX_SEQ_LEN,
        device=DEVICE,
        lazy_load=True
    ),
    batch_size=BATCH_SIZE,
    shuffle=False,
)

train_loader = DataLoader(
        TinyStoriesDataset(
        input_file=train_file,
        tokenizer=tokenizer,
        seq_len=MAX_SEQ_LEN,
        device=DEVICE,
        lazy_load=False
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads):
        super(TransformerBlock, self).__init__()
        self.norm_layer_1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(
            embed_dim=embed_dim, 
            num_heads=n_heads,
            bias=False,
            dropout=0.1,
            batch_first=True
        )

        self.norm_layer_2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*8//3,),
            nn.Linear(embed_dim*8//3, embed_dim,),
            nn.ReLU(),
        )
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # X --> (BATCH_SIZE, CONTEXT_LENGTH, EMBED_DIM) 
        x_norm = self.norm_layer_1(x)
        attn_mask = torch.triu(
            torch.zeros(
                (x_norm.size(1), x_norm.size(1))
            ), 
            diagonal=0
        ).to(x.device)
        attn_mask[attn_mask>0] = -torch.inf
        x_norm, _ = self.attention(
            x_norm, 
            x_norm, 
            x_norm, 
            attn_mask=attn_mask, 
            is_causal=True
        )
        x = x + x_norm

        x_norm = self.norm_layer_2(x)
        x_norm = self.ffn(x_norm)
        x_norm = self.dropout(x_norm)

        return x + x_norm

In [5]:
class TransformerLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len=MAX_SEQ_LEN, n_layers=5, n_heads=4):
        super(TransformerLM, self).__init__()
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.position_emb = nn.Embedding(max_seq_len, embed_dim)

        self.dropout = nn.Dropout(0.1)

        self.transfomers = nn.Sequential(
            *[
                TransformerBlock(
                    embed_dim=embed_dim,
                    n_heads=n_heads
                ) for _ in range(n_layers)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x:torch.Tensor):
        x = self.dropout(
            self.token_emb(x) + self.position_emb(torch.arange(x.size(1), device=x.device))
        )
        x = self.transfomers(x)
        x = self.norm(x)
        x = self.out(x)

        return x.reshape((x.shape[0], x.shape[2], x.shape[1]))

In [6]:
def init_weights(layer_in):
    if isinstance(layer_in, nn.Linear):
        nn.init.xavier_uniform_(layer_in.weight, gain=nn.init.calculate_gain('relu'))
        if layer_in.bias != None:
            nn.init.zeros_(layer_in.bias)

In [7]:
tlm = TransformerLM(
    vocab_size=VOCAB_SIZE,
    embed_dim=256,
    max_seq_len=MAX_SEQ_LEN,
    n_layers=4,
    n_heads=8,
).to(DEVICE)

tlm.apply(init_weights)
# tlm.load_state_dict(
#     torch.load("models/model_0_0", weights_only=True)
# )
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(tlm.parameters(), lr=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20000, eta_min=6*1e-6)

In [8]:
def train_model(
    model,
    train_loader,
    test_loader,
    loss_fn,
    optimizer, 
    scheduler,
    n_epochs,
    last_epoch=0,
    best_vloss=1_000_000,
    model_id=0,
    callback_no_upgrade=10,
    n_accumulation_steps = 4,
):
    try:
        n_epochs_no_upgrade = 0
        scaler = torch.cuda.amp.GradScaler()
        
        for epoch in range(last_epoch, last_epoch + n_epochs):
            model.train()
            running_loss = 0.0
            
            n_batches = len(train_loader)
            progress_bar = tqdm(
                enumerate(train_loader), 
                total=n_batches, 
                leave=True,
                desc=f"Epoch [({epoch + 1} / {last_epoch + n_epochs})]"
            )
            
            for i, (inputs, targets) in progress_bar:
                

                with torch.cuda.amp.autocast():
                    outputs = model(inputs)  # Forward pass
                    loss = loss_fn(outputs, targets) / n_accumulation_steps  # Compute loss
                # outputs = model(inputs)
                # loss = loss_fn(outputs, targets) / n_accumulation_steps

                # loss.backward()
                scaler.scale(loss).backward()  # Backpropagation
                if (i+1) % n_accumulation_steps == 0 or (i+1)==n_batches:
                    scaler.step(optimizer)#.step()  # Optimization step
                    if scheduler!=None: scheduler.step()
                    scaler.update()

                    # optimizer.step()
                    # if scheduler!=None: scheduler.step()
                    # optimizer.zero_grad()  # Zero the gradients
                
                running_loss+=loss.item()*n_accumulation_steps
                    
                if (i+1) == n_batches:
                    # finishing last batch ... calc validation loss
                    progress_bar.set_postfix(loss=running_loss/(i+1), val_loss="Calculating...")
                    vloss = 0.0
                    model.eval()
                    with torch.no_grad():
                        for inputs, targets in test_loader:
                            with torch.amp.autocast("cuda"):
                                outputs = model(inputs)
                                loss = loss_fn(outputs, targets)

                            vloss += loss.item()
                            
                    vloss=vloss/len(test_loader)
                    progress_bar.set_postfix(loss=running_loss/(i+1), val_loss=vloss)
                    
                else:
                    progress_bar.set_postfix(loss=running_loss/(i+1))

            if vloss < best_vloss:
                best_vloss = vloss
                os.makedirs("models", exist_ok=True)
                model_path = 'models//model_{}_{}'.format(model_id, epoch)
                torch.save(model.state_dict(), model_path)

            else:
                n_epochs_no_upgrade += 1

            if n_epochs_no_upgrade >= callback_no_upgrade:
                break

        return epoch, best_vloss
    except KeyboardInterrupt:
        return epoch, best_vloss

In [9]:
last_epoch=0;best_loss=1_000_000

In [10]:
last_epoch, best_loss = train_model(
    model=tlm,
    train_loader=train_loader,
    test_loader=test_loader,
    loss_fn=criterion,
    optimizer=optimizer,
    scheduler=None,
    n_epochs=3,
    last_epoch=last_epoch,
    best_vloss=best_loss,
    n_accumulation_steps=8,
    callback_no_upgrade=3
)
print(last_epoch, best_loss)

Epoch [(1 / 3)]: 100%|██████████| 64940/64940 [11:26:23<00:00,  1.58it/s, loss=3.5, val_loss=1.87]          
Epoch [(2 / 3)]: 100%|██████████| 64940/64940 [11:21:16<00:00,  1.59it/s, loss=1.67, val_loss=0.783]         
Epoch [(3 / 3)]: 100%|██████████| 64940/64940 [11:17:16<00:00,  1.60it/s, loss=1.21, val_loss=0.783]         


2 0.7830687470494941


In [11]:
best_loss

0.7830687470494941

In [12]:

def _get_next_token(
        model,
        tokens,
        next_token_pos,
        k=10
    ):
    
    with torch.no_grad():
        pred_tokens = model(tokens)[:,:,next_token_pos]
        top_k = torch.argsort(pred_tokens,dim=1,descending=True)[:,:k]
        pred_tokens = pred_tokens[:, top_k[0]]
        probs = torch.softmax(pred_tokens, dim=1)[0]
        next_token = top_k[:, probs.multinomial(1)[0]]
        #next_token = torch.softmax(pred_tokens[:, :, next_token_pos], dim=1).argmax().reshape(-1,1)
    return next_token

def get_story(model, tokenizer, seed_text, max_generated,device="cuda", k=10):
    base = tokenizer(
        seed_text,
        padding="max_length",
        truncation=False,
        max_length=MAX_SEQ_LEN-1,
        return_tensors="pt",
    )
    
    n_generated = 0
    next_token_pos = base["attention_mask"].sum().item() - 1

    tokens = (base.input_ids * base.attention_mask).to(device)
    window_filter_pos = max(0, next_token_pos - MAX_SEQ_LEN + 1)
    
    while n_generated < max_generated:
        next_token = _get_next_token(
            model,
            tokens[window_filter_pos:],
            next_token_pos,
            k=k
        )

        if next_token == tokenizer.eos_token_id:
            break

        if next_token_pos < MAX_SEQ_LEN - 1:
            tokens[0, next_token_pos+1] = next_token
            next_token_pos += 1
        else:
            tokens = torch.cat([tokens, next_token.reshape(-1,1)], dim=1)
            window_filter_pos += 1
        n_generated += 1
        
    return tokenizer.decode(tokens.squeeze()[:next_token_pos+1])
    

In [19]:
temp = get_story(tlm, tokenizer, "once upon a time", 30, k=2)
with open("temp.txt", "w", encoding="utf-8") as f:
    f.write(temp)

In [14]:
tokenizer.pad_token_id

2

In [20]:
ord("恩")

24681

In [21]:
tokenizer.encode("恩")

[3, 461, 3]

In [26]:
tokenizer.decode([3,461,3])

' 恩  '

In [24]:
tokenizer.decode([3])

' '

In [25]:
tokenizer.decode(461)

'##恩'