In [None]:
from transformer_classifier_util import MyDataset
from transformer_classifier import TransformerClassifier
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np

# Check if 'mps' is available
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using 'mps' device")
else:
    device = torch.device('cpu')
    print("Using 'cpu' device")

dataset_dir = '/Users/neoneye/git/python_arc/run_tasks_result/20241119_151040_jsonl_trainingdata'

dataset = MyDataset.load_jsonl_files(dataset_dir, 3)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize the model, loss function, and optimizer
model = TransformerClassifier(src_vocab_size=528, num_classes=10)
model.to(device)  # Move model to 'mps' device

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

num_epochs = 10

# Training loop
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    count_limit = 60000
    count = 0
    for batch_idx, (src, ys) in enumerate(dataloader):
        # Move data to 'mps' device
        src = src.to(device)
        ys = ys.to(device)

        optimizer.zero_grad()
        output = model(src)
        loss = criterion(output, ys)
        loss.backward()
        optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}], Loss: {loss.item():.4f}')

        count += 1
        if count >= count_limit:
            break


In [None]:
# After training
torch.save(model.state_dict(), '/Users/neoneye/git/python_arc/run_tasks_result/transformer_classifier.pth')
