In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from torch.utils.data import Dataset
import os

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

save_dir = "./gpt2_token_raw_model/checkpoints/run_1"
os.makedirs(save_dir, exist_ok=True)

losses = []

tokenizer.add_special_tokens({'pad_token': '[PAD]'})

config = GPT2Config(
    n_embd=768,
    n_layer=12,
    n_head=12,
    vocab_size=tokenizer.vocab_size,
)
model = GPT2LMHeadModel(config).to(device)
model.resize_token_embeddings(len(tokenizer))

In [None]:
class BinaryToDecimalDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size=128):
        with open(file_path, 'r') as f:
            lines = f.read().splitlines()
        
        self.examples = tokenizer(lines, truncation=True, padding=True, max_length=block_size, return_tensors="pt")["input_ids"]
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        x = self.examples[idx]
        y = x.clone()
        return x, y

dataset = BinaryToDecimalDataset("./gpt2_token_raw_model/10k_data.txt", tokenizer)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
epochs = 20
learn_rate = 1e-3
optimizer = AdamW(model.parameters(), lr=learn_rate)
model.train()
losses = []

for epoch in range(epochs):
    epoch_loss = 0
    num_batches = 0

    for batch in dataloader:
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(input_ids=inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss += loss.item()
        num_batches += 1

    avg_loss = epoch_loss / num_batches
    losses.append(avg_loss)
    print(f"Epoch {epoch+1} has Avg Loss: {avg_loss:.4f}")

    if (epoch + 1) % 50 == 0:
        checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch+1}")
        model.save_pretrained(checkpoint_path)
        tokenizer.save_pretrained(checkpoint_path)
        print(f"Checkpoint to {checkpoint_path}")

# with open("./gpt2_token_raw_model/training_losses.txt", "w") as f:
#     for l in losses:
#         f.write(f"{l}\n")