In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from data.dataloaderyolo import CaptchaYOLODataset
from model.yolo import MiniYolo, MiniYoloMultiScale
from loss.yololoss import YOLOLoss, YOLOMultiScaleLoss
import os
from PIL import Image, ImageDraw
from utils.yoloinfer import decode_predictions
from torchvision import transforms
from torch.optim.lr_scheduler import StepLR



In [None]:
def yolo_collate(batch):
    imgs = torch.stack([b[0] for b in batch])
    targets = [b[1] for b in batch]  # list of (Ni,4)
    classes = [b[2] for b in batch]  # list of (Ni,)
    return imgs, targets, classes

In [None]:
ANNOTATION_FILE = "/kaggle/input/yolocaptcha/eval_annote.txt"
IMG_SIZE = (80, 200)  # (H, W) unified
BATCH_SIZE = 16
EPOCHS = 20
OUTPUT_DIR = "inference_results"
LEARNING_RATE = 1e-1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
anchors = [
(0.11500,0.51250),
(0.11000,0.42500),
(0.12000,0.42500),
(0.12000,0.53750),
(0.13000,0.55000),
(0.14500,0.60000),
(0.15000,0.8000),
(0.17000,0.8000),
]

dataset = CaptchaYOLODataset("/kaggle/input/yolocaptcha/yolo.txt", img_dir="/kaggle/input/yolocaptcha", img_size=(200,80))
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=yolo_collate)
model = MiniYoloMultiScale(in_chan=1, num_class=10, anchors=len(anchors)).to(DEVICE)
criterion = YOLOMultiScaleLoss(
    anchors, 
    num_class=10, 
    lambda_coord=5, 
    lambda_cls=1,
    scale_weights=[1, 1, 1,1,2]
)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

# checkpoint = torch.load("/kaggle/input/yolocaptcha/mini_yolo_multiscale_epoch_10.pth", map_location=DEVICE)
# model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
print(f"Model: {model.__class__.__name__}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Input: (1, 80, 200) -> Output: {model(torch.randn(1, 1, 80, 200).to(DEVICE)).shape}")

Model: MiniYoloMultiScale
Parameters: 11,422,625


AttributeError: 'tuple' object has no attribute 'shape'

In [None]:
for epoch in range(61):
    model.train()
    running_loss = 0.0
    epoch_loss_breakdown = {
        'loss_scale_1': 0.0,
        'loss_scale_2': 0.0, 
        'loss_scale_3': 0.0,
        'loss_scale_4': 0.0, 
        'loss_scale_5': 0.0,
        'total_loss': 0.0
    }
    
    for batch_idx, (imgs, targets, classes) in enumerate(loader):
        imgs = imgs.to(DEVICE)
        optimizer.zero_grad()
        preds = model(imgs)
        try:
            loss, loss_breakdown = criterion(preds, targets, classes)
            
            if torch.isnan(loss):
                print(f"NaN loss at epoch {epoch+1}, batch {batch_idx+1}")
                continue
                
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            running_loss += loss.item()
            
            for key in epoch_loss_breakdown:
                epoch_loss_breakdown[key] += loss_breakdown.get(key, 0)
            
                
        except Exception as e:
            print(f"Error in batch {batch_idx+1}: {e}")
            continue
    
    # scheduler.step()  # call after epoch
    avg_loss = running_loss / len(loader)
    
    # Average loss breakdown
    for key in epoch_loss_breakdown:
        epoch_loss_breakdown[key] /= len(loader)
    
    print(f"\nEpoch {epoch+1}/{EPOCHS} average Loss: {avg_loss:.6f}")
    
    # Save checkpoint
    if epoch%5 == 0:
        torch.save({
            "epoch": epoch+1,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            # "scheduler_state_dict": scheduler.state_dict(),
            "loss": avg_loss,
            "loss_breakdown": epoch_loss_breakdown
        }, f"/kaggle/working/checkpoints/mini_yolo_multiscale_epoch_{epoch+1}.pth")
    
    print("-" * 80)

print("Training completed!")


Epoch 1/20 average Loss: 45.969649
--------------------------------------------------------------------------------

Epoch 2/20 average Loss: 45.666250
--------------------------------------------------------------------------------

Epoch 3/20 average Loss: 44.400334
--------------------------------------------------------------------------------

Epoch 4/20 average Loss: 43.900904
--------------------------------------------------------------------------------

Epoch 5/20 average Loss: 42.791744
--------------------------------------------------------------------------------

Epoch 6/20 average Loss: 42.325467
--------------------------------------------------------------------------------

Epoch 7/20 average Loss: 41.806734
--------------------------------------------------------------------------------

Epoch 8/20 average Loss: 41.923437
--------------------------------------------------------------------------------

Epoch 9/20 average Loss: 41.343360
----------------------------

In [None]:
print("Running inference on evaluation set...")
model.eval()
# Load checkpoint if needed
# checkpoint = torch.load("checkpoints/mini_yolo_multiscale_epoch_1.pth", map_location=DEVICE)
# model.load_state_dict(checkpoint["model_state_dict"])
os.makedirs(OUTPUT_DIR, exist_ok=True)
with open(ANNOTATION_FILE, "r") as f:
    lines = f.readlines()

transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.Grayscale(num_output_channels=1),  # True grayscale
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])  
inference_anchors = anchors  # do not redefine to single anchor

with torch.no_grad():
    for line in lines:
        img_path, captcha_text = line.strip().split()
        img_pil = Image.open(os.path.join('/kaggle/input/yolocaptcha',img_path)).convert("RGB")
        orig_w, orig_h = img_pil.size
        img_tensor = transform(img_pil).unsqueeze(0).to(DEVICE)
        outputs = model(img_tensor)
        boxes, scores, sel_idx = decode_predictions(
            outputs, (IMG_SIZE[0], IMG_SIZE[1]), inference_anchors,
            conf_thresh=0.45
            , iou_thresh=0.35
        )[0]

        # Debug if empty
        if boxes.numel() == 0:
            if isinstance(outputs, (tuple, list)):
                max_obj = max(output[...,0].sigmoid().max().item() for output in outputs)
                print(f"No boxes for {img_path}; max obj: {max_obj}")
            else:
                print(f"No boxes for {img_path}; max obj:", outputs[...,0].sigmoid().max().item())
            continue

        # Get predicted classes for each box using sel_idx - handle multi-scale
        pred_classes = []
        if isinstance(outputs, (tuple, list)):
            # Multi-scale: use first scale for class prediction (largest resolution)
            cls_logits = outputs[0][...,5:][0]  # (A,H,W,num_class)
            for idx in sel_idx:
                a, y, x = idx.tolist()
                # Ensure indices are within bounds for first scale
                H_scale, W_scale = outputs[0].shape[2], outputs[0].shape[3]
                y_scaled = min(y, H_scale-1)
                x_scaled = min(x, W_scale-1)
                class_prob = cls_logits[a, y_scaled, x_scaled].softmax(dim=-1)
                pred_class = class_prob.argmax().item()
                pred_classes.append(pred_class)
        else:
            # Single scale (backward compatibility)
            cls_logits = outputs[...,5:][0]  # (A,H,W,num_class)
            for idx in sel_idx:
                a, y, x = idx.tolist()
                class_prob = cls_logits[a, y, x].softmax(dim=-1)
                pred_class = class_prob.argmax().item()
                pred_classes.append(pred_class)

        # Sort boxes and classes by x1 (left coordinate)
        boxes_np = boxes.cpu().numpy()
        sort_idx = boxes_np[:,0].argsort()
        boxes_sorted = boxes_np[sort_idx]
        pred_classes_sorted = [pred_classes[int(i)] for i in sort_idx.tolist()]

        # Draw boxes
        draw = ImageDraw.Draw(img_pil)
        scale_x = orig_w / IMG_SIZE[0]
        scale_y = orig_h / IMG_SIZE[1]
        for (x1,y1,x2,y2) in boxes_sorted:
            draw.rectangle([x1*scale_x, y1*scale_y, x2*scale_x, y2*scale_y],
                           outline="red", width=2)
        img_pil.save(os.path.join(OUTPUT_DIR, os.path.basename(img_path)))

        # Print predicted answer (as string of digits)
        answer = ''.join(str(c) for c in pred_classes_sorted)
        print(f"{os.path.basename(img_path)} | GT: {captcha_text} | Pred: {answer}")

Running inference on evaluation set...
captcha_0000.jpg | GT: 96997 | Pred: 0006
captcha_0001.jpg | GT: 66704 | Pred: 00020
captcha_0002.jpg | GT: 50309 | Pred: 0000
captcha_0003.jpg | GT: 76277 | Pred: 0006
captcha_0004.jpg | GT: 63932 | Pred: 0000
captcha_0005.jpg | GT: 25923 | Pred: 000
captcha_0006.jpg | GT: 07772 | Pred: 0060
captcha_0007.jpg | GT: 55996 | Pred: 0000
captcha_0008.jpg | GT: 50930 | Pred: 0000
captcha_0009.jpg | GT: 29967 | Pred: 00000
9_90835.png | GT: 90835 | Pred: 0000
1_45312.png | GT: 45312 | Pred: 000
