<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/HAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title ##**Install** { display-mode: "form" }
%%capture
import os

%cd /content

!pip install basicsr==1.3.4.9 --no-deps


degradations_path = "/usr/local/lib/python3.11/dist-packages/basicsr/data/degradations.py"
if os.path.exists(degradations_path):
    with open(degradations_path, 'r', encoding='utf-8') as f:
        content = f.read()
    old_import = "from torchvision.transforms.functional_tensor import rgb_to_grayscale"
    new_import = "from torchvision.transforms.functional import rgb_to_grayscale"
    if old_import in content:
        content = content.replace(old_import, new_import)
        with open(degradations_path, 'w', encoding='utf-8') as f:
            f.write(content)
        print(f"Fixed import in {degradations_path}")
    else:
        print(f"Import in {degradations_path} already correct")

if not os.path.exists("/content/HAT"):
    print("Cloning HAT repository...")
    !git clone https://github.com/XPixelGroup/HAT /content/HAT

hat_arch_path = "/content/HAT/hat/archs/hat_arch.py"
with open(hat_arch_path, 'r', encoding='utf-8') as f:
    content = f.read()
if "ARCH_REGISTRY" in content:
    content = content.replace("ARCH_REGISTRY", "MODEL_REGISTRY")
if "from basicsr.utils.registry import MODEL_REGISTRY" not in content:
    content = content.replace(
        "from basicsr.utils.registry import ARCH_REGISTRY",
        "from basicsr.utils.registry import MODEL_REGISTRY"
    )
if "import torch.nn as nn" not in content:
    content = content.replace("import torch", "import torch\nimport torch.nn as nn")
with open(hat_arch_path, 'w', encoding='utf-8') as f:
    f.write(content)
print("Updated hat_arch.py")

os.makedirs("/content/uploads", exist_ok=True)
os.makedirs("/content/results", exist_ok=True)
print("Created folders: /content/uploads, /content/results")

model_path = "/content/HAT/experiments/pretrained_models/Real_HAT_GAN_sharper.pth"
if not os.path.exists(model_path):
    print("Downloading Real_HAT_GAN_sharper.pth...")
    !pip install gdown
    !gdown --id 1EioFq5-mKmv1uqta_Byd9cgXp9SU3zjj -O {model_path}
else:
    print("Real_HAT_GAN_sharper.pth already exists")

print("Setup completed.")

In [None]:
#@title ##**Upload images** { display-mode: "form" }
from google.colab import files
import os
import shutil

upload_folder = "/content/uploads"

shutil.rmtree(upload_folder, ignore_errors=True)
os.makedirs(upload_folder, exist_ok=True)

uploaded = files.upload()

for filename in uploaded.keys():
    shutil.move(filename, os.path.join(upload_folder, filename))

In [34]:
#@title ##**Run** { display-mode: "form" }
%%capture
import os
import yaml
import shutil
import sys
import torch
import logging
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gc
from torchvision import transforms
import traceback

# Подавление предупреждений
logging.getLogger("torch").setLevel(logging.ERROR)
logging.getLogger("basicsr").setLevel(logging.ERROR)

# Функция проверки памяти GPU
def check_gpu_memory(stage=""):
    if torch.cuda.is_available():
        total = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        allocated = torch.cuda.memory_allocated(0) / (1024**3)
        print(f"{stage} GPU Memory: Total {total:.2f} GB, Allocated {allocated:.2f} GB")
    else:
        print(f"{stage} GPU not available")

# Очистка памяти в начале
print("Initial memory cleanup...")
check_gpu_memory("Initial")
gc.collect()
torch.cuda.empty_cache()
print("Initial GPU cache and RAM cleared")

# Проверка доступных моделей
print("Available pretrained models:")
!ls -lh /content/HAT/experiments/pretrained_models/

# Проверка Set5
print("Checking Set5 dataset...")
!ls /content/uploads/

# Очистка после подготовки данных
check_gpu_memory("Post-data-prep")
gc.collect()
torch.cuda.empty_cache()
print("Post-data-prep GPU cache and RAM cleared")

# Проверка входных изображений
upload_folder = "/content/uploads"
temp_folder = "/content/temp_uploads"
print("Checking input images:")
for img_name in os.listdir(upload_folder):
    img_path = os.path.join(upload_folder, img_name)
    if os.path.isfile(img_path):  # Only check files
        try:
            img = Image.open(img_path)
            img_array = np.array(img)
            print(f" - {img_name}: {img.size}, format: {img.format}, pixel range: [{img_array.min()}, {img_array.max()}]")
        except Exception as e:
            print(f" - {img_name}: Failed to open ({e})")
    else:
        print(f" - {img_name}: Skipped (not a file)")

# Путь к весам
weight_path = "/content/HAT/experiments/pretrained_models/Real_HAT_GAN_sharper.pth"
if os.path.exists(weight_path):
    print(f"Inspecting weights: {weight_path}")
    checkpoint = torch.load(weight_path, map_location='cpu')
    print("Available keys in checkpoint:", list(checkpoint.keys()))
    del checkpoint
else:
    print(f"Pretrained weights not found at {weight_path}!")
    raise FileNotFoundError("Check available models above")

# Проверка весов
print(f"Weights found: {weight_path}, Size: {os.path.getsize(weight_path) / (1024**2):.2f} MB")

# Очистка перед загрузкой модели
check_gpu_memory("Pre-model-load")
gc.collect()
torch.cuda.empty_cache()
print("Pre-model-load GPU cache and RAM cleared")

# Добавляем путь к hat в sys.path
sys.path.insert(0, "/content/HAT")

# Импортируем hat.archs
try:
    import hat.archs
except Exception as e:
    print(f"Failed to import hat.archs: {e}")
    raise

# Регистрируем HAT в ARCH_REGISTRY
from basicsr.utils.registry import ARCH_REGISTRY, MODEL_REGISTRY
if "HAT" not in ARCH_REGISTRY._obj_map:
    try:
        hat_class = MODEL_REGISTRY.get("HAT")
        if hat_class:
            ARCH_REGISTRY.register(hat_class)
            print("HAT architecture manually registered in ARCH_REGISTRY from MODEL_REGISTRY")
        else:
            raise KeyError("HAT not found in MODEL_REGISTRY")
    except KeyError as e:
        print(f"Failed to register HAT: {e}")
        raise
else:
    print("HAT already registered in ARCH_REGISTRY")

# Вывод зарегистрированных архитектур
print("Registered architectures:", list(ARCH_REGISTRY._obj_map.keys()))

result_folder = "/content/results"

# Очистка папки с результатами
if os.path.exists(result_folder):
    shutil.rmtree(result_folder)
os.makedirs(result_folder, exist_ok=True)
print(f"Result folder {result_folder} cleared and recreated")

# Создание временной папки
if os.path.exists(temp_folder):
    shutil.rmtree(temp_folder)
os.makedirs(temp_folder, exist_ok=True)
print(f"Temporary folder {temp_folder} created")

# Функции для обработки изображений с тайлингом
def load_and_preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    w_orig, h_orig = img.size
    print(f"Original image dimensions: {w_orig}x{h_orig}")

    new_w = ((w_orig + 15) // 16) * 16
    new_h = ((h_orig + 15) // 16) * 16

    delta_w = new_w - w_orig
    delta_h = new_h - h_orig

    padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)

    if delta_w != 0 or delta_h != 0:
        print(f"Adding transparent padding to dimensions: {new_w}x{new_h}")
        img = img.crop((-padding[0], -padding[1], w_orig + padding[2], h_orig + padding[3]))
    else:
        print("Image dimensions are already multiples of 16, no padding needed.")

    return img, padding, (w_orig, h_orig)

def split_image_into_tiles(img, tile_size):
    w, h = img.size
    tiles = []
    positions = []
    for y in range(0, h, tile_size):
        for x in range(0, w, tile_size):
            tile = img.crop((x, y, min(x + tile_size, w), min(y + tile_size, h)))
            tiles.append(tile)
            positions.append((x, y))
    return tiles, positions

def merge_tiles(tiles, positions, img_size):
    upscaled_img = Image.new('RGB', img_size)
    for tile, (x, y) in zip(tiles, positions):
        upscaled_img.paste(tile, (x, y))
    return upscaled_img

def postprocess_image(tensor):
    output = tensor.squeeze(0).cpu().clamp_(0, 1)
    output = transforms.ToPILImage()(output)
    return output

# Конфигурация для Real_HAT_GAN_sharper
config = {
    "name": "HAT_SRx4_Test",
    "model_type": "HATModel",
    "scale": 4,
    "num_gpu": 1,
    "network_g": {
        "type": "HAT",
        "upscale": 4,
        "in_chans": 3,
        "img_size": 64,
        "window_size": 16,
        "img_range": 1.0,
        "depths": [6, 6, 6, 6, 6, 6],
        "embed_dim": 180,
        "num_heads": [6, 6, 6, 6, 6, 6],
        "mlp_ratio": 2,
        "upsampler": "pixelshuffle",
        "resi_connection": "1conv"
    },
    "path": {
        "pretrain_network_g": weight_path,
        "param_key_g": "params_ema",
        "strict_load_g": False
    },
    "test": {
        "test_y_channel": False,
        "crop_border": 4
    },
    "val": {
        "save_img": True,
        "suffix": None
    },
    "datasets": {
        "test": {
            "name": "test",
            "type": "SingleImageDataset",
            "dataroot_lq": upload_folder,
            "io_backend": {"type": "disk"}
        }
    }
}

# Сохранение конфигурации
os.makedirs("/content/HAT/options/test", exist_ok=True)
with open("/content/HAT/options/test/HAT_SRx4_Test.yml", "w") as f:
    yaml.dump(config, f)

# Очистка перед запуском модели
check_gpu_memory("Pre-test-pipeline")
gc.collect()
torch.cuda.empty_cache()
print("Pre-test-pipeline GPU cache and RAM cleared")

# Модифицированный запуск с тайлингом
from hat.test import test_pipeline
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.ToTensor()

# Получаем список изображений (только файлы)
image_files = [f for f in os.listdir(upload_folder) if os.path.isfile(os.path.join(upload_folder, f))]

for filename in image_files:
    input_path = os.path.join(upload_folder, filename)
    output_path = os.path.join(result_folder, filename)

    try:
        print(f"\nProcessing {filename}")
        # Перемещаем оригинальное изображение во временную папку
        temp_input_path = os.path.join(temp_folder, filename)
        shutil.move(input_path, temp_input_path)

        img, padding, (w_orig, h_orig) = load_and_preprocess_image(temp_input_path)
        tile_size = 512
        tiles, positions = split_image_into_tiles(img, tile_size)

        upscaled_tiles = []
        for idx, tile in enumerate(tiles):
            print(f"Processing tile {idx+1}/{len(tiles)}")
            # Очищаем upload_folder (только файлы)
            for f in os.listdir(upload_folder):
                f_path = os.path.join(upload_folder, f)
                if os.path.isfile(f_path):
                    os.remove(f_path)

            # Сохраняем только текущий тайл в upload_folder
            temp_tile_path = os.path.join(upload_folder, f"tile_{idx}.png")
            tile.save(temp_tile_path)

            # Обновляем конфигурацию
            config['datasets']['test']['dataroot_lq'] = upload_folder
            with open("/content/HAT/options/test/HAT_SRx4_Test.yml", "w") as f:
                yaml.dump(config, f)

            # Запуск test_pipeline
            %cd /content/HAT
            original_argv = sys.argv
            sys.argv = ["hat/test.py", "-opt", "options/test/HAT_SRx4_Test.yml"]
            try:
                root_path = "/content/HAT"
                test_pipeline(root_path)
                # Находим результат
                result_dir = "results/HAT_SRx4_Test/visualization/test"
                tile_output = [f for f in os.listdir(result_dir) if f.startswith("tile_")][0]
                output_image = Image.open(os.path.join(result_dir, tile_output))
                upscaled_tiles.append(output_image)
                # Очищаем результаты
                shutil.rmtree(result_dir)
            except Exception as e:
                print(f"Test pipeline failed for tile {idx+1}: {e}")
                raise
            finally:
                sys.argv = original_argv

            torch.cuda.empty_cache()

        # Собираем изображение
        scale = 4
        upscaled_w = img.size[0] * scale
        upscaled_h = img.size[1] * scale
        upscaled_img = merge_tiles(upscaled_tiles, [(x * scale, y * scale) for x, y in positions], (upscaled_w, upscaled_h))

        pad_left, pad_top, pad_right, pad_bottom = [p * scale for p in padding]
        upscaled_img = upscaled_img.crop((pad_left, pad_top, upscaled_w - pad_right, upscaled_h - pad_bottom))

        expected_w, expected_h = w_orig * scale, h_orig * scale
        upscaled_img = upscaled_img.resize((expected_w, expected_h), Image.LANCZOS)

        upscaled_img.save(output_path)
        print(f"Image successfully processed and saved to {output_path}")

        # Возвращаем оригинальное изображение обратно
        shutil.move(temp_input_path, input_path)

        del img, upscaled_img, upscaled_tiles
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")
        traceback.print_exc()
        # Убедимся, что оригинальное изображение возвращено даже при ошибке
        if os.path.exists(temp_input_path):
            shutil.move(temp_input_path, input_path)

# Очистка временной папки
if os.path.exists(temp_folder):
    shutil.rmtree(temp_folder)
print(f"Temporary folder {temp_folder} removed")

# Очистка после обработки
check_gpu_memory("Post-results")
gc.collect()
torch.cuda.empty_cache()
print("Post-results GPU cache and RAM cleared")

# Визуализация результатов с отладкой
try:
    for fn in os.listdir(result_folder):
        result_img = os.path.join(result_folder, fn)
        if os.path.exists(result_img):
            img = Image.open(result_img)
            img_array = np.array(img)
            print(f"Visualizing {fn}: pixel range: [{img_array.min()}, {img_array.max()}]")
            plt.figure(figsize=(10, 10))
            plt.title(fn)
            plt.imshow(img)
            plt.axis("off")
            plt.show()
            del img, img_array
        else:
            print(f"Result image {result_img} not found")
except Exception as e:
    print(f"Failed to visualize result: {e}")

# Финальная очистка памяти
print("Final memory cleanup...")
check_gpu_memory("Final")
gc.collect()
torch.cuda.empty_cache()
print("Final GPU cache and RAM cleared")

In [None]:
#@title ##**Visualize** { display-mode: "form" }
import os
import PIL.Image
import numpy as np
from IPython.display import HTML, display
import base64
from io import BytesIO

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

def resize_image_maintain_aspect(image, max_width, target_height=None):
    width, height = image.size
    if width > max_width:
        new_height = int(height * max_width / width)
        image = image.resize((max_width, new_height), PIL.Image.Resampling.LANCZOS)
    if target_height is not None and image.size[1] != target_height:
        new_width = int(image.size[0] * target_height / image.size[1])
        image = image.resize((new_width, target_height), PIL.Image.Resampling.LANCZOS)
    return image

def image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

visualization_method = "Slider" #@param ["Side-by-Side", "Slider"]

filenames_upload = sorted([f for f in os.listdir(upload_folder) if is_image_file(f)])
filenames_upload_output = sorted([f for f in os.listdir(result_folder) if is_image_file(f)])

if not filenames_upload or not filenames_upload_output:
    print(f"Error: No images found in {upload_folder} or {result_folder}.")
else:
    for filename, filename_output in zip(filenames_upload, filenames_upload_output):
        try:
            image_original = PIL.Image.open(os.path.join(upload_folder, filename))
            image_restore = PIL.Image.open(os.path.join(result_folder, filename_output))

            if visualization_method == "Side-by-Side":
                max_width = 500
                image_original = resize_image_maintain_aspect(image_original, max_width)
                image_restore = resize_image_maintain_aspect(image_restore, max_width)
                target_height = min(image_original.size[1], image_restore.size[1])
                image_original = resize_image_maintain_aspect(image_original, max_width, target_height)
                image_restore = resize_image_maintain_aspect(image_restore, max_width, target_height)

                combined_width = image_original.size[0] + image_restore.size[0]
                combined_image = PIL.Image.new('RGB', (combined_width, target_height))
                combined_image.paste(image_original, (0, 0))
                combined_image.paste(image_restore, (image_original.size[0], 0))
                display(combined_image)

            else:
                max_width = min(image_restore.size[0], 1000)
                image_restore = resize_image_maintain_aspect(image_restore, max_width)
                target_height = image_restore.size[1]
                image_original = resize_image_maintain_aspect(image_original, max_width, target_height)

                if image_original.mode != 'RGB':
                    image_original = image_original.convert('RGB')
                if image_restore.mode != 'RGB':
                    image_restore = image_restore.convert('RGB')

                original_base64 = image_to_base64(image_original)
                restore_base64 = image_to_base64(image_restore)

                html_code = f"""
                <div style="position: relative; width: {image_restore.size[0]}px; height: {image_restore.size[1]}px; margin-bottom: 20px;">
                    <div style="position: relative; width: 100%; height: 100%; overflow: hidden;">
                        <img src="data:image/png;base64,{original_base64}" style="position: absolute; width: 100%; height: 100%;">
                        <div style="position: absolute; width: 100%; height: 100%; overflow: hidden; clip-path: inset(0 0 0 50%);">
                            <img src="data:image/png;base64,{restore_base64}" style="position: absolute; width: 100%; height: 100%;">
                        </div>
                    </div>
                    <div class="slider" style="position: absolute; top: 0; bottom: 0; width: 4px; background: white; cursor: ew-resize; left: 50%; box-shadow: 0 0 5px rgba(0,0,0,0.5);">
                        <div style="position: absolute; top: 50%; transform: translateY(-50%); width: 20px; height: 20px; background: white; border-radius: 50%; left: -8px;"></div>
                    </div>
                </div>
                <script>
                    document.querySelectorAll('.slider').forEach(slider => {{
                        let isDragging = false;
                        const container = slider.parentElement.querySelector('div:nth-child(1)');
                        const clipDiv = container.querySelector('div');

                        slider.addEventListener('mousedown', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('mouseup', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('mousemove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});

                        slider.addEventListener('touchstart', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('touchend', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('touchmove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.touches[0].clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});
                    }});
                </script>
                """

                display(HTML(html_code))

        except Exception as e:
            print(f"Error processing {filename} and {filename_output}: {e}")

In [None]:
#@title ##**Download results** { display-mode: "form" }
import os
from google.colab import files
import zipfile

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

files_in_folder = [f for f in os.listdir(result_folder) if is_image_file(f)]
zip_file = "download.zip"

if len(files_in_folder) == 1:
    file_to_download = os.path.join(result_folder, files_in_folder[0])
    files.download(file_to_download)
else:
    zip_path = os.path.join(result_folder, zip_file)
    if os.path.exists(zip_path):
        os.remove(zip_path)

    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for file in files_in_folder:
            file_path = os.path.join(result_folder, file)
            zipf.write(file_path, file)

    files.download(zip_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>