# Real-ESRGAN x4 ‚Äî ONNX + TensorRT 10 (v15)

**v15: Usa `trtexec` (CLI de NVIDIA)** para construir el engine, evitando conflictos de Python bindings.


In [None]:
# CONFIG
INPUT_DRIVE_DIR   = "/content/drive/MyDrive/zanelli_miranda"
OUTPUT_DRIVE_ROOT = "/content/drive/MyDrive/zanelli"
OUTPUT_SUBDIR     = "upscaled_x4_batches"
FILES_PER_BATCH   = 600
ARCHIVE_ZSTD_LVL  = 3
MAX_PARALLEL_IO   = 2
TRT_PRECISION     = "fp16"
INPUT_W, INPUT_H  = 720, 480

# === NUEVAS L√çNEAS ===
ENABLE_FP16       = True
JPG_QUALITY       = 95
WRITER_THREADS    = 4
# =====================

CACHE_DRIVE_DIR   = f"{OUTPUT_DRIVE_ROOT}/.cache_realesrgan_onnx"
ONNX_CACHE_NAME   = f"realesrgan_{INPUT_W}x{INPUT_H}.onnx"
ENGINE_CACHE_NAME = f"realesrgan_{INPUT_W}x{INPUT_H}_{TRT_PRECISION}.engine"
WORK_LOCAL_ROOT   = "/content/work_realesrgan"
LOCAL_IN_DIR      = f"{WORK_LOCAL_ROOT}/in"
LOCAL_OUT_DIR     = f"{WORK_LOCAL_ROOT}/out"
LOCAL_TMP_DIR     = f"{WORK_LOCAL_ROOT}/tmp"
MANIFEST_PATH     = f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}/manifest_{INPUT_W}x{INPUT_H}.txt"
CKPT_PATH         = f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}/checkpoint_{INPUT_W}x{INPUT_H}.json"

import os
for k,v in list(locals().items()):
    if isinstance(v, (str,int)) and k.isupper(): os.environ[k] = str(v)
print("‚úÖ Config")

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
!nvidia-smi -L
print(f"Input dir exists: {os.path.isdir(INPUT_DRIVE_DIR)}")

In [None]:
import os, json, glob
os.makedirs(f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}", exist_ok=True)
os.makedirs(CACHE_DRIVE_DIR, exist_ok=True)
for d in (WORK_LOCAL_ROOT, LOCAL_IN_DIR, LOCAL_OUT_DIR, LOCAL_TMP_DIR): os.makedirs(d, exist_ok=True)

if not os.path.isfile(MANIFEST_PATH):
    files = sorted([os.path.basename(p) for p in glob.glob(f"{INPUT_DRIVE_DIR}/*.png")])
    if not files: raise RuntimeError(f"No PNG in {INPUT_DRIVE_DIR}")
    open(MANIFEST_PATH, "w").write("\n".join(files)+"\n")
    print(f"Manifest: {len(files)} files")
else:
    files = [l.strip() for l in open(MANIFEST_PATH) if l.strip()]
    print(f"Manifest exists: {len(files)} files")

if not os.path.isfile(CKPT_PATH):
    json.dump({"next_index": 0}, open(CKPT_PATH, "w"))
print(f"Checkpoint: {json.load(open(CKPT_PATH)).get('next_index', 0)}")
print(f"Batches: {len(glob.glob(f'{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}/batch_*.tar.zst'))}")

In [None]:
!apt-get update -qq && apt-get install -y -qq zstd rsync >/dev/null && echo "‚úÖ APT"

In [None]:
# (Celda 6) TensorRT en Colab: evitar mismatch Driver/CUDA Runtime
# Este notebook puede correr con distintos "CUDA Version" seg√∫n el runtime.
# Si instalas TensorRT/CUDA demasiado nuevo, aparece:
#   "CUDA driver version is insufficient for CUDA runtime version"
#
# Recomendaci√≥n: NO uses /usr/bin/trtexec del sistema si trae un CUDA mayor al que soporta el driver.
# En su lugar, instala TensorRT Python en la variante correcta (cu12 o cu13) seg√∫n `nvidia-smi`.

import re, subprocess, sys

print("üîé GPU / Driver / CUDA (seg√∫n nvidia-smi):")
try:
    smi = subprocess.check_output(["nvidia-smi"]).decode("utf-8", errors="ignore")
    print(smi.splitlines()[0])
    m = re.search(r"CUDA Version:\s+(\d+)\.(\d+)", smi)
    if not m:
        raise RuntimeError("No pude parsear CUDA Version desde nvidia-smi")
    CUDA_MAJOR = int(m.group(1))
    CUDA_MINOR = int(m.group(2))
except Exception as e:
    print("‚ö†Ô∏è No pude ejecutar nvidia-smi:", e)
    print("   Si no tienes GPU habilitada: Runtime -> Change runtime type -> GPU")
    CUDA_MAJOR, CUDA_MINOR = 12, 0  # fallback razonable

print(f"‚úÖ Detectado CUDA Version ~ {CUDA_MAJOR}.{CUDA_MINOR} (m√°ximo soportado por el driver).")

# Instalar TensorRT Python en la variante adecuada:
# - Si nvidia-smi reporta CUDA 12.x -> tensorrt-cu12
# - Si reporta CUDA 13.x -> tensorrt-cu13
# Nota: por defecto, algunos paquetes instalan CUDA 13.x; aqu√≠ forzamos la variante que calza con el driver.
pkg = f"tensorrt-cu{CUDA_MAJOR}"
print("üì¶ Instalando:", pkg)
!pip -q install --upgrade "{pkg}" onnx

import tensorrt as trt
print("‚úÖ TensorRT Python:", trt.__version__)


In [None]:
# RRDBNet architecture (compatibles con pesos RealESRGAN_x4plus.pth)
import torch, torch.nn as nn, torch.nn.functional as F

class ResidualDenseBlock5C(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf + 2*gc, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf + 3*gc, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf + 4*gc, nf, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        # Nombres IMPORTANTES para que calcen con el state_dict:
        # rdb1/2/3 y conv1..conv5
        self.rdb1 = ResidualDenseBlock5C(nf, gc)
        self.rdb2 = ResidualDenseBlock5C(nf, gc)
        self.rdb3 = ResidualDenseBlock5C(nf, gc)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x

class RRDBNet(nn.Module):
    def __init__(self, num_in_ch=3, num_out_ch=3, nf=64, nb=23, gc=32):
        super().__init__()
        self.conv_first = nn.Conv2d(num_in_ch, nf, 3, 1, 1)
        self.body = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.conv_body = nn.Conv2d(nf, nf, 3, 1, 1)

        # upsample x4 (nearest + conv) como en ESRGAN / Real-ESRGAN
        self.conv_up1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_hr  = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_last = nn.Conv2d(nf, num_out_ch, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(0.2, True)

    def forward(self, x):
        feat = self.conv_first(x)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.conv_hr(feat)))
        return out

print("‚úÖ RRDBNet (nombres compatibles con RealESRGAN_x4plus)")


In [None]:
# Export ONNX (CPU only) ‚Äî forzar pesos embebidos (sin external_data)
import os, torch
# Asegurar dependencias para exporter ONNX (PyTorch reciente usa onnxscript)
import importlib, sys, subprocess, inspect

def _pip_install(*pkgs):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", *pkgs])

for pkg in ("onnx", "onnxscript"):
    try:
        importlib.import_module(pkg)
    except ImportError:
        _pip_install(pkg)

ONNX_LOCAL = f"{LOCAL_TMP_DIR}/realesrgan.onnx"
ONNX_CACHE = f"{CACHE_DRIVE_DIR}/{ONNX_CACHE_NAME}"

if os.path.isfile(ONNX_CACHE):
    print("üîÅ ONNX cached")
    !cp "{ONNX_CACHE}" "{ONNX_LOCAL}"
else:
    print("Exporting ONNX...")
    m = RRDBNet()
    pth = f"{LOCAL_TMP_DIR}/weights.pth"
    if not os.path.isfile(pth):
        !wget -q -O "{pth}" "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
    sd = torch.load(pth, map_location='cpu', weights_only=True)
    state = sd.get('params_ema', sd.get('params', sd))

    # Algunos checkpoints vienen con prefijo 'module.' (DataParallel)
    if isinstance(state, dict) and any(k.startswith('module.') for k in state.keys()):
        state = {k.replace('module.', '', 1): v for k, v in state.items()}

    try:
        m.load_state_dict(state, strict=True)
    except RuntimeError:
        print("‚ùå load_state_dict strict=True fall√≥. Diagn√≥stico con strict=False:")
        missing, unexpected = m.load_state_dict(state, strict=False)
        print("  missing keys:", len(missing))
        print("  unexpected keys:", len(unexpected))
        print("  ejemplo missing:", missing[:5])
        print("  ejemplo unexpected:", unexpected[:5])
        raise

    # Asegurar float32 en export (evita pesos BF16/FP16 raros)
    m = m.float().cpu().eval()

    dummy = torch.randn(1, 3, INPUT_H, INPUT_W, dtype=torch.float32)

    # Export: usar exporter cl√°sico si est√° disponible (evita depender de dynamo)
    _export_kwargs = dict(
        opset_version=17,
        input_names=['input'],
        output_names=['output'],
        export_params=True,
        do_constant_folding=True,
        keep_initializers_as_inputs=False,
    )
    _sig = inspect.signature(torch.onnx.export)
    if 'dynamo' in _sig.parameters:
        _export_kwargs['dynamo'] = False  # exporter cl√°sico

    torch.onnx.export(
        m,
        dummy,
        ONNX_LOCAL,
        **_export_kwargs,
    )

    print(f"‚úÖ ONNX {os.path.getsize(ONNX_LOCAL)/1e6:.1f}MB")

    # Re-guardar con onnx para garantizar que NO use external data (opcional)
    try:
        import onnx
        mm = onnx.load_model(ONNX_LOCAL, load_external_data=True)
        onnx.save_model(mm, ONNX_LOCAL, save_as_external_data=False)
        print("‚úÖ ONNX re-guardado sin external data")
    except Exception as e:
        print("‚ö†Ô∏è  No se pudo re-guardar ONNX con el paquete onnx (opcional).", type(e).__name__)

    !cp "{ONNX_LOCAL}" "{ONNX_CACHE}"
    del m


In [None]:
# (Celda 9) Build TensorRT engine (TensorRT Python) ‚Äî robusto para ONNX con pesos externos
import os, time, glob

ENABLE_FP16 = True  # Usar FP16 para m√°xima velocidad

ENGINE_LOCAL = f"{LOCAL_TMP_DIR}/realesrgan.engine"
ENGINE_CACHE = f"{CACHE_DRIVE_DIR}/{ENGINE_CACHE_NAME}"

def _file_mb(p):
    try: return os.path.getsize(p)/1024/1024
    except: return -1

if os.path.isfile(ENGINE_CACHE):
    print("üîÅ Engine cached")
    !cp "{ENGINE_CACHE}" "{ENGINE_LOCAL}"
else:
    print("Building TensorRT engine with TensorRT Python (puede tardar 2-15 min)...")

    # 1) Diagn√≥stico r√°pido del ONNX (especialmente si fue exportado con external data)
    onnx_dir = os.path.dirname(ONNX_LOCAL)
    onnx_base = os.path.splitext(os.path.basename(ONNX_LOCAL))[0]
    sidecars = sorted(glob.glob(os.path.join(onnx_dir, onnx_base + "*")))
    print(f"ONNX: {ONNX_LOCAL}  ({_file_mb(ONNX_LOCAL):.1f} MB)")
    print("Sidecar files:", [os.path.basename(x) for x in sidecars])

    # 2) Si el ONNX usa pesos externos (archivo .data/.onnx_data/etc), crear copia EMBEBIDA
    #    Esto evita el fallo t√≠pico: "Failed to import initializer: <weight>"
    ONNX_TO_PARSE = ONNX_LOCAL
    ONNX_EMBED = os.path.join(onnx_dir, onnx_base + "_embedded.onnx")

    try:
        import onnx
        try:
            m = onnx.load_model(ONNX_LOCAL, load_external_data=True)
            # Guardar embebido (sin external data)
            onnx.save_model(m, ONNX_EMBED, save_as_external_data=False)
            if os.path.isfile(ONNX_EMBED) and os.path.getsize(ONNX_EMBED) > 0:
                ONNX_TO_PARSE = ONNX_EMBED
                print(f"‚úÖ ONNX embebido creado: {ONNX_TO_PARSE}  ({_file_mb(ONNX_TO_PARSE):.1f} MB)")
        except Exception as e:
            print("‚ö†Ô∏è  No pude cargar/embeber external data con onnx.load_model(load_external_data=True).")
            print("    Si tu export gener√≥ un archivo de pesos externo (p.ej. *.data / *.onnx_data), aseg√∫rate que exista en el mismo directorio.")
            print("    Error:", type(e).__name__, str(e)[:300])
    except Exception as e:
        print("‚ö†Ô∏è  Paquete 'onnx' no disponible; se intentar√° parsear el ONNX tal cual.")
        print("    Error:", type(e).__name__, str(e)[:200])

    import tensorrt as trt

    t0 = time.time()
    logger = trt.Logger(trt.Logger.WARNING)

    # Crear builder / network / parser
    builder = trt.Builder(logger)
    explicit = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(explicit)
    parser = trt.OnnxParser(network, logger)

    # 3) Parse: preferir parse_from_file (maneja mejor external-data que parse(bytes))
    ok = False
    if hasattr(parser, "parse_from_file"):
        ok = parser.parse_from_file(ONNX_TO_PARSE)
    else:
        with open(ONNX_TO_PARSE, "rb") as f:
            ok = parser.parse(f.read())

    if not ok:
        print("‚ùå Error parseando ONNX en TensorRT:")
        for i in range(parser.num_errors):
            print(parser.get_error(i))

        # Mensaje espec√≠fico para el error t√≠pico de initializers
        print("\nSugerencias r√°pidas:")
        print(" - Si el error menciona 'Failed to import initializer: ...', casi siempre es porque el ONNX fue guardado con pesos externos y falta el archivo sidecar.")
        print(" - Revisa que junto a realesrgan.onnx exista el archivo de pesos (p.ej. realesrgan.onnx.data / realesrgan.data / *.onnx_data).")
        print(" - Si existe, aseg√∫rate que est√© en el MISMO directorio que el .onnx.")
        print(" - Alternativa: re-exporta el ONNX con pesos embebidos (sin external_data).")
        raise RuntimeError("ONNX parse failed")

    config = builder.create_builder_config()

    # Workspace
    workspace_gib = 4
    try:
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_gib * (1 << 30))
    except Exception:
        config.max_workspace_size = workspace_gib * (1 << 30)

    # FP16
    if ENABLE_FP16:
        try:
            if builder.platform_has_fast_fp16:
                config.set_flag(trt.BuilderFlag.FP16)
                print("  FP16 enabled")
            else:
                print("  FP16 requested but platform_has_fast_fp16=False; building FP32")
        except Exception:
            config.set_flag(trt.BuilderFlag.FP16)
            print("  FP16 enabled (sin check)")

    # Build
    engine_bytes = builder.build_serialized_network(network, config)
    if engine_bytes is None:
        raise RuntimeError("TensorRT build failed (engine_bytes is None)")

    with open(ENGINE_LOCAL, "wb") as f:
        f.write(engine_bytes)

    print(f"‚úÖ Engine built: {ENGINE_LOCAL}  ({time.time()-t0:.1f} s)")
    !cp "{ENGINE_LOCAL}" "{ENGINE_CACHE}"

# sanity check
if not os.path.isfile(ENGINE_LOCAL) or os.path.getsize(ENGINE_LOCAL) < 1024*1024:
    raise RuntimeError("Engine file missing or too small ‚Äî build likely failed")


In [None]:
# Install pycuda for inference
!pip install -q pycuda
print("‚úÖ pycuda")

In [None]:
# Inference class using TensorRT Runtime (not Builder)
import numpy as np, cv2
import pycuda.driver as cuda
import pycuda.autoinit

# Import tensorrt AFTER pycuda.autoinit
import tensorrt as trt

class Model:
    def __init__(self, path):
        self.logger = trt.Logger(trt.Logger.WARNING)
        
        # Cargar engine (solo Runtime, no Builder)
        with open(path, 'rb') as f:
            self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
        
        if self.engine is None:
            raise RuntimeError("Failed to load engine")
        
        self.ctx = self.engine.create_execution_context()
        self.stream = cuda.Stream()
        
        # Get tensor info
        self.iname = self.engine.get_tensor_name(0)
        self.oname = self.engine.get_tensor_name(1)
        self.ishape = tuple(self.engine.get_tensor_shape(self.iname))
        self.oshape = tuple(self.engine.get_tensor_shape(self.oname))
        
        # Allocate GPU memory
        self.d_in = cuda.mem_alloc(int(np.prod(self.ishape)) * 4)
        self.d_out = cuda.mem_alloc(int(np.prod(self.oshape)) * 4)
        self.h_out = np.empty(self.oshape, np.float32)
        
        print(f"In: {self.ishape}, Out: {self.oshape}")
    
    def __call__(self, img):
        # Preprocess: BGR->RGB, HWC->NCHW, normalize
        x = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
        x = np.ascontiguousarray(x.transpose(2, 0, 1)[None])
        
        # Copy to GPU
        cuda.memcpy_htod_async(self.d_in, x, self.stream)
        
        # Execute
        self.ctx.set_tensor_address(self.iname, int(self.d_in))
        self.ctx.set_tensor_address(self.oname, int(self.d_out))
        self.ctx.execute_async_v3(self.stream.handle)
        
        # Copy back
        cuda.memcpy_dtoh_async(self.h_out, self.d_out, self.stream)
        self.stream.synchronize()
        
        # Postprocess: NCHW->HWC, denormalize, RGB->BGR
        y = np.clip(self.h_out[0], 0, 1).transpose(1, 2, 0)
        return cv2.cvtColor((y * 255).round().astype(np.uint8), cv2.COLOR_RGB2BGR)

model = Model(ENGINE_LOCAL)
print("‚úÖ Model loaded")

In [None]:
# Smoke test
import shutil, time, glob
shutil.rmtree(LOCAL_IN_DIR, True); shutil.rmtree(LOCAL_OUT_DIR, True)
os.makedirs(LOCAL_IN_DIR); os.makedirs(LOCAL_OUT_DIR)

for f in files[:5]: 
    shutil.copy2(f"{INPUT_DRIVE_DIR}/{f}", LOCAL_IN_DIR)

print(f"Processing {len(files[:5])} images...")
t0 = time.time()
for f in files[:5]:
    img = cv2.imread(f"{LOCAL_IN_DIR}/{f}")
    if img is not None: 
        cv2.imwrite(f"{LOCAL_OUT_DIR}/{f}", model(img))

dt = time.time() - t0
outs = glob.glob(LOCAL_OUT_DIR + '/*.png')
print(f"‚úÖ {len(outs)} imgs, {dt:.1f}s total, {dt/max(1,len(outs)):.2f}s/img")

if outs:
    sin = cv2.imread(f"{LOCAL_IN_DIR}/{files[0]}")
    sout = cv2.imread(outs[0])
    print(f"   Input:  {sin.shape}")
    print(f"   Output: {sout.shape}")

In [None]:
# Main loop - CORREGIDO
import json, subprocess, shutil, time, glob, threading, queue, os
from concurrent.futures import ThreadPoolExecutor

OUT = f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}"

def bash(c): return subprocess.check_output(c, shell=True, text=True)
def ckpt(): return json.load(open(CKPT_PATH))
def save_ckpt(i, extra=None):
    d = ckpt(); d['next_index'] = i
    if extra: d.update(extra)
    json.dump(d, open(CKPT_PATH, 'w'), indent=2)

def bname(s, e): return f"batch_{s:08d}_{e:08d}.tar.zst"
def done(s, e): return os.path.isfile(f"{OUT}/{bname(s,e)}")

def format_time(seconds):
    if seconds < 60: return f"{seconds:.0f}s"
    elif seconds < 3600: return f"{seconds/60:.1f} min"
    else: return f"{int(seconds//3600)}h {int((seconds%3600)//60)}m"

# === FIX #2: Recalcular idx desde el primer batch faltante ===
real_idx = 0
while real_idx < len(files):
    s, e = real_idx, min(len(files), real_idx + FILES_PER_BATCH)
    if not done(s, e):
        break
    real_idx = e
idx = real_idx
total = len(files)
save_ckpt(idx)
print(f"Start: {idx}/{total} (verificado)")

prefetch_pool = ThreadPoolExecutor(1)
upload_pool = ThreadPoolExecutor(2)

# === WRITER THREADS ===
write_queue = queue.Queue(maxsize=WRITER_THREADS * 2)

def writer_thread():
    while True:
        item = write_queue.get()
        if item is None:
            write_queue.task_done()
            break
        path, img = item
        cv2.imwrite(path, img, [cv2.IMWRITE_JPEG_QUALITY, JPG_QUALITY])
        write_queue.task_done()

writers = []
for _ in range(WRITER_THREADS):
    t = threading.Thread(target=writer_thread, daemon=True)
    t.start()
    writers.append(t)

# === UPLOAD ASYNC (FIX #2: checkpoint secuencial) ===
upload_queue = queue.Queue()
last_uploaded_idx = idx  # Checkpoint secuencial

def upload_worker():
    global last_uploaded_idx
    while True:
        item = upload_queue.get()
        if item is None:
            upload_queue.task_done()
            break
        arc, name, end_idx = item
        try:
            bash(f"rsync -a '{arc}' '{OUT}/{name}'")
            # Solo avanzar checkpoint si es el siguiente en secuencia
            if end_idx == last_uploaded_idx + FILES_PER_BATCH or last_uploaded_idx == idx:
                last_uploaded_idx = end_idx
                save_ckpt(end_idx)
            try: os.remove(arc)
            except: pass
        except Exception as e:
            print(f"   ‚ö†Ô∏è Upload error: {e}")
        upload_queue.task_done()

upload_thread = threading.Thread(target=upload_worker, daemon=True)
upload_thread.start()

# === PREFETCH (FIX #1: alternancia correcta) ===
LOCAL_IN_A = f"{WORK_LOCAL_ROOT}/in_A"
LOCAL_IN_B = f"{WORK_LOCAL_ROOT}/in_B"
os.makedirs(LOCAL_IN_A, exist_ok=True)
os.makedirs(LOCAL_IN_B, exist_ok=True)

def copy_batch_to_dir(file_list, target_dir):
    shutil.rmtree(target_dir, True)
    os.makedirs(target_dir, exist_ok=True)
    lst = f"{LOCAL_TMP_DIR}/pf_{os.path.basename(target_dir)}.txt"
    open(lst, 'w').write("\n".join(file_list) + "\n")
    bash(f"rsync -a --files-from='{lst}' '{INPUT_DRIVE_DIR}/' '{target_dir}/'")
    return target_dir

# Calcular batches pendientes
batches = []
temp_idx = idx
while temp_idx < total:
    s, e = temp_idx, min(total, temp_idx + FILES_PER_BATCH)
    if not done(s, e):
        batches.append((s, e))
    temp_idx = e

print(f"Batches pendientes: {len(batches)}")

if not batches:
    print("‚úÖ Todo ya procesado!")
else:
    global_start = time.time()
    
    # Prefetch inicial a LOCAL_IN_A
    s0, e0 = batches[0]
    print(f"üì• Prefetch inicial [{s0}:{e0})...")
    prefetch_future = prefetch_pool.submit(copy_batch_to_dir, files[s0:e0], LOCAL_IN_A)
    
    # FIX #1: Variables claras para alternancia
    use_A_for_current = True  # Batch actual usa A, prefetch va a B

    for batch_idx, (s, e) in enumerate(batches):
        batch_num = batch_idx + 1
        total_batches = len(batches)
        
        print(f"\n{'='*60}")
        print(f"üé¨ Batch {batch_num}/{total_batches} [{s}:{e}) ({e-s} frames)")
        print(f"{'='*60}")
        
        # Determinar directorios
        current_in_dir = LOCAL_IN_A if use_A_for_current else LOCAL_IN_B
        prefetch_target = LOCAL_IN_B if use_A_for_current else LOCAL_IN_A
        
        # Esperar prefetch de ESTE batch
        t0 = time.time()
        prefetch_future.result()
        wait_time = time.time() - t0
        if wait_time > 1:
            print(f"   üì• Esper√≥ prefetch: {wait_time:.1f}s")
        else:
            print(f"   üì• Prefetch listo ‚úì")
        
        # Lanzar prefetch del SIGUIENTE batch (a la otra carpeta)
        if batch_idx + 1 < len(batches):
            next_s, next_e = batches[batch_idx + 1]
            prefetch_future = prefetch_pool.submit(copy_batch_to_dir, files[next_s:next_e], prefetch_target)
        
        # Alternar para pr√≥xima iteraci√≥n
        use_A_for_current = not use_A_for_current
        
        # Preparar output
        shutil.rmtree(LOCAL_OUT_DIR, True)
        os.makedirs(LOCAL_OUT_DIR)
        
        # Procesar
        t1 = time.time()
        n = 0
        for f in files[s:e]:
            img = cv2.imread(f"{current_in_dir}/{f}")
            if img is not None:
                out_img = model(img)
                base, _ = os.path.splitext(f)
                write_queue.put((f"{LOCAL_OUT_DIR}/{base}.jpg", out_img))
                n += 1
        
        write_queue.join()
        dt = time.time() - t1
        print(f"   ‚ö° TRT+JPG: {dt:.1f}s ({n} imgs, {n/dt:.2f} fps)")
        
        # Comprimir
        t2 = time.time()
        arc = f"{WORK_LOCAL_ROOT}/{bname(s,e)}"
        bash(f"tar -I 'zstd -{ARCHIVE_ZSTD_LVL} -T0' -cf '{arc}' -C '{LOCAL_OUT_DIR}' .")
        print(f"   üì¶ Zip: {time.time()-t2:.1f}s")
        
        # Encolar upload
        upload_queue.put((arc, bname(s,e), e))
        print(f"   üì§ Upload encolado")
        
        # Monitor
        frames_done = e
        elapsed = time.time() - global_start
        fps = (frames_done - idx) / elapsed
        remaining = total - frames_done
        eta = remaining / fps if fps > 0 else 0
        
        print(f"\n   üìä Progreso: {frames_done:,}/{total:,} ({100*frames_done/total:.1f}%)")
        print(f"   ‚è±Ô∏è  Tiempo: {format_time(elapsed)} | FPS: {fps:.2f}")
        print(f"   üèÅ ETA: {format_time(eta)}")

    print("\n‚è≥ Esperando uploads...")
    upload_queue.join()

# Cleanup
upload_queue.put(None)
for _ in writers:
    write_queue.put(None)
write_queue.join()

total_time = time.time() - global_start
print("\n" + "="*60)
print("‚úÖ COMPLETADO")
print(f"   Frames: {total:,}")
print(f"   Tiempo: {format_time(total_time)}")
print(f"   FPS: {(total-idx)/total_time:.2f}")
print("="*60)

In [None]:
# Extract batch (optional)
bs = sorted(glob.glob(f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}/batch_*.tar.zst"))
print(f"{len(bs)} batches")
if bs:
    out = f"{OUTPUT_DRIVE_ROOT}/{OUTPUT_SUBDIR}/extracted"
    os.makedirs(out, exist_ok=True)
    !tar -I 'zstd -d' -xf "{bs[0]}" -C "{out}"
    print(f"‚úÖ {len(glob.glob(out+'/*.jpg'))} files")