<a href="https://colab.research.google.com/github/hhubert14/chess-ai/blob/main/chess_model_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets==3.2.0
!pip install transformers==4.47.1
!pip install peft==0.14.0

Collecting datasets==3.2.0
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets==3.2.0)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets==3.2.0)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets==3.2.0)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets==3.2.0)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloadin

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from accelerate.test_utils.testing import get_backend
from tqdm.auto import tqdm
from torch.amp import autocast, GradScaler
import torch
import math

# Adjustable variables
model_name = "openai-community/gpt2"
model_dir = "gpt2"
batch_size = 64
train_dataset_path = "/content/train_puzzles_lg.csv"
test_dataset_path = "/content/test_puzzles_lg.csv"
num_epochs = 6
learning_rate = 2e-5
max_length = 128
gradient_accumulation_steps = 2
dataset_size = 134583

# Initialize tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Basic tokenizer setup
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
model.config.pad_token_id = tokenizer.pad_token_id

def tokenize_batch(batch):
    inputs = tokenizer(
        [example['inputs'] for example in batch],
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    labels = tokenizer(
        [example['label'] for example in batch],
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors="pt"
    )
    inputs["labels"] = labels["input_ids"]
    return inputs

# Load datasets with shuffling
train_dataset = load_dataset(
    "csv",
    data_files={"train": train_dataset_path},
    streaming=True
)["train"].shuffle(seed=42, buffer_size=dataset_size)

eval_dataset = load_dataset(
    "csv",
    data_files={"test": test_dataset_path},
    streaming=True
)["test"].shuffle(seed=42, buffer_size=dataset_size)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=tokenize_batch)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, collate_fn=tokenize_batch)

# Setup optimizer and scheduler
total_training_steps = math.ceil((dataset_size / batch_size) * num_epochs / gradient_accumulation_steps)
num_warmup_steps = math.ceil(total_training_steps * 0.05)

optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_training_steps
)

# Device setup
device, _, _ = get_backend()
print(f"Using device: {device}")
model.to(device)

scaler = GradScaler('cuda')

# Training loop
print(f"Starting training for {num_epochs} epochs, {total_training_steps} total steps")
print(f"Warmup steps: {num_warmup_steps}")

model.train()
total_loss = 0
step_count = 0
best_eval_loss = float('inf')
patience = 5
patience_counter = 0

for epoch in range(num_epochs):
    print(f"\nEpoch: {epoch + 1}/{num_epochs}")
    progress_bar = tqdm(total=None)

    for batch_idx, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}

        with autocast('cuda'):
            outputs = model(**batch)
            loss = outputs.loss / gradient_accumulation_steps

        scaler.scale(loss).backward()
        total_loss += loss.item()

        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad()

            if step_count % 50 == 0:
                avg_loss = total_loss / gradient_accumulation_steps
                print(f"\nStep: {step_count}, Average Loss: {avg_loss:.4f}")
                total_loss = 0

            step_count += 1

            # Validation every 100 steps
            if step_count % 100 == 0:
                model.eval()
                eval_loss = 0
                eval_steps = 0

                print("\nRunning validation...")
                with torch.no_grad():
                    for eval_batch in eval_dataloader:
                        eval_batch = {k: v.to(device) for k, v in eval_batch.items()}
                        with autocast('cuda'):
                            eval_outputs = model(**eval_batch)
                            eval_loss += eval_outputs.loss.item()
                        eval_steps += 1
                        if eval_steps >= 50:
                            break

                avg_eval_loss = eval_loss / eval_steps
                print(f"Validation Loss: {avg_eval_loss:.4f}")

                if avg_eval_loss < best_eval_loss:
                    best_eval_loss = avg_eval_loss
                    model.save_pretrained(f"{model_dir}_best")
                    tokenizer.save_pretrained(f"{model_dir}_best")
                    print(f"New best model saved with loss: {best_eval_loss:.4f}")
                    patience_counter = 0
                else:
                    patience_counter += 1
                    if patience_counter >= patience:
                        print("Early stopping triggered!")
                        break

                model.train()

        progress_bar.update(1)

    if patience_counter >= patience:
        break

# Final evaluation
print("\nRunning final evaluation...")
model.eval()
predictions_texts = []

with torch.no_grad():
    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in eval_batch.items()}

        generated_ids = model.generate(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            max_new_tokens=32,  # Shorter, since chess moves are short
            num_return_sequences=1,
            do_sample=True,
            temperature=0.7,
            top_p=0.7,
            repetition_penalty=1.5,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

        decoded_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        predictions_texts.extend(decoded_preds)

        if "labels" in batch:
            labels = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)
            for pred, label in zip(decoded_preds, labels):
                if len(predictions_texts) <= 5:  # Print first few examples
                    print(f"\nInput: {tokenizer.decode(batch['input_ids'][0], skip_special_tokens=True)}")
                    print(f"Prediction: {pred}")
                    print(f"Actual: {label}")

print("\nSample Predictions:")
for i, prediction in enumerate(predictions_texts[:5]):
    print(f"Prediction {i + 1}: {prediction}")

print(f"\nTotal predictions: {len(predictions_texts)}")
print(f"Best validation loss: {best_eval_loss:.4f}")

Using device: cuda
Starting training for 6 epochs, 6309 total steps
Warmup steps: 316

Epoch: 1/6


0it [00:00, ?it/s]


Step: 0, Average Loss: 5.3494

Step: 50, Average Loss: 149.3044

Running validation...
Validation Loss: 0.2961
New best model saved with loss: 0.2961

Step: 100, Average Loss: 14.4420

Step: 150, Average Loss: 6.0059

Running validation...
Validation Loss: 0.1445
New best model saved with loss: 0.1445

Step: 200, Average Loss: 4.2318

Step: 250, Average Loss: 3.7754

Running validation...
Validation Loss: 0.1321
New best model saved with loss: 0.1321

Step: 300, Average Loss: 3.6072

Step: 350, Average Loss: 3.4893

Running validation...
Validation Loss: 0.1268
New best model saved with loss: 0.1268

Step: 400, Average Loss: 3.3874

Step: 450, Average Loss: 3.3056

Running validation...
Validation Loss: 0.1242
New best model saved with loss: 0.1242

Step: 500, Average Loss: 3.2448

Step: 550, Average Loss: 3.2023

Running validation...
Validation Loss: 0.1238
New best model saved with loss: 0.1238

Step: 600, Average Loss: 3.2140

Step: 650, Average Loss: 3.1670

Running validation...

0it [00:00, ?it/s]


Running validation...
Validation Loss: 0.1208

Step: 1100, Average Loss: 3.0774

Step: 1150, Average Loss: 3.0047

Running validation...
Validation Loss: 0.1176
New best model saved with loss: 0.1176

Step: 1200, Average Loss: 3.0121

Step: 1250, Average Loss: 3.0038

Running validation...
Validation Loss: 0.1177

Step: 1300, Average Loss: 2.9816

Step: 1350, Average Loss: 2.9952

Running validation...
Validation Loss: 0.1186

Step: 1400, Average Loss: 2.9969

Step: 1450, Average Loss: 2.9853

Running validation...
Validation Loss: 0.1181

Step: 1500, Average Loss: 2.9805

Step: 1550, Average Loss: 2.9551

Running validation...
Validation Loss: 0.1173
New best model saved with loss: 0.1173

Step: 1600, Average Loss: 2.9468

Step: 1650, Average Loss: 2.9895

Running validation...
Validation Loss: 0.1160
New best model saved with loss: 0.1160

Step: 1700, Average Loss: 2.9540

Step: 1750, Average Loss: 2.9351

Running validation...
Validation Loss: 0.1156
New best model saved with loss:

0it [00:00, ?it/s]


Step: 2150, Average Loss: 2.9622

Running validation...
Validation Loss: 0.1147
New best model saved with loss: 0.1147

Step: 2200, Average Loss: 2.8768

Step: 2250, Average Loss: 2.8943

Running validation...
Validation Loss: 0.1144
New best model saved with loss: 0.1144

Step: 2300, Average Loss: 2.8815

Step: 2350, Average Loss: 2.8674

Running validation...
Validation Loss: 0.1142
New best model saved with loss: 0.1142

Step: 2400, Average Loss: 2.8981

Step: 2450, Average Loss: 2.8867

Running validation...
Validation Loss: 0.1141
New best model saved with loss: 0.1141

Step: 2500, Average Loss: 2.8793

Step: 2550, Average Loss: 2.8769

Running validation...
Validation Loss: 0.1135
New best model saved with loss: 0.1135

Step: 2600, Average Loss: 2.8526

Step: 2650, Average Loss: 2.8443

Running validation...
Validation Loss: 0.1132
New best model saved with loss: 0.1132

Step: 2700, Average Loss: 2.8780

Step: 2750, Average Loss: 2.8588

Running validation...
Validation Loss: 0.

0it [00:00, ?it/s]


Running validation...
Validation Loss: 0.1121
New best model saved with loss: 0.1121

Step: 3200, Average Loss: 2.8821

Step: 3250, Average Loss: 2.8105

Running validation...
Validation Loss: 0.1126

Step: 3300, Average Loss: 2.8164

Step: 3350, Average Loss: 2.8120

Running validation...
Validation Loss: 0.1117
New best model saved with loss: 0.1117

Step: 3400, Average Loss: 2.7983

Step: 3450, Average Loss: 2.8236

Running validation...
Validation Loss: 0.1123

Step: 3500, Average Loss: 2.8090

Step: 3550, Average Loss: 2.8064

Running validation...
Validation Loss: 0.1119

Step: 3600, Average Loss: 2.8210

Step: 3650, Average Loss: 2.7814

Running validation...
Validation Loss: 0.1116
New best model saved with loss: 0.1116

Step: 3700, Average Loss: 2.7782

Step: 3750, Average Loss: 2.8155

Running validation...
Validation Loss: 0.1113
New best model saved with loss: 0.1113

Step: 3800, Average Loss: 2.7970

Step: 3850, Average Loss: 2.7909

Running validation...
Validation Loss:

0it [00:00, ?it/s]


Step: 4250, Average Loss: 2.8101

Running validation...
Validation Loss: 0.1110

Step: 4300, Average Loss: 2.7513

Step: 4350, Average Loss: 2.7611

Running validation...
Validation Loss: 0.1119

Step: 4400, Average Loss: 2.7543

Step: 4450, Average Loss: 2.7492

Running validation...
Validation Loss: 0.1101
New best model saved with loss: 0.1101

Step: 4500, Average Loss: 2.7612

Step: 4550, Average Loss: 2.7694

Running validation...
Validation Loss: 0.1099
New best model saved with loss: 0.1099

Step: 4600, Average Loss: 2.7570

Step: 4650, Average Loss: 2.7712

Running validation...
Validation Loss: 0.1101

Step: 4700, Average Loss: 2.7312

Step: 4750, Average Loss: 2.7419

Running validation...
Validation Loss: 0.1098
New best model saved with loss: 0.1098

Step: 4800, Average Loss: 2.7650

Step: 4850, Average Loss: 2.7604

Running validation...
Validation Loss: 0.1099

Step: 4900, Average Loss: 2.7508

Step: 4950, Average Loss: 2.7583

Running validation...
Validation Loss: 0.10

0it [00:00, ?it/s]


Running validation...
Validation Loss: 0.1090
New best model saved with loss: 0.1090

Step: 5300, Average Loss: 2.7668

Step: 5350, Average Loss: 2.7101

Running validation...
Validation Loss: 0.1090

Step: 5400, Average Loss: 2.7317

Step: 5450, Average Loss: 2.7191

Running validation...
Validation Loss: 0.1091

Step: 5500, Average Loss: 2.7057

Step: 5550, Average Loss: 2.7343

Running validation...
Validation Loss: 0.1090

Step: 5600, Average Loss: 2.7275

Step: 5650, Average Loss: 2.7229

Running validation...
Validation Loss: 0.1088
New best model saved with loss: 0.1088

Step: 5700, Average Loss: 2.7387

Step: 5750, Average Loss: 2.7037

Running validation...
Validation Loss: 0.1090

Step: 5800, Average Loss: 2.7076

Step: 5850, Average Loss: 2.7284

Running validation...
Validation Loss: 0.1089

Step: 5900, Average Loss: 2.7260

Step: 5950, Average Loss: 2.7187

Running validation...
Validation Loss: 0.1093

Step: 6000, Average Loss: 2.7421

Step: 6050, Average Loss: 2.7088

R