In [None]:
# Import necessary libraries and define utility functions
import psutil
import gc
import sys
from typing import Optional
import os, random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import math
import numpy as np
from collections import OrderedDict
import re
from collections import defaultdict
import glob
from tqdm import tqdm
import pandas as pd
import json

import time
from tqdm.notebook import tqdm
import Levenshtein

# device and seed 설정
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
os.environ['PYTHONHASHSEED']='42'
random.seed(42); np.random.seed(42); torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42); torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

# 메모리 관리 함수
def check_memory_usage(context: Optional[str] = None, threshold: float = 0.9) -> bool:
    """메모리 사용량 체크 함수"""
    mem = psutil.virtual_memory()
    used_ratio = mem.used / mem.total
    
    if context:
        print(f"Memory usage at '{context}': {used_ratio:.2%} used")
    else:
        print(f"Current memory usage: {used_ratio:.2%} used")

    return used_ratio < threshold

def cleanup_memory(aggressive: bool = False):
    """메모리 정리 함수"""
    if aggressive:
        for obj in gc.get_objects():
            if hasattr(obj, 'cache_clear'):
                obj.cache_clear()
    
    collected = gc.collect()

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_available():
        if hasattr(torch.mps, 'empty_cache'):
            torch.mps.empty_cache()
    
    return collected

def safe_load_images(image_paths: list, max_batch_size: int = 50):
    """안전한 이미지 배치 로딩"""
    for i in range(0, len(image_paths), max_batch_size):
        batch_paths = image_paths[i:i + max_batch_size]
        yield batch_paths

        if not check_memory_usage(f"Image batch {i//max_batch_size + 1}"):
            print("Performing memory cleanup...")
            cleanup_memory()

# 초기 메모리 상태 체크
check_memory_usage("Initial startup")
print(f"Using device: {device}")

In [None]:
# Define paths
train_boxes_path = "train_extracted_boxes"
train_labels_path = "train_label"
val_boxes_path = "val_extracted_boxes"
val_labels_path = "valid_label"

# 이미지 전처리 함수
def load_image_as_grayscale_array(image_path):
    """Load image and convert to grayscale array"""
    try:
        image = Image.open(image_path).convert('L')  # 흑백 변환
        return np.array(image)
    except Exception as e:
        return None

def extract_document_id(filename):
    """Extract document ID from filename"""
    parts = filename.split('_box_')
    if len(parts) >= 2:
        return parts[0]
    return None

def load_data(boxes_path, labels_path):
    """
    Load extracted boxes and corresponding JSON labels
    """
    box_files = glob.glob(os.path.join(boxes_path, "*.png"))
    document_boxes = defaultdict(list)
    
    for box_file in box_files:
        filename = os.path.basename(box_file)
        doc_id = extract_document_id(filename)
        if doc_id:
            document_boxes[doc_id].append(box_file)

    json_files = glob.glob(os.path.join(labels_path, "*.json"))
    json_dict = {}
    
    for json_file in json_files:
        filename = os.path.basename(json_file)
        doc_id = filename.replace('.json', '')
        json_dict[doc_id] = json_file
    
    li_X = []
    df_y_data = []
    
    valid_doc_ids = set(document_boxes.keys()) & set(json_dict.keys())
    print(f"Processing {len(valid_doc_ids)} documents with both boxes and labels")
    
    for doc_id in tqdm(sorted(valid_doc_ids)):
        try:
            with open(json_dict[doc_id], 'r', encoding='utf-8') as f:
                label_data = json.load(f)
        
            box_paths = sorted(document_boxes[doc_id])
            document_images = []
            
            for box_path in box_paths:
                img_array = load_image_as_grayscale_array(box_path)
                if img_array is not None:
                    document_images.append(img_array)
            
            if document_images:
                li_X.append(document_images)
                
                df_y_entry = {
                    'document_id': doc_id,
                    'bbox': label_data.get('bbox', []),
                    'images_metadata': label_data.get('Images', {}),
                    'dataset_metadata': label_data.get('Dataset', {}),
                    'annotation_metadata': label_data.get('Annotation', {}),
                    'num_boxes': len(document_images)
                }
                df_y_data.append(df_y_entry)
                
        except Exception as e:
            print(f"Error processing document {doc_id}: {e}")
            continue
    
    # Create DataFrame
    df_y = pd.DataFrame(df_y_data)
    
    return li_X, df_y

# Load training and validation data
li_train_X, df_train_y = load_data(train_boxes_path, train_labels_path)
li_val_X, df_val_y = load_data(val_boxes_path, val_labels_path)

li_df_bbox_train = []
for i in range(df_train_y.shape[0]):
    li_df_bbox_train.append(pd.DataFrame(df_train_y.iloc[i]['bbox']))
li_df_bbox_val = []
for i in range(df_val_y.shape[0]):
    li_df_bbox_val.append(pd.DataFrame(df_val_y.iloc[i]['bbox']))

In [None]:
# 사전 정의
def build_korean_vocab():
    """Build comprehensive vocabulary with Korean, English, and special tokens"""
    korean_chars = "가나다라마바사아자차카타파하거너더러머버서어저처커터퍼허갸냐댜랴먀뱌샤야쟈챠캬탸퍄햐"
    korean_chars += "개내대래매배새애재채캐태패해걔냬댸럐먜뱨섀얘쟤챼캬탸퍄혜"
    korean_chars += "고노도로모보소오조초코토포호교뇨됴료묘뵤쇼요죠쵸쿄툐표효"
    korean_chars += "구누두루무부수우주추쿠투푸후규뉴듀류뮤뷰슈유쥬츄큐튜퓨휴"
    korean_chars += "그느드르므브스으즈츠크트프흐긔늬디리미비시이지치키티피히"
    
    korean_chars += "각간갇갈갉갊감갑값갓강갖갗같갚갛개객간갤갬갭갯갰갱"
    korean_chars += "국굳굴굵굶굼굽굿궁궁군궐궤권궷궴궵귀귄귐귑귓규균귤귀규"
    
    all_chars = set()

    consonants = "ㄱㄲㄴㄷㄸㄹㄻㄼㄽㄾㄿㅀㅁㅂㅃㅄㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ"
    vowels = "ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ"
    
    all_chars.update(consonants)
    all_chars.update(vowels)
    all_chars.update(korean_chars)

    for i in range(0xAC00, 0xD7A4):
        all_chars.add(chr(i))

    english_lower = "abcdefghijklmnopqrstuvwxyz"
    english_upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    all_chars.update(english_lower)
    all_chars.update(english_upper)
    
    numbers = "0123456789"
    all_chars.update(numbers)

    punctuation = "().,;:!?-"
    all_chars.update(punctuation)

    email_web_chars = "@_./" 
    all_chars.update(email_web_chars)

    additional_symbols = "[]{}\"'`~#$%^&*+=<>|\\~"  
    all_chars.update(additional_symbols)

    all_chars.add(" ")

    math_currency = "±×÷=≠≤≥∞∑∫√π°℃℉€$¥₩%"
    all_chars.update(math_currency)

    vocab_chars = sorted(list(all_chars))

    char_to_idx = {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, '<UNK>': 3}
    for i, char in enumerate(vocab_chars):
        char_to_idx[char] = i + 4
    
    idx_to_char = {v: k for k, v in char_to_idx.items()}
    
    return char_to_idx, idx_to_char, len(char_to_idx)

# 사전 구축
char_to_idx, idx_to_char, vocab_size = build_korean_vocab()

# Swin Transformer 함수
class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""
    def __init__(self, img_size=(112, 448), patch_size=4, in_chans=1, embed_dim=96):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = [img_size[0] // patch_size, img_size[1] // patch_size]
        self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        x = self.norm(x)
        return x

class WindowAttention(nn.Module):
    """Window based multi-head self attention (W-MSA) module with relative position bias"""
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # Get pair-wise relative position bias
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij'))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = F.softmax(attn, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

def window_partition(x, window_size):
    """Partition into non-overlapping windows"""
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """Reverse of window partition"""
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class SwinTransformerBlock(nn.Module):
    """Swin Transformer Block""" 
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)

        self.norm2 = nn.LayerNorm(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(drop)
        )

        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            h_slices = (slice(0, -self.window_size),
                       slice(-self.window_size, -self.shift_size),
                       slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                       slice(-self.window_size, -self.shift_size),
                       slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1

            mask_windows = window_partition(img_mask, self.window_size)
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Partition windows
        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)

        # Merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x

class PatchMerging(nn.Module):
    """Patch Merging Layer"""
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

class BasicLayer(nn.Module):
    """A basic Swin Transformer layer"""
    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
                 drop_path=0., downsample=None):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth

        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                               num_heads=num_heads, window_size=window_size,
                               shift_size=0 if (i % 2 == 0) else window_size // 2,
                               mlp_ratio=mlp_ratio,
                               qkv_bias=qkv_bias,
                               drop=drop, attn_drop=attn_drop,
                               drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path)
            for i in range(depth)])

        # Patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

class SwinTransformer(nn.Module):
    """Swin Transformer backbone"""
    def __init__(self, img_size=(112, 448), patch_size=4, in_chans=1, embed_dim=96,
                 depths=[2, 6, 2], num_heads=[6, 12, 24],
                 window_size=7, mlp_ratio=4., qkv_bias=True,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1):
        super().__init__()

        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio

        # Patch embedding
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        patches_resolution = self.patch_embed.patches_resolution
        self.patches_resolution = patches_resolution

        # Stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        # Build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                             input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                             patches_resolution[1] // (2 ** i_layer)),
                             depth=depths[i_layer],
                             num_heads=num_heads[i_layer],
                             window_size=window_size,
                             mlp_ratio=self.mlp_ratio,
                             qkv_bias=qkv_bias,
                             drop=drop_rate, attn_drop=attn_drop_rate,
                             drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                             downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
            self.layers.append(layer)

        self.norm = nn.LayerNorm(int(embed_dim * 2 ** (self.num_layers - 1)))

    def forward(self, x):
        x = self.patch_embed(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)  # B L C
        return x

class SwinTransformerOCR(nn.Module):
    """Complete Swin Transformer OCR model with encoder-decoder architecture"""
    def __init__(self, img_size=(112, 448), vocab_size=vocab_size, max_length=32,
                 embed_dim=96, depths=[2, 6, 2], num_heads=[6, 12, 24]):
        super().__init__()
        
        self.max_length = max_length
        self.vocab_size = vocab_size
        
        # Swin Transformer encoder
        self.swin = SwinTransformer(
            img_size=img_size,
            embed_dim=embed_dim,
            depths=depths,
            num_heads=num_heads
        )
        
        final_dim = int(embed_dim * 2 ** (len(depths) - 1))
        
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        
        decoder_dim = 384
        self.feature_proj = nn.Linear(final_dim, decoder_dim)
        
        # Transformer decoder
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=decoder_dim,
            nhead=8,
            dim_feedforward=decoder_dim * 4,
            dropout=0.1
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=4)
        
        self.output_proj = nn.Linear(decoder_dim, vocab_size)
        
        self.pos_encoding = nn.Parameter(torch.randn(max_length, decoder_dim)) # Positional encoding
        self.text_embedding = nn.Embedding(vocab_size, decoder_dim)
        
    def forward(self, images, target_sequences=None):
        batch_size = images.size(0)
        
        features = self.swin(images)  # B, H*W, C
        
        global_context = self.global_pool(features.transpose(1, 2)).squeeze(-1)  # B, C
        global_context = self.feature_proj(global_context).unsqueeze(0)  # 1, B, decoder_dim
        
        if self.training and target_sequences is not None:
            # Teacher forcing
            seq_len = target_sequences.size(1)
            target_embeds = self.text_embedding(target_sequences)  # B, seq_len, decoder_dim
            target_embeds += self.pos_encoding[:seq_len].unsqueeze(0)  # Add positional encoding
            
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=images.device), diagonal=1).bool()
            
            target_embeds = target_embeds.transpose(0, 1)  # seq_len, B, decoder_dim
            output = self.transformer_decoder(
                target_embeds,
                global_context,
                tgt_mask=causal_mask
            )
            
            output = self.output_proj(output.transpose(0, 1))  # B, seq_len, vocab_size
            return output
        else:
            # Autoregressive inference
            return self.generate(global_context, max_length=self.max_length)
    
    def generate(self, memory, max_length=32):
        batch_size = memory.size(1)
        device = memory.device
        
        generated = torch.full((batch_size, 1), char_to_idx['<BOS>'], device=device, dtype=torch.long)
        
        for i in range(max_length - 1):
            current_embeds = self.text_embedding(generated)  # B, current_len, decoder_dim
            seq_len = current_embeds.size(1)
            current_embeds += self.pos_encoding[:seq_len].unsqueeze(0)
            
            causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
            
            # Decoder forward
            current_embeds = current_embeds.transpose(0, 1)  # current_len, B, decoder_dim
            output = self.transformer_decoder(
                current_embeds,
                memory,
                tgt_mask=causal_mask
            )
            
            next_token_logits = self.output_proj(output[-1])  # B, vocab_size
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # B, 1
            
            generated = torch.cat([generated, next_token], dim=1)
            
            if torch.all(next_token.squeeze(-1) == char_to_idx['<EOS>']):
                break
        
        return generated

In [None]:
# Dataset class for OCR training
class OCRDataset(Dataset):
    def __init__(self, li_images, li_df_bbox, transform=None, max_length=32):
        self.li_images = li_images
        self.li_df_bbox = li_df_bbox
        self.transform = transform
        self.max_length = max_length
        
        self.data_pairs = []
        for doc_idx, (images, df_bbox) in enumerate(zip(li_images, li_df_bbox)):
            for img_idx, image in enumerate(images):
                if img_idx < len(df_bbox):
                    text = df_bbox.iloc[img_idx]['data'] if 'data' in df_bbox.columns else ""
                    if isinstance(text, list) and len(text) > 0:
                        text = text[0]  
                    elif not isinstance(text, str):
                        text = str(text)
                    self.data_pairs.append((image, text, doc_idx, img_idx))
    
    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, idx):
        image, text, doc_idx, img_idx = self.data_pairs[idx]
        
        # Preprocess image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image.astype(np.uint8))

        image = image.resize((448, 112), Image.Resampling.LANCZOS)
        
        image = np.array(image).astype(np.float32) / 255.0 # Normalize to [0, 1]
        image = torch.from_numpy(image).unsqueeze(0)  # torch tensor with channel dimension
    
        tokens = self.text_to_tokens(text)

        return image, tokens, text

    def text_to_tokens(self, text):
        """Convert text to token indices"""
        tokens = [char_to_idx['<BOS>']]
        
        for char in text[:self.max_length-2]:
            if char in char_to_idx:
                tokens.append(char_to_idx[char])
            else:
                tokens.append(char_to_idx['<UNK>'])
        
        tokens.append(char_to_idx['<EOS>'])
        
        while len(tokens) < self.max_length: 
            tokens.append(char_to_idx['<PAD>']) # padding
        
        return torch.tensor(tokens[:self.max_length], dtype=torch.long)

def collate_fn(batch):
    """Custom collate function for OCR dataset"""
    images, tokens, texts = zip(*batch)
    
    images = torch.stack(images, 0)
    tokens = torch.stack(tokens, 0)
    
    return images, tokens, texts


train_dataset = OCRDataset(li_train_X, li_df_bbox_train)
val_dataset = OCRDataset(li_val_X, li_df_bbox_val)

batch_size = 8  
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# Helper functions
def tokens_to_text(tokens, remove_special=True):
    """Convert token indices back to text"""
    text = ""
    for token in tokens:
        if token == char_to_idx['<PAD>']:
            break
        elif token == char_to_idx['<EOS>']:
            break
        elif remove_special and token in [char_to_idx['<BOS>'], char_to_idx['<UNK>']]:
            continue
        else:
            if token in idx_to_char:
                text += idx_to_char[token]
    return text

def calculate_accuracy(predictions, targets):
    """Calculate character-level accuracy"""
    pred_texts = [tokens_to_text(pred.cpu().numpy()) for pred in predictions]
    target_texts = [tokens_to_text(target.cpu().numpy()) for target in targets]
    
    correct_chars = 0
    total_chars = 0
    
    for pred, target in zip(pred_texts, target_texts):
        for p, t in zip(pred, target):
            if p == t:
                correct_chars += 1
            total_chars += 1
    
    return correct_chars / total_chars if total_chars > 0 else 0.0

In [None]:
# 모델 초기화
model = SwinTransformerOCR(
    img_size=(112, 448),
    vocab_size=vocab_size,
    max_length=32,
    embed_dim=96,
    depths=[2, 6, 2],
    num_heads=[6, 12, 24]
).to(device)

# Loss function and optimizer for enhanced model
criterion = nn.CrossEntropyLoss(ignore_index=char_to_idx['<PAD>'])
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=3)

In [None]:
# Training
num_epochs = 20
best_val_accuracy = 0
epoch_start_time = time.time()

for epoch in range(1, num_epochs + 1):

    model.train()
    total_loss = 0
    total_accuracy = 0
    batch_count = 0

    text_samples = []

    # Create training progress bar
    train_pbar = tqdm(train_loader, 
                      desc=f"Epoch {epoch} Training", 
                      total=len(train_loader), 
                      unit="batch",
                      leave=True,
                      position=0,
                      file=sys.stdout)
    
    for batch_idx, (images, targets, texts) in enumerate(train_pbar):
        images = images.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass (teacher forcing)
        outputs = model(images, targets[:, :-1])
        loss = criterion(outputs.reshape(-1, vocab_size), targets[:, 1:].reshape(-1))

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Calculate accuracy
        with torch.no_grad():
            predictions = torch.argmax(outputs, dim=-1)
            accuracy = calculate_accuracy(predictions, targets[:, 1:])
        
        total_loss += loss.item()
        total_accuracy += accuracy
        batch_count += 1
        
        if batch_count <= 10:
            text_samples.extend(texts[:2]) 
        
        # Update progress bar with current metrics
        train_pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{accuracy:.4f}',
            'AvgLoss': f'{total_loss / batch_count:.4f}',
            'AvgAcc': f'{total_accuracy / batch_count:.4f}'
        })
        
        if batch_count % 10 == 0:
            train_pbar.refresh()
        
        if batch_count % 500 == 0:
            if torch.backends.mps.is_available() and hasattr(torch.mps, 'empty_cache'):
                torch.mps.empty_cache()
    
    train_pbar.close()
    
    avg_train_loss = total_loss / batch_count
    avg_train_acc = total_accuracy / batch_count

    model.eval()
    val_total_loss = 0
    val_total_accuracy = 0
    val_batch_count = 0
    val_samples = []

    # Create validation progress bar 
    val_pbar = tqdm(val_loader, 
                    desc=f"Epoch {epoch} Validation", 
                    total=len(val_loader), 
                    unit="batch",
                    leave=True,
                    position=1,
                    file=sys.stdout)
    
    with torch.no_grad():
        for batch_idx, (images, targets, texts) in enumerate(val_pbar):
            images = images.to(device)
            targets = targets.to(device)
            
            # Forward pass with teacher forcing for loss
            model.train()
            outputs = model(images, targets[:, :-1])
            model.eval()

            loss = criterion(outputs.reshape(-1, vocab_size), targets[:, 1:].reshape(-1))

            # Generate predictions (inference mode)
            predictions = model(images)
            accuracy = calculate_accuracy(predictions, targets)
            
            val_total_loss += loss.item()
            val_total_accuracy += accuracy
            val_batch_count += 1
            
            # Collect validation samples
            if val_batch_count <= 5:
                val_samples.extend(texts[:2])
            
            # Update progress bar with current metrics
            val_pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{accuracy:.4f}',
                'AvgLoss': f'{val_total_loss / val_batch_count:.4f}',
                'AvgAcc': f'{val_total_accuracy / val_batch_count:.4f}'
            })
    
    val_pbar.close()
    
    avg_val_loss = val_total_loss / val_batch_count
    avg_val_acc = val_total_accuracy / val_batch_count
    
    # Update learning rate
    scheduler.step()

    # Save best model
    if avg_val_acc > best_val_accuracy:
        best_val_accuracy = avg_val_acc
        torch.save(model.state_dict(), 'swin_ocr_full_best.pth')
    
    # Save final model
    torch.save(model.state_dict(), 'swin_ocr_full_final.pth')

    # Save training checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'train_accuracy': avg_train_acc,
        'val_loss': avg_val_loss,
        'val_accuracy': avg_val_acc,
        'vocab_size': vocab_size,
        'char_to_idx': char_to_idx,
        'idx_to_char': idx_to_char
    }, 'swin_ocr_enhanced_full_checkpoint.pth')
    
    # Final results
    total_time = time.time() - epoch_start_time
    
    # Show sample text types processed
    for i, text in enumerate(text_samples[:8]):
        # Categorize text type
        has_korean = any('\uAC00' <= c <= '\uD7A3' for c in text)
        has_english = any(c.isalpha() and ord(c) < 128 for c in text)
        has_numbers = any(c.isdigit() for c in text)
        has_symbols = any(c in '@._-()' for c in text)
        
        categories = []
        if has_korean: categories.append("Korean")
        if has_english: categories.append("English")  
        if has_numbers: categories.append("Numbers")
        if has_symbols: categories.append("Symbols")
        

In [None]:
# CER Calculation
def calculate_cer(model_enhanced, val_loader, device, max_samples=None): 
    model_enhanced.eval()
    
    total_char_errors = 0
    total_ref_chars = 0
    sample_count = 0
    detailed_results = []
    
    with torch.no_grad():
        for batch_idx, (images, targets, texts) in enumerate(val_loader):
            if max_samples and sample_count >= max_samples:
                break
                
            images = images.to(device)
            predictions = model_enhanced(images)

            batch_size = images.size(0)
            for i in range(batch_size):
                if max_samples and sample_count >= max_samples:
                    break
                
                # Convert tokens back to text
                reference_text = tokens_to_text(targets[i].cpu().numpy())
                predicted_text = tokens_to_text(predictions[i].cpu().numpy())
                original_text = texts[i]
                
                # Calculate edit distance (Levenshtein distance)
                edit_distance = Levenshtein.distance(predicted_text, reference_text)
                
                # CER calculation
                ref_length = len(reference_text)
                if ref_length > 0:
                    cer_sample = edit_distance / ref_length
                else:
                    cer_sample = 0.0 if len(predicted_text) == 0 else 1.0
                
                # Accumulate totals
                total_char_errors += edit_distance
                total_ref_chars += ref_length

                result = {
                    'sample_id': sample_count,
                    'original': original_text,
                    'reference': reference_text,
                    'predicted': predicted_text,
                    'edit_distance': edit_distance,
                    'ref_length': ref_length,
                    'cer': cer_sample,
                    'perfect_match': reference_text == predicted_text
                } # Store
                detailed_results.append(result)
                
                sample_count += 1
                
                if sample_count % 100 == 0:
                    current_cer = total_char_errors / max(total_ref_chars, 1)
                    print(f"   Processed {sample_count} samples, Current CER: {current_cer:.4f}")
    
    # Calculate overall CER
    overall_cer = total_char_errors / max(total_ref_chars, 1)
    
    # Calculate additional statistics
    perfect_matches = sum(1 for r in detailed_results if r['perfect_match'])
    sample_cers = [r['cer'] for r in detailed_results]
    
    results = {
        'overall_cer': overall_cer,
        'total_samples': sample_count,
        'total_char_errors': total_char_errors,
        'total_ref_chars': total_ref_chars,
        'perfect_matches': perfect_matches,
        'perfect_match_rate': perfect_matches / sample_count if sample_count > 0 else 0,
        'mean_sample_cer': np.mean(sample_cers),
        'median_sample_cer': np.median(sample_cers),
        'min_cer': np.min(sample_cers),
        'max_cer': np.max(sample_cers),
        'detailed_results': detailed_results
    }
    
    return results

# Calculate CER 
cer_results = calculate_cer(model, val_loader, device)