In [1]:
from trainer import CaptchaTrainer
import torch
import matplotlib.pyplot as plt
import numpy as np

In [2]:
trainer = CaptchaTrainer()
print("Trainer initialized successfully!")
print(f"Vocabulary size: {len(trainer.vocab)}")
print(f"Training data size: {len(trainer.train_gen)}")
print(f"Device: {next(trainer.model.parameters()).device}")

num_epochs = 10
save_every = 1
eval_every = 1 

AttributeError: 'OCRData' object has no attribute 'aug'

In [None]:
train_losses = []
eval_losses = []

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs} \n")
    
    # Training
    trainer.train_epoch(epoch)
    epoch_loss = trainer.train_loss[-1]
    train_losses.append(epoch_loss)
    
    print(f"Training Loss: {epoch_loss:.4f}")
    if epoch % eval_every == 0:
        try:
            val_loader = torch.utils.data.DataLoader(
                trainer.train_gen, 
                batch_size=8, 
                shuffle=False, 
                num_workers=2
            )
            
            eval_loss = trainer.evaluate(val_loader)
            eval_losses.append(eval_loss)
            print(f"Validation Loss: {eval_loss:.4f} \n")
        except Exception as e:
            print(f"Evaluation failed: {e} \n")
    
    # Save checkpoint
    if epoch % save_every == 0:
        checkpoint_path = f"checkpoints/model_epoch_{epoch}.pth"
        trainer.save_checkpoint(checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path} \n")
    
    # Prediction sample
    if epoch % eval_every == 0:
        try:
            pred_sents, actual_sents, _ = trainer.predict(sample=10)
            
            # Calculate accuracy
            correct = sum(1 for pred, actual in zip(pred_sents, actual_sents) if pred.strip() == actual.strip())
            accuracy = correct / len(pred_sents) * 100
            
            print(f"Sample Accuracy: {accuracy:.2f}% ({correct}/{len(pred_sents)}) \n")
        except Exception as e:
            print(f"Prediction failed: {e}")

print("Training completed! \n")

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss')
if eval_losses:
    eval_epochs = list(range(eval_every, len(eval_losses) * eval_every + 1, eval_every))
    plt.plot(eval_epochs, eval_losses, 'r-', label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-')
plt.xlabel('Epoch')
plt.ylabel('Training Loss')
plt.title('Training Loss Detail')
plt.grid(True)

plt.tight_layout()
plt.show()
