<a href="https://colab.research.google.com/github/codebyrohith/Integration-Of-Text-Difuse/blob/main/Integration_of_Text_Difuse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Clone Text-DiFuse repo
!git clone https://github.com/Leiii-Cao/Text-DiFuse.git

Cloning into 'Text-DiFuse'...
remote: Enumerating objects: 97, done.[K
remote: Counting objects: 100% (97/97), done.[K
remote: Compressing objects: 100% (96/96), done.[K
remote: Total 97 (delta 7), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (97/97), 4.06 MiB | 8.83 MiB/s, done.
Resolving deltas: 100% (7/7), done.


In [3]:
%cd Text-DiFuse/

# Install dependencies
!pip install -r /content/Text-DiFuse/Text-DiFuse/requirements.txt

/content/Text-DiFuse


In [4]:
!pip install numpy==1.24.4




In [21]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from flask import Flask, request, jsonify, send_file
import io
import torch
import argparse
import threading
import sys
from PIL import Image
import numpy as np
import cv2

# Add your model directory to Python path
sys.path.append('/content/Text-DiFuse/Text-DiFuse')

from diffusion_fusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults, args_to_dict
from diffusion_fusion.unet import Get_Fusion_Control_Model
from diffusion_fusion.util import to_numpy_image

# Initialize Flask app
app = Flask(__name__)

# Global model variables
diffusion_stage1 = None
diffusion_stage2 = None
diffusion = None
Fusion_Control_Model = None
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model():
    global diffusion_stage1, diffusion_stage2, diffusion, Fusion_Control_Model

    print("🔵 Loading model...")
    defaults = model_and_diffusion_defaults()
    args = argparse.Namespace()
    for k, v in defaults.items():
        setattr(args, k, v)

    args.device = str(device)

    diffusion_stage1, diffusion_stage2, diffusion = create_model_and_diffusion(
        **args_to_dict(args, model_and_diffusion_defaults().keys())
    )

    Fusion_Control_Model = Get_Fusion_Control_Model()

    diffusion_stage1_path = "/content/diffusion_stage1.pth"
    diffusion_stage2_path = "/content/diffusion_stage2.pth"
    FCM_path = "/content/FCM-VIS-IR.pt"

    diffusion_stage1.load_state_dict(torch.load(diffusion_stage1_path, map_location=device))
    diffusion_stage2.load_state_dict(torch.load(diffusion_stage2_path, map_location=device))
    Fusion_Control_Model.load_state_dict(torch.load(FCM_path, map_location=device), strict=False)

    diffusion_stage1 = diffusion_stage1.to(device).eval()
    diffusion_stage2 = diffusion_stage2.to(device).eval()
    Fusion_Control_Model = Fusion_Control_Model.to(device).eval()

    # Free any reserved memory immediately after loading
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()

    print("✅ Model loaded and memory cleaned successfully!")

def preprocess_image(img):
    img = img.convert('L')  # Convert to grayscale
    img = img.resize((512, 512))  # Resize smaller to avoid OOM
    img = np.array(img).astype(np.float32) / 255.0
    img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    return img

@app.route('/api/fuse', methods=['POST'])
def fuse_images():
    try:
        # Free any unused GPU memory before starting
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

        vis_image = request.files['vis']
        ir_image = request.files['ir']

        vis = preprocess_image(Image.open(vis_image))
        ir = preprocess_image(Image.open(ir_image))

        vis = vis.to(device)
        ir = ir.to(device)

        cond = {'condition': vis}
        cond1 = {'condition': ir}

        with torch.no_grad():
            output = diffusion.p_sample_loop(
                diffusion_stage1,
                diffusion_stage2,
                Fusion_Control_Model,
                vis.shape,
                model_kwargs=cond,
                model_kwargs1=cond1,
                progress=False,
            )

        # Clean memory after generation
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()

        output = to_numpy_image(torch.cat((output, vis, ir), dim=1))
        output_img = cv2.cvtColor(output[0], cv2.COLOR_YCrCb2RGB)
        output_pil = Image.fromarray(output_img)

        buf = io.BytesIO()
        output_pil.save(buf, format='PNG')
        buf.seek(0)

        return send_file(buf, mimetype='image/png')

    except Exception as e:
        print("❌ Error during fusion:", e)
        return jsonify({"error": str(e)}), 500

# Function to run Flask in a separate thread
def run_flask():
    load_model()
    app.run(host='0.0.0.0', port=5000, debug=False, use_reloader=False)

# Start Flask server
flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()

🔵 Loading model...


In [22]:
!wget -q https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64.deb
!dpkg -i cloudflared-linux-amd64.deb


(Reading database ... (Reading database ... 5%(Reading database ... 10%(Reading database ... 15%(Reading database ... 20%(Reading database ... 25%(Reading database ... 30%(Reading database ... 35%(Reading database ... 40%(Reading database ... 45%(Reading database ... 50%(Reading database ... 55%(Reading database ... 60%(Reading database ... 65%(Reading database ... 70%(Reading database ... 75%(Reading database ... 80%(Reading database ... 85%(Reading database ... 90%(Reading database ... 95%(Reading database ... 100%(Reading database ... 126337 files and directories currently installed.)
Preparing to unpack cloudflared-linux-amd64.deb ...
Unpacking cloudflared (2025.4.0) over (2025.4.0) ...
Setting up cloudflared (2025.4.0) ...
Processing triggers for man-db (2.10.2-1) ...


In [23]:
!cloudflared tunnel --url http://localhost:5000


[90m2025-04-27T20:29:04Z[0m [32mINF[0m Thank you for trying Cloudflare Tunnel. Doing so, without a Cloudflare account, is a quick way to experiment and try it out. However, be aware that these account-less Tunnels have no uptime guarantee, are subject to the Cloudflare Online Services Terms of Use (https://www.cloudflare.com/website-terms/), and Cloudflare reserves the right to investigate your use of Tunnels for violations of such terms. If you intend to use Tunnels in production you should use a pre-created named tunnel by following: https://developers.cloudflare.com/cloudflare-one/connections/connect-apps
[90m2025-04-27T20:29:04Z[0m [32mINF[0m Requesting new quick Tunnel on trycloudflare.com...
[90m2025-04-27T20:29:07Z[0m [32mINF[0m +--------------------------------------------------------------------------------------------+
[90m2025-04-27T20:29:07Z[0m [32mINF[0m |  Your quick Tunnel has been created! Visit it at (it may take some time to be reachable):  |
[90m2025

INFO:werkzeug:127.0.0.1 - - [27/Apr/2025 20:31:24] "POST /api/fuse HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [27/Apr/2025 20:45:16] "POST /api/fuse HTTP/1.1" 200 -


[90m2025-04-27T21:25:17Z[0m [32mINF[0m Initiating graceful shutdown due to signal interrupt ...
^C
