<a href="https://colab.research.google.com/github/kimhanbut/AI_Project/blob/main/AI_CRNN_OCR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio --quiet
!pip install matplotlib opencv-python python-bidi arabic-reshaper --quiet
!pip install tqdm --quiet

In [3]:
import os

image_dir = "/content/drive/MyDrive/Colab Notebooks/car_num/car_num_img"
json_dir = "/content/drive/MyDrive/Colab Notebooks/car_num/car_num_json"

In [98]:
import json
import cv2
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
import torch.nn as nn

transform = transforms.Compose([
    transforms.Resize((32, 100)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

class OCRDataset(Dataset):
    def __init__(self, image_dir, json_dir, transform=transform):
        self.image_dir = image_dir
        self.samples = []
        self.transform = transform

        for file_name in os.listdir(json_dir):
            if not file_name.endswith(".json"):
                continue
            with open(os.path.join(json_dir, file_name), "r", encoding="utf-8") as f:
                data = json.load(f)
                image_path = os.path.join(image_dir, data["imagePath"])
                label = data["value"]
                if os.path.exists(image_path):
                    self.samples.append((image_path, label))

        # 문자 집합 구성
        self.charset = sorted(set(char for _, label in self.samples for char in label))
        self.char2idx = {char: idx + 1 for idx, char in enumerate(self.charset)}  # 0 = blank for CTC
        self.idx2char = {idx: char for char, idx in self.char2idx.items()}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_path, label = self.samples[idx]
        # OpenCV로 읽기 (그레이스케일)
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        # numpy -> PIL.Image 변환
        image = Image.fromarray(image)
        # transform이 있다면 적용 (여기서 resize도 함)
        if self.transform:
            image = self.transform(image)  # 이제 tensor로 변환됨

        label_idx = [self.char2idx[char] for char in label]
        return image, torch.tensor(label_idx, dtype=torch.long), label

    def decode(self, preds):
        # greedy decoding
        pred_text = ''
        prev_idx = -1
        for idx in preds:
            if idx != prev_idx and idx != 0:
                pred_text += self.idx2char.get(idx.item(), '')
            prev_idx = idx
        return pred_text

In [99]:
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 3, 1, 1), nn.ReLU(),
            nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(512, 512, 2, 1, 0), nn.ReLU()  # output: [B, 512, 1, W]
        )

        self.rnn1 = nn.LSTM(512, 512, bidirectional=True, batch_first=True)   #원래는 512, 256
        self.rnn2 = nn.LSTM(1024, 512, bidirectional=True, batch_first=True)   #원래는 512 256
        self.fc = nn.Linear(1024, num_classes)                                 #원래는 512

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        assert h == 1, f"Expected height=1, got {h}"
        x = x.squeeze(2).permute(0, 2, 1)  # [B, W, C]

        x, _ = self.rnn1(x)
        x, _ = self.rnn2(x)

        x = self.fc(x)                    # [B, W, C]
        x = x.permute(1, 0, 2)            # [W, B, C] = [T, B, C]   T = timescale, B = batch size, C = class amount(0,9, 가 하..+CTC Blank 포함 )
        return x

In [100]:
from torch.nn import CTCLoss
from torch.utils.data import DataLoader, random_split
from torch.optim import Adam
from tqdm import tqdm
from google.colab import drive
drive.mount('/content/drive')


def custom_collate(batch):
    images, labels_idx, labels_str = zip(*batch)
    images = torch.stack(images)  # (B, C, H, W)

    # labels_idx는 리스트 안에 각 샘플 라벨 텐서
    targets = torch.cat(labels_idx)  # concat all target sequences

    target_lengths = torch.tensor([len(label) for label in labels_idx], dtype=torch.long)

    return images, targets, target_lengths, labels_str


dataset = OCRDataset(image_dir, json_dir)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate)


model = CRNN(num_classes=len(dataset.char2idx) + 1)  # +1 for CTC blank
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

charset = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '가', '나', '다']  # 실제 문자들
char2idx = dataset.char2idx
idx_to_char = {idx: char for char, idx in char2idx.items()}

criterion = CTCLoss(blank=0, zero_infinity=True)
optimizer = Adam(model.parameters(), lr=0.0005)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def evaluate_model(model, dataloader, criterion, idx_to_char, device):
    model.eval()
    total_loss = 0
    total_chars = 0
    correct_chars = 0
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, targets, target_lengths, _ in dataloader:
            images = images.to(device)
            targets = targets.to(device)
            target_lengths = target_lengths.to(device)

            outputs = model(images)
            input_lengths = torch.full(size=(images.size(0),), fill_value=outputs.size(0), dtype=torch.long)

            outputs_log_softmax = outputs.log_softmax(2)
            loss = criterion(outputs_log_softmax, targets, input_lengths, target_lengths)
            total_loss += loss.item()

            # 여기서 디코딩은 배치 사이즈가 1일 때랑 다르게 처리해야 함.
            # 일단 배치 1개씩 처리하는 예시:

            for b in range(images.size(0)):
                preds = torch.argmax(outputs[:, b, :], dim=1).cpu().numpy().tolist()

                decoded = []
                prev = -1
                for p in preds:
                    if p != prev and p != 0:
                        decoded.append(p)
                    prev = p
                pred_text = ''.join([idx_to_char[i] for i in decoded])

                start_idx = sum(target_lengths[:b])
                end_idx = start_idx + target_lengths[b].item()
                true_text = ''.join([idx_to_char[i.item()] for i in targets[start_idx:end_idx]])

                min_len = min(len(pred_text), len(true_text))
                correct_chars += sum([pred_text[i] == true_text[i] for i in range(min_len)])
                total_chars += len(true_text)

                for t_char, p_char in zip(true_text, pred_text):
                    y_true.append(t_char)
                    y_pred.append(p_char)

    acc = correct_chars / total_chars if total_chars else 0
    prec = precision_score(y_true, y_pred, average='micro', zero_division=0)
    rec = recall_score(y_true, y_pred, average='micro', zero_division=0)

    return total_loss, acc, prec, rec

In [None]:
def pad_labels(labels):
    max_len = max(len(label) for label in labels)
    padded = torch.zeros(len(labels), max_len, dtype=torch.long)
    for i, label in enumerate(labels):
        padded[i, :len(label)] = label
    return padded



# 에폭마다 기록된 값들을 저장할 리스트
train_losses = []
val_losses = []
val_accuracies = []
val_precisions = []
val_recalls = []

for epoch in range(200):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader):
        images, targets, target_lengths, _ = batch
        images = images.to(device)
        targets = targets.to(device)
        target_lengths = target_lengths.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        input_lengths = torch.full(size=(images.size(0),), fill_value=outputs.size(0), dtype=torch.long)

        loss = criterion(outputs, targets, input_lengths, target_lengths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    val_loss, val_acc, val_prec, val_rec = evaluate_model(model, val_loader, criterion, idx_to_char, device)

    # 결과 저장
    train_losses.append(total_loss)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    val_precisions.append(val_prec)
    val_recalls.append(val_rec)

    print(f"[Epoch {epoch+1}] Train Loss: {total_loss:.4f} | Val Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | Precision: {val_prec:.4f} | Recall: {val_rec:.4f}")


In [None]:
model.eval()
with torch.no_grad():
    for i in range(30):
        image, label_idx, label = dataset[i]
        image = image.unsqueeze(0).to(device)
        output = model(image)
        pred = output.argmax(2)[:, 0]
        pred_text = dataset.decode(pred)
        print(f"[GT] {label} → [Pred] {pred_text}")

In [103]:
def ctc_decode(pred_indices, blank=0):
    decoded = []
    prev = None
    for idx in pred_indices:
        if idx != blank and idx != prev:
            decoded.append(idx)
        prev = idx
    return decoded

decoded_indices = ctc_decode(pred_indices[:, 0].tolist(), blank=0)
print("Decoded indices:", decoded_indices)

Decoded indices: [1, 2, 4, 5]


In [None]:
epochs = range(1, len(train_losses) + 1)

plt.figure(figsize=(14, 8))

# Loss
plt.subplot(2, 2, 1)
plt.plot(epochs, train_losses, label='Train Loss', marker='o')
plt.plot(epochs, val_losses, label='Val Loss', marker='o')
plt.title("Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(2, 2, 2)
plt.plot(epochs, val_accuracies, label='Accuracy', color='green', marker='o')
plt.title("Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.grid(True)

# Precision
plt.subplot(2, 2, 3)
plt.plot(epochs, val_precisions, label='Precision', color='orange', marker='o')
plt.title("Precision")
plt.xlabel("Epoch")
plt.ylabel("Precision")
plt.grid(True)

# Recall
plt.subplot(2, 2, 4)
plt.plot(epochs, val_recalls, label='Recall', color='red', marker='o')
plt.title("Recall")
plt.xlabel("Epoch")
plt.ylabel("Recall")
plt.grid(True)

plt.tight_layout()
plt.show()