In [None]:
import tokenizers.processors


tokenizer = tokenizers.Tokenizer.from_file("data/icons/vocab/tokenizer.json")
tokenizer.post_processor = tokenizers.processors.TemplateProcessing(
    single="<SOS> $A <EOS>",
    special_tokens=[("<SOS>", 0), ("<EOS>", 1)],
)

In [None]:
print(sorted(tokenizer.get_vocab(), key=lambda x: tokenizer.get_vocab()[x]))

In [None]:
import torch
import torch.nn as nn

In [None]:
import os

class Dataloader:
    def __init__(self, use_train: bool):
        self.root = "data/icons/outline/" + ("train" if use_train else "test")
        self.files = os.listdir(self.root)
        self.use_train = use_train
        self.max_size = 0
        self.file = None
        self._make_corpus(file="data/icons/outline/train-corpus.txt" if use_train else "data/icons/outline/test-corpus.txt")
        self.corpus = self._load_corpus(file="data/icons/outline/train-corpus.txt" if use_train else "data/icons/outline/test-corpus.txt")
        print("Max Size", self.max_size)
        print("Files", self.file)

    def _make_corpus (self, file="data/icons/outline/train-corpus.txt"):
        with open(file, "w") as f:
            for file in self.files:
                with open(self.root + "/" + file) as g:
                    encoded = tokenizer.encode(g.read()).ids
                    self.max_size = max(self.max_size, len(encoded))
                    self.file = file
                    f.write(" ".join(map(str, encoded)) + " ")
    
    def _load_corpus(self, file="data/icons/outline/train-corpus.txt"):
        with open(file) as f:
            return list(map(int, f.read().split()))
        
    def __len__(self):
        return len(self.corpus) - 1
    
    def get_batch(self, batch_size: int, seq_len: int):
        idx = torch.randint(0, len(self.corpus) - seq_len, (batch_size,))

        return torch.stack([torch.tensor(self.corpus[i:i+seq_len]) for i in idx])

In [None]:
train_loader = Dataloader(True)
test_loader = Dataloader(False)

In [None]:
device = torch.device("mps")

In [None]:
print(tokenizer.token_to_id("[SOS]"))

In [None]:
import transformers

config = transformers.LlamaConfig(
    vocab_size=128,
    hidden_size=64,
    intermediate_size=128,
    num_hidden_layers=8,
    num_attention_heads=8,
    max_position_embeddings=256,
    attention_dropout=0,
    pad_token_id=tokenizer.token_to_id("<PAD>"),
    bos_token_id=tokenizer.token_to_id("<SOS>"),
    eos_token_id=tokenizer.token_to_id("<EOS>"),
)
model = transformers.LlamaForCausalLM(config).to(device)
num_train_steps = 0
print(f"Model using {model.num_parameters():,} parameters.")

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.9)


In [None]:
# Validation
@torch.no_grad()
def evaluate():
    model.eval()
    batch = test_loader.get_batch(64, 256).to(device)
    x = batch[:,:-1]
    y = batch[:,1:]

    output = model(x, labels=y)

    loss = output.loss

    return loss.item()

evaluate()

In [None]:
b = train_loader.get_batch(1, 256)
b[0].tolist()

In [None]:
def decode_f(x):
    print(tokenizer.decode(x[0].tolist()).replace("   ", " \t").replace(" ", "").replace("\t", " "))

In [None]:
import tqdm

In [None]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.999)

In [None]:
pbar = tqdm.trange(10_000)
for step in pbar:
    model.train()
    # Encode the sequence
    batch = train_loader.get_batch(64, 256).to(device)
    x = batch[:,:-1]
    y = batch[:,1:]
        
    # Forward pass
    output = model(x, labels=y)
    loss = output.loss
        
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # scheduler.step()

    ema.update()

    if step % 100 == 0:
        val_loss = evaluate()
        
    pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

In [None]:
model.eval()
with torch.no_grad():
    idxs = model.generate(torch.tensor([[1]]).to(device), max_length=256)
    print(idxs.flatten().tolist())
    decode_f(idxs)