<a href="https://colab.research.google.com/github/detektor777/colab_list_image/blob/main/sr_former.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ##**Install** { display-mode: "form" }
%%capture

import os

!pip install openmim --no-deps
!pip install model-index --no-deps

!mim --version

!mim install mmcv-full

try:
    #!pip uninstall basicsr -y
    print("System basicsr package removed.")
except:
    print("System basicsr not installed or already removed.")

!pip install torch torchvision gdown --no-cache-dir

!pip install opendatalab --no-deps
!pip install openxlab --no-deps
!pip install pycryptodome --no-deps
!pip install oss2 --no-deps

degradations_path = "/content/SRFormer/basicsr/data/degradations.py"
if os.path.exists(degradations_path):
    with open(degradations_path, 'r', encoding='utf-8') as file:
        filedata = file.read()

    if 'from torchvision.transforms.functional_tensor import rgb_to_grayscale' in filedata:
        filedata = filedata.replace(
            'from torchvision.transforms.functional_tensor import rgb_to_grayscale',
            'from torchvision.transforms.functional import rgb_to_grayscale'
        )
        with open(degradations_path, 'w', encoding='utf-8') as file:
            file.write(filedata)
        print(f"Fixed import in {degradations_path}")
    else:
        print(f"Import in {degradations_path} already fixed or uses different structure")
else:
    print(f"File {degradations_path} not found. Ensure /content/SRFormer is configured correctly.")

#downloads models
import os
from google.colab import files

if not os.path.exists("/content/SRFormer"):
    !git clone https://github.com/HVision-NKU/SRFormer.git /content/SRFormer
else:
    print("SRFormer repository already cloned.")

!python /content/SRFormer/setup.py develop

!mkdir -p /content/SRFormer/input_images
!mkdir -p /content/SRFormer/output_images
!mkdir -p /content/SRFormer/PretrainModel

models = [
    {"id": "1omhiTXRfX5JJiN1GBZNEcPdOYCGV_f_f", "name": "SRFormerLight_SRX4_DIV2K.pth"},
    {"id": "1Eeei_NEjDeni7ysSejmR7AwyG24fO5bp", "name": "SRFormerLight_SRX3_DIV2K.pth"},
    {"id": "1e3So4kYb0JFQAUZ-sYzlGx_KZsa9rSpG", "name": "SRFormerLight_SRX2_DIV2K.pth"},
    {"id": "13_fpD4aDE1wbEYX8yGWA3mVLZOCRWkWv", "name": "SRFormer_SRX4_DF2K.pth"},
    {"id": "1nfKDZvavZEgZEbF6ymyPFZ27lhg_fBjB", "name": "SRFormer_SRX3_DF2K.pth"},
    {"id": "1lU8SsKeaTwBSC5bP69LjBuJs69Qt4Rsf", "name": "SRFormer_SRX2_DF2K.pth"},
    {"id": "1nmLocRyZPfCiQP2_zc1hxIixEzfVyFG6", "name": "SRFormer_S_RealSR.pth"},
    {"id": "1o-uxeR547hyvnkFCv-4BMnkvZ0RYgQYA", "name": "SRFormer_S_RealSR_for_Chainner.pth"},
    {"id": "1apiX03mDAJkaKzGvdJweOTVwL9aiwdI8", "name": "SRFormer_realworld_W22C180.pth"}
]

for model in models:
    model_path = f"/content/SRFormer/PretrainModel/{model['name']}"
    if not os.path.exists(model_path):
        print(f"Downloading {model['name']}...")
        !gdown {model['id']} -O {model_path} || echo "{model['name']} download failed. Please upload it manually."
    else:
        print(f"Model {model['name']} already exists.")

missing_models = []
for model in models:
    model_path = f"/content/SRFormer/PretrainModel/{model['name']}"
    if not os.path.exists(model_path):
        missing_models.append(model)

if missing_models:
    print("The following models were not downloaded automatically:")
    for model in missing_models:
        print(f"- {model['name']}: https://drive.google.com/file/d/{model['id']}/view?usp=drive_link")
    print("Please download the missing files and upload them using the interface below:")

    uploaded = files.upload()

    for filename in uploaded.keys():
        for model in missing_models:
            if filename == model['name']:
                !mv {filename} /content/SRFormer/PretrainModel/{model['name']}
                print(f"File {model['name']} successfully moved.")
                break
        else:
            print(f"Error: Uploaded file '{filename}' does not match any expected filenames.")
else:
    print("All models successfully downloaded.")

print("\nFiles in /content/SRFormer/PretrainModel:")
!ls -lh /content/SRFormer/PretrainModel

#add model
import os
import subprocess

model_dir = "/content/SRFormer/PretrainModel"
config_dir = "/content/SRFormer/options/test/SRFormer"
os.makedirs(model_dir, exist_ok=True)
os.makedirs(config_dir, exist_ok=True)

files_to_download = [
    {
        "url": "https://drive.usercontent.google.com/download?id=1SaKvpYYIm2Vj2m9GifUMlNCbmkE6JZmr&export=download&authuser=0",
        "name": "FrankendataPretrainer_SRFormer400K_g.pth",
        "dir": model_dir
    },
    {
        "url": "https://drive.usercontent.google.com/download?id=1cfK4EMAUNfmukiFRak0FGK9NTMS0lD62&export=download&authuser=0",
        "name": "FrankendataPretrainer_SRFormer400K_d.pth",
        "dir": model_dir
    },
    {
        "url": "https://drive.usercontent.google.com/download?id=11_mVUusZmFDeeiweklJItFEaYSjHszBp&export=download&authuser=0",
        "name": "4xFrankendataPretrain_SRFormer.yml",
        "dir": config_dir
    }
]

for file in files_to_download:
    output_path = os.path.join(file["dir"], file["name"])
    print(f"Downloading {file['name']}...")
    try:
        cmd = f'curl -L -o "{output_path}" "{file["url"]}&confirm=t"'
        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            print(f"Error downloading {file['name']}: {result.stderr}")
            raise RuntimeError(f"Download failed for {file['name']}.")
        if not os.path.exists(output_path):
            print(f"Error: File {output_path} not found after download.")
            raise FileNotFoundError(f"File {file['name']} missing.")
        print(f"File downloaded successfully to {output_path}")
    except Exception as e:
        print(f"Error downloading {file['name']}: {e}")
        raise

#basicsr
import os

version_path = "/content/SRFormer/basicsr/version.py"
if not os.path.exists(version_path):
    print(f"File {version_path} not found. Creating a placeholder.")
    with open(version_path, 'w', encoding='utf-8') as f:
        f.write('__version__ = "unknown"\n__gitsha__ = "unknown"\n')
    print(f"Created {version_path} with placeholders.")
else:
    print(f"File {version_path} already exists.")

degradations_path = "/content/SRFormer/basicsr/data/degradations.py"
if os.path.exists(degradations_path):
    with open(degradations_path, 'r', encoding='utf-8') as f:
        content = f.read()
    old_import = "from torchvision.transforms.functional_tensor import rgb_to_grayscale"
    new_import = "from torchvision.transforms.functional import rgb_to_grayscale"
    if old_import in content:
        content = content.replace(old_import, new_import)
        with open(degradations_path, 'w', encoding='utf-8') as f:
            f.write(content)
        print(f"Fixed import in {degradations_path}")
    else:
        print(f"Import in {degradations_path} is already correct or missing.")
else:
    print(f"File {degradations_path} not found.")



In [None]:
#@title ##**Upload Image** { display-mode: "form" }
from google.colab import files
import os

input_dir = "/content/SRFormer/input_images"
os.makedirs(input_dir, exist_ok=True)

for f in os.listdir(input_dir):
    os.remove(os.path.join(input_dir, f))

try:
    uploaded = files.upload()
    if not uploaded:
        print(f"Error: No files uploaded.")
    else:
        for filename in uploaded.keys():
            os.rename(filename, os.path.join(input_dir, filename))
except Exception as e:
    print(f"Error uploading files: {e}")

In [None]:
#@title ##**Run** { display-mode: "form" }
import os
import sys
import yaml
import warnings
from pathlib import Path
from tqdm import tqdm
import logging

warnings.filterwarnings("ignore", category=UserWarning, module="torch.functional")
logging.getLogger().setLevel(logging.ERROR)

config_dir = "/content/SRFormer/options/test/SRFormer"
model_dir = "/content/SRFormer/PretrainModel"

yml_files = [
    "test_SRFormer_DF2Ksrx4.yml",
    "test_SRFormer_light_DIV2Ksrx4.yml",
    "test_SRFormer_DF2Ksrx2.yml",
    "test_SRFormer-S_x4_real.yml",
    "test_SRFormer_light_DIV2Ksrx3.yml",
    "4xFrankendataPretrain_SRFormer.yml",
    "test_SRFormer_DF2Ksrx3.yml",
    "test_SRFormer_light_DIV2Ksrx2.yml"
]

selected_yml = "test_SRFormer-S_x4_real.yml" #@param ["test_SRFormer_DF2Ksrx4.yml", "test_SRFormer_light_DIV2Ksrx4.yml", "test_SRFormer_DF2Ksrx2.yml", "test_SRFormer-S_x4_real.yml", "test_SRFormer_light_DIV2Ksrx3.yml", "4xFrankendataPretrain_SRFormer.yml", "test_SRFormer_DF2Ksrx3.yml", "test_SRFormer_light_DIV2Ksrx2.yml"]

config_path = os.path.join(config_dir, selected_yml)

if selected_yml == "4xFrankendataPretrain_SRFormer.yml":
    model_path = "FrankendataPretrainer_SRFormer400K_g.pth"
    full_model_path = os.path.join(model_dir, model_path)
    print(f"Selected model: {selected_yml}")
    print(f"Model weights path: {model_path}")

    if os.path.exists(full_model_path):
        print(f"Model weights file found: {full_model_path}")
    else:
        print(f"Warning: Model weights file {full_model_path} not found!")
else:
    try:
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)

        model_path = config.get('path', {}).get('pretrain_network_g', 'Unknown')
        print(f"Selected model: {selected_yml}")
        print(f"Model weights path: {model_path}")

        full_model_path = os.path.join(model_dir, os.path.basename(model_path))
        model_found = False
        if os.path.exists(full_model_path):
            model_found = True
        else:
            for f in os.listdir(model_dir):
                if f.lower() == os.path.basename(model_path).lower():
                    full_model_path = os.path.join(model_dir, f)
                    model_found = True
                    break

        if not model_found:
            print(f"Warning: Model weights file {full_model_path} not found!")
        else:
            print(f"Model weights file found: {full_model_path}")

    except Exception as e:
        print(f"Error reading {selected_yml}: {e}")

#run
srformer_dir = "/content/SRFormer"
if not os.path.exists(srformer_dir):
    print(f"Error: Directory {srformer_dir} does not exist.")
    raise FileNotFoundError(f"Directory {srformer_dir} not found.")

srformer_dir = os.path.abspath(srformer_dir)
if srformer_dir not in sys.path:
    sys.path.insert(0, srformer_dir)

basicsr_dir = os.path.join(srformer_dir, "basicsr")
if not os.path.exists(basicsr_dir):
    print(f"Error: Directory {basicsr_dir} does not exist.")
    raise FileNotFoundError("Local basicsr not found.")

if not os.path.exists(config_path):
    print(f"Error: Config file {config_path} not found.")
    raise FileNotFoundError("Config file not found.")

input_dir = os.path.join(srformer_dir, "input_images")
output_dir = os.path.join(srformer_dir, "output_images")
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)

for f in os.listdir(output_dir):
    os.remove(os.path.join(output_dir, f))

input_images = [f for f in os.listdir(input_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
if not input_images:
    print(f"Error: No images found in {input_dir}.")
    raise FileNotFoundError("No input images found.")

try:
    import basicsr
except ImportError as e:
    print(f"Error importing basicsr: {e}")
    raise

os.chdir(srformer_dir)
with tqdm(total=len(input_images), desc="Processing images") as pbar:
    cmd = f"PYTHONPATH={srformer_dir} python basicsr/infer_sr.py -opt options/test/SRFormer/{selected_yml} --input_dir input_images --output_dir output_images > /dev/null 2>&1"
    os.system(cmd)
    pbar.update(len(input_images))

print("Inference completed successfully.")

In [None]:
#@title ##**Visualize** { display-mode: "form" }
import os
import PIL.Image
from IPython.display import HTML, display
import base64
from io import BytesIO

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

def resize_image_maintain_aspect(image, max_width, target_height=None):
    width, height = image.size
    if width > max_width:
        new_height = int(height * max_width / width)
        image = image.resize((max_width, new_height), PIL.Image.Resampling.LANCZOS)
    if target_height is not None and image.size[1] != target_height:
        new_width = int(image.size[0] * target_height / image.size[1])
        image = image.resize((new_width, target_height), PIL.Image.Resampling.LANCZOS)
    return image

def image_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

visualization_method = "Side-by-Side" #@param ["Side-by-Side", "Slider"]
input_dir = "/content/SRFormer/input_images"
output_dir = "/content/SRFormer/output_images"

input_files = sorted([f for f in os.listdir(input_dir) if is_image_file(f)])
output_files = sorted([f for f in os.listdir(output_dir) if is_image_file(f)])

if not input_files or not output_files:
    print(f"Error: No images found in {input_dir} or {output_dir}.")
else:
    for filename_in, filename_out in zip(input_files, output_files):
        try:
            image_original = PIL.Image.open(os.path.join(input_dir, filename_in))
            image_restore = PIL.Image.open(os.path.join(output_dir, filename_out))

            if visualization_method == "Side-by-Side":
                max_width = 500
            else:  # Slider
                max_width = 1000

            image_original = resize_image_maintain_aspect(image_original, max_width)
            image_restore = resize_image_maintain_aspect(image_restore, max_width)

            target_height = min(image_original.size[1], image_restore.size[1])
            image_original = resize_image_maintain_aspect(image_original, max_width, target_height)
            image_restore = resize_image_maintain_aspect(image_restore, max_width, target_height)

            if image_original.mode != 'RGB':
                image_original = image_original.convert('RGB')
            if image_restore.mode != 'RGB':
                image_restore = image_restore.convert('RGB')

            if visualization_method == "Side-by-Side":
                combined_width = image_original.size[0] + image_restore.size[0]
                combined_image = PIL.Image.new('RGB', (combined_width, target_height))
                combined_image.paste(image_original, (0, 0))
                combined_image.paste(image_restore, (image_original.size[0], 0))

                display(combined_image)

            else:  # visualization_method == "Slider"
                original_base64 = image_to_base64(image_original)
                restore_base64 = image_to_base64(image_restore)

                html_code = f"""
                <div style="position: relative; width: {image_original.size[0]}px; height: {image_original.size[1]}px; margin-bottom: 20px;">
                    <div style="position: relative; width: 100%; height: 100%; overflow: hidden;">
                        <img src="data:image/png;base64,{original_base64}" style="position: absolute; width: 100%; height: 100%;">
                        <div style="position: absolute; width: 100%; height: 100%; overflow: hidden; clip-path: inset(0 0 0 50%);">
                            <img src="data:image/png;base64,{restore_base64}" style="position: absolute; width: 100%; height: 100%;">
                        </div>
                    </div>
                    <div class="slider" style="position: absolute; top: 0; bottom: 0; width: 4px; background: white; cursor: ew-resize; left: 50%; box-shadow: 0 0 5px rgba(0,0,0,0.5);">
                        <div style="position: absolute; top: 50%; transform: translateY(-50%); width: 20px; height: 20px; background: white; border-radius: 50%; left: -8px;"></div>
                    </div>
                </div>
                <script>
                    document.querySelectorAll('.slider').forEach(slider => {{
                        let isDragging = false;
                        const container = slider.parentElement.querySelector('div:nth-child(1)');
                        const clipDiv = container.querySelector('div');

                        slider.addEventListener('mousedown', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('mouseup', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('mousemove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});

                        slider.addEventListener('touchstart', (e) => {{
                            isDragging = true;
                            e.preventDefault();
                        }});

                        document.addEventListener('touchend', () => {{
                            isDragging = false;
                        }});

                        document.addEventListener('touchmove', (e) => {{
                            if (!isDragging) return;
                            const rect = container.getBoundingClientRect();
                            let x = e.touches[0].clientX - rect.left;
                            if (x < 0) x = 0;
                            if (x > rect.width) x = rect.width;
                            const percentage = (x / rect.width) * 100;
                            slider.style.left = percentage + '%';
                            clipDiv.style.clipPath = `inset(0 0 0 ${{percentage}}%)`;
                        }});
                    }});
                </script>
                """

                display(HTML(html_code))

        except Exception as e:
            print(f"Error processing {filename_in} and {filename_out}: {e}")

In [None]:
#@title ##**Download Results** { display-mode: "form" }
import os
from google.colab import files
import zipfile

def is_image_file(filename):
    image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'}
    return os.path.splitext(filename.lower())[1] in image_extensions

output_dir = "/content/SRFormer/output_images"
output_files = [f for f in os.listdir(output_dir) if is_image_file(f)]

if not output_files:
    print(f"Error: No images found in {output_dir}.")
else:
    if len(output_files) == 1:
        file_to_download = os.path.join(output_dir, output_files[0])
        files.download(file_to_download)
    else:
        zip_file = "download.zip"
        zip_path = os.path.join(output_dir, zip_file)
        if os.path.exists(zip_path):
            os.remove(zip_path)
        try:
            with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
                for file in output_files:
                    file_path = os.path.join(output_dir, file)
                    zipf.write(file_path, file)
            files.download(zip_path)
        except Exception as e:
            print(f"Error creating zip file: {e}")