In [1]:
from trainer import CaptchaTrainer
import torch
import numpy as np

In [2]:
trainer = CaptchaTrainer()
print("Trainer initialized successfully!")
print(f"Vocabulary size: {len(trainer.vocab)}")
print(f"Training data size: {len(trainer.train_gen)}")
print(f"Device: {next(trainer.model.parameters()).device}")

Trainer initialized successfully!
Vocabulary size: 14
Training data size: 510
Device: cpu


In [3]:
num_epochs = 10
save_every = 1
eval_every = 1 

In [4]:
val_loader = torch.utils.data.DataLoader(
    trainer.valid_gen,  # Assuming train_gen is the dataset or generator
    batch_size=8,
    shuffle=False,
    num_workers=2
)

In [9]:
train_losses = []
eval_losses = []

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs} \n")
    
    # Training
    trainer.train_epoch(epoch)
    epoch_loss = trainer.train_loss[-1]
    train_losses.append(epoch_loss)
    
    print(f"Training Loss: {epoch_loss:.4f}")
    if epoch % eval_every == 0:
        try:
            eval_loss, eval_accuracy = trainer.evaluate(val_loader)
            eval_losses.append(eval_loss)
            print(f"Validation Loss: {eval_loss:.4f}, Accuracy: {eval_accuracy:.2f}% \n")
        except Exception as e:
            print(f"Evaluation failed: {e} \n")
    
    # Save checkpoint
    if epoch % save_every == 0:
        checkpoint_path = f"checkpoints/model_epoch_{epoch}.pth"
        trainer.save_checkpoint(checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path} \n")
    
    # Prediction sample
    if epoch % eval_every == 0:
        try:
            pred_sents, actual_sents = trainer.predict(data_loader=val_loader, sample=5, use_beam_search= True)
            correct = sum(1 for pred, actual in zip(pred_sents, actual_sents) 
                        if pred.strip() == actual.strip())
            accuracy = correct / len(pred_sents) * 100
            print(f"Sample Accuracy: {accuracy:.2f}% ({correct}/{len(pred_sents)})")
            for i, (pred, actual) in enumerate(zip(pred_sents[:], actual_sents[:])):
                print(f"  {i+1}. Pred: '{pred}' | Actual: '{actual}'")
        except Exception as e:
            print(f"Prediction failed: {e}")

Epoch 1/10 



                                                                            

Training Loss: 0.6648
Validation Loss: 0.8166, Accuracy: 90.00% 

Checkpoint saved: checkpoints/model_epoch_1.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 2/10 



                                                                            

Training Loss: 0.6615
Validation Loss: 0.9659, Accuracy: 88.33% 

Checkpoint saved: checkpoints/model_epoch_2.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 3/10 



                                                                            

Training Loss: 0.6559
Validation Loss: 0.8275, Accuracy: 88.33% 

Checkpoint saved: checkpoints/model_epoch_3.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '6668' | Actual: '96997'
  2. Pred: '66668' | Actual: '66704'
  3. Pred: '6668' | Actual: '50309'
  4. Pred: '66666' | Actual: '76277'
  5. Pred: '86' | Actual: '63932'
Epoch 4/10 



                                                                            

Training Loss: 0.6556
Validation Loss: 0.9881, Accuracy: 85.00% 

Checkpoint saved: checkpoints/model_epoch_4.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 5/10 



                                                                            

Training Loss: 0.6744
Validation Loss: 0.8519, Accuracy: 88.33% 

Checkpoint saved: checkpoints/model_epoch_5.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 6/10 



                                                                            

Training Loss: 0.6584
Validation Loss: 0.7759, Accuracy: 91.67% 

Checkpoint saved: checkpoints/model_epoch_6.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 7/10 



                                                                            

Training Loss: 0.6517
Validation Loss: 0.7874, Accuracy: 91.67% 

Checkpoint saved: checkpoints/model_epoch_7.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 8/10 



                                                                            

Training Loss: 0.6391
Validation Loss: 0.8174, Accuracy: 88.33% 

Checkpoint saved: checkpoints/model_epoch_8.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 9/10 



                                                                            

Training Loss: 0.6431
Validation Loss: 0.7208, Accuracy: 95.00% 

Checkpoint saved: checkpoints/model_epoch_9.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'
Epoch 10/10 



                                                                             

Training Loss: 0.6504
Validation Loss: 0.9120, Accuracy: 88.33% 

Checkpoint saved: checkpoints/model_epoch_10.pth 

Sample Accuracy: 0.00% (0/5)
  1. Pred: '' | Actual: '96997'
  2. Pred: '' | Actual: '66704'
  3. Pred: '' | Actual: '50309'
  4. Pred: '' | Actual: '76277'
  5. Pred: '' | Actual: '63932'


In [11]:
trainer.load_checkpoint('./checkpoints/model_epoch_9.pth')
pred_sents, actual_sents = trainer.predict(data_loader=val_loader, sample=10, use_beam_search= False)
correct = sum(1 for pred, actual in zip(pred_sents, actual_sents) 
            if pred.strip() == actual.strip())
accuracy = correct / len(pred_sents) * 100
print(f"Sample Accuracy: {accuracy:.2f}% ({correct}/{len(pred_sents)})")
for i, (pred, actual) in enumerate(zip(pred_sents[:], actual_sents[:])):
    print(f"  {i+1}. Pred: '{pred}' | Actual: '{actual}'")

  pe = torch.tensor(self.pe[:,:seq_length, :self.d_model],requires_grad=False)


Step 0, next_tokens: tensor([13, 10,  9, 11, 10,  6,  4,  9])
Step 1, next_tokens: tensor([13, 10,  4, 10,  7,  9, 11,  9])
Step 2, next_tokens: tensor([11, 11,  7,  6, 13, 13, 11, 13])
Step 3, next_tokens: tensor([13,  4,  4, 11,  7,  6, 11, 13])
Step 4, next_tokens: tensor([11,  8, 13,  2,  6,  7,  6, 10])
Step 0, next_tokens: tensor([9, 6])
Step 1, next_tokens: tensor([ 4, 13])
Step 2, next_tokens: tensor([13, 13])
Step 3, next_tokens: tensor([ 7, 10])
Step 4, next_tokens: tensor([ 4, 11])
Sample Accuracy: 80.00% (8/10)
  1. Pred: '99797' | Actual: '96997'
  2. Pred: '66704' | Actual: '66704'
  3. Pred: '50309' | Actual: '50309'
  4. Pred: '7627' | Actual: '76277'
  5. Pred: '63932' | Actual: '63932'
  6. Pred: '25923' | Actual: '25923'
  7. Pred: '07772' | Actual: '07772'
  8. Pred: '55996' | Actual: '55996'
  9. Pred: '50930' | Actual: '50930'
  10. Pred: '29967' | Actual: '29967'


In [22]:
import os
import sys
import numpy as np
from torch.utils.data import Dataset, DataLoader
from utils.vocab import Vocab
from PIL import Image
import torch
import albumentations as A
def process_image(img):
    img = img.resize((200, 80), Image.LANCZOS)  # Fixed size
    img = np.asarray(img).transpose(2, 0, 1)
    img = img / 255
    return img
class PrivateDataset(Dataset):
    def __init__(self, annote_file, image_dir):
        self.image_dir = image_dir
        
        # Read annotation file
        with open(annote_file, 'r') as f:
            self.image_files = [line.strip() for line in f.readlines()]
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        
        # Load image and apply same preprocessing as training
        image = Image.open(img_path).convert('RGB')
        
        # Use the same process_image function as training (no augmentation)
        processed_img = process_image(image)  # This resizes to (200, 80) and normalizes
        
        # Return filename as identifier (no label available)
        filename = self.image_files[idx]
        
        return torch.FloatTensor(processed_img), filename

# Create private dataset and dataloader
def create_private_dataloader(annote_file, image_dir, batch_size=8):
    dataset = PrivateDataset(annote_file, image_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    return dataloader

# Prediction function for private set (no labels)
def predict_private_set(trainer, annote_file, image_dir, use_beam_search=True):
    """
    Generate predictions for private test set (no ground truth labels)
    """
    print("Creating private dataset...")
    private_loader = create_private_dataloader(annote_file, image_dir)
    
    print(f"Private dataset size: {len(private_loader.dataset)}")
    
    # Get predictions only (no accuracy calculation)
    print("Running predictions...")
    predictions = []
    filenames = []
    
    trainer.model.eval()
    
    with torch.no_grad():
        for batch_img, batch_filenames in private_loader:
            batch_img = batch_img.to(trainer.device)
            batch_size = batch_img.size(0)
            
            # Use trainer.predict method for both beam search and autoregressive
            batch_predictions = trainer.predict(img=batch_img, use_beam_search=use_beam_search)
            
            # Add predictions and filenames
            for b in range(batch_size):
                predictions.append(batch_predictions[b])
                filenames.append(batch_filenames[b])
    
    print(f"\n=== PRIVATE SET PREDICTIONS ===")
    print(f"Total samples: {len(predictions)}")
    
    # Show sample predictions
    print(f"\nSample predictions:")
    for i, (filename, pred) in enumerate(zip(filenames[:10], predictions[:10])):
        print(f"  {i+1:2d}. {filename} -> '{pred}'")
    
    # Save results to submission file
    results_file = "private_predictions.txt"
    with open(results_file, 'w') as f:
        for filename, pred in zip(filenames, predictions):
            f.write(f"{filename}\t{pred}\n")
    
    print(f"\nPredictions saved to: {results_file}")
    
    return predictions, filenames


In [24]:
from data.dataloader import *
predictions, filenames = predict_private_set(trainer, 'private_annote.txt', '', use_beam_search=False)

Creating private dataset...
Private dataset size: 20
Running predictions...

=== PRIVATE SET PREDICTIONS ===
Total samples: 20

Sample predictions:
   1. private/captcha_0000.jpg -> '290'
   2. private/captcha_0001.jpg -> '37774'
   3. private/captcha_0002.jpg -> '65236'
   4. private/captcha_0003.jpg -> '39468'
   5. private/captcha_0004.jpg -> '23039'
   6. private/captcha_0005.jpg -> '034'
   7. private/captcha_0006.jpg -> '63890'
   8. private/captcha_0007.jpg -> '02788'
   9. private/captcha_0008.jpg -> '64040'
  10. private/captcha_0009.jpg -> '8766'

Predictions saved to: private_predictions.txt
