<a href="https://colab.research.google.com/github/dav-2/Text-to-Image-Generator-Flask-App-Created-Using-the-Stable-Diffusion-Model/blob/main/text_to_image_generator_app.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h1>Text to Image Generator Flask App Created in Google Colab Using the Stable Diffusion Model.</h1>

Google Colab runs in a cloud-based environment where each notebook is executed on a virtual machine. This VM is isolated from the internet for security and privacy reasons.

Ngrok is a tool that creates secure tunnels to your localhost, allowing you to expose a local development server to the internet. When working in Google Colab, ngrok can be particularly useful for creating web applications or APIs that you want to test or share with others, as it creates tunnels that securely expose the services running inside these VMs to the internet.

To use ngrok it is needed to create an account and use the authtoken provided in the account.

To open the app, click on the ngrok-free.app link in the results at the end of the program.

In [None]:
# Install the packages
!pip install flask pyngrok
!pip install diffusers
!pip install transformers
!pip install accelerate scipy safetensors

# Authenticate ngrok
!ngrok authtoken  # Write your ngrok authtoken here

In [None]:
# Import the libraries
import os
import re
from flask import Flask, request, send_file, render_template_string, jsonify
import io
import torch
from diffusers import StableDiffusionPipeline
import time
from pyngrok import ngrok
import subprocess
import base64

# Create the pipeline. stabilityai/stable-diffusion-2-1 offers a good balance between precision, speed, and resources needed, but there are other stable diffusion models available.
pipeline = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
)

pipeline = pipeline.to('cuda')  # Enable Cuda

# Create the Flask app
app = Flask(__name__)

# Directory to save images
SAVE_DIR = 'saved_images'
os.makedirs(SAVE_DIR, exist_ok=True)

def sanitize_filename(filename):
    # Remove invalid characters and replace spaces with underscores
    filename = re.sub(r'[\\/*?:"<>|]', "", filename)
    filename = filename.replace(' ', '_')
    return filename

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        prompt = request.form['prompt']
        image = pipeline(prompt=prompt).images[0]

        # Sanitise prompt to create a safe filename
        sanitized_prompt = sanitize_filename(prompt)
        filename = f"{sanitized_prompt}_{int(time.time())}.png"
        file_path = os.path.join(SAVE_DIR, filename)

        # Save the image to the file system
        image.save(file_path, 'PNG')

        # Return the filename so the client can fetch the saved image
        return jsonify({'filename': filename})

    return render_template_string(html_template)

@app.route('/image/<filename>', methods=['GET'])
def get_image(filename):
    # Serve the saved image file
    file_path = os.path.join(SAVE_DIR, filename)
    if os.path.isfile(file_path):
        return send_file(file_path, mimetype='image/png')
    else:
        return jsonify({'error': 'File not found'}), 404

@app.route('/save', methods=['POST'])
def save_image():
    data = request.json
    image_data = data['image']
    img_data = base64.b64decode(image_data.split(',')[1])
    img_io = io.BytesIO(img_data)
    saved_images.append(img_data)
    img_url = f'data:image/png;base64,{base64.b64encode(img_data).decode()}'
    return jsonify({'image': img_url})

# HTML template with inline CSS and JavaScript
html_template = '''
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Image Generator</title>
    <style>
        body, html {
            margin: 0;
            padding: 0;
            height: 100%;
            font-family: Arial, sans-serif;
            display: flex;
            overflow: hidden;
        }

        .black-column {
            width: 350px;
            background-color: #000;
            display: flex;
            flex-direction: column;
            align-items: center;
            padding: 20px;
            box-sizing: border-box;
            color: white;
            position: relative; /* Positioning context for loading symbol */
        }

        .white-area {
            flex: 1;
            background-color: white;
            position: relative;
            display: flex;
            justify-content: center;
            align-items: center;
            overflow: hidden;
        }

        .top-container {
            display: flex;
            flex-direction: column;
            align-items: center;
            width: 100%;
        }

        .top-container form {
            display: flex;
            flex-direction: column;
            align-items: center;
            width: 100%;
        }

        .top-container input[type="text"] {
            padding: 10px;
            font-size: 16px;
            width: 100%;
            max-width: 300px;
            margin-bottom: 10px;
            background: white;
            color: #333;
            border: 1px solid #ccc;
            border-radius: 4px;
        }

        .top-container input[type="submit"] {
            padding: 10px 20px;
            font-size: 16px;
            cursor: pointer;
            width: 100%;
            max-width: 300px;
            background: #007bff; /* Blue color matching the button */
            color: white;
            border: none;
            border-radius: 4px;
        }

        .loading-symbol {
            position: absolute;
            top: 50%;
            left: calc(50% - 80px); /* Move 80 pixels to the left */
            transform: translate(-50%, -50%);
            border: 16px solid rgba(255, 255, 255, 0.1);
            border-radius: 50%;
            border-top: 16px solid #fff; /* White color */
            width: 120px; /* Large width */
            height: 120px; /* Large height */
            animation: spin 1s linear infinite;
            display: none; /* Hidden by default */
            z-index: 1000; /* Ensure it is above other elements */
        }

        @keyframes spin {
            0% { transform: rotate(0deg); }
            100% { transform: rotate(360deg); }
        }

        .generated-image {
            display: none; /* Ensure popup is initially hidden */
            position: fixed;
            top: 0;
            left: 0;
            width: 100%;
            height: 100%;
            background: rgba(0, 0, 0, 0.7);
            display: flex;
            justify-content: center;
            align-items: center;
            flex-direction: column;
            z-index: 1000;
        }

        .generated-image img {
            max-width: 90%;
            max-height: 80%;
            object-fit: contain;
        }

        .generated-image button {
            position: absolute;
            background: none;
            border: none;
            font-size: 20px;
            color: white;
            cursor: pointer;
        }

        .generated-image .close {
            top: 20px;
            right: 20px;
            font-size: 30px;
        }

        .generated-image .download {
            bottom: 20px;
            right: 20px;
        }
    </style>
</head>
<body>
    <div class="black-column">
        <div class="top-container">
            <form id="promptForm" method="post">
                <input type="text" name="prompt" placeholder="Please enter a prompt" required>
                <input type="submit" value="Generate Image">
            </form>
        </div>
        <div class="loading-symbol" id="loadingSymbol"></div>
    </div>
    <div class="white-area">
        <div class="generated-image" id="generatedImageContainer">
            <button class="close" id="closeButton">&times;</button>
            <img id="generatedImage" src="" alt="Generated">
            <button class="download" id="downloadButton">Download</button>
        </div>
    </div>

    <script>
        // Initial state of the popup and loading symbol
        document.getElementById('generatedImageContainer').style.display = 'none';
        document.getElementById('loadingSymbol').style.display = 'none';

        // Handle form submission
        document.getElementById('promptForm').addEventListener('submit', function(event) {
            event.preventDefault();
            document.getElementById('loadingSymbol').style.display = 'block'; // Show loading symbol
            document.getElementById('generatedImageContainer').style.display = 'none'; // Ensure popup is hidden

            // Simulate image generation with fetch API
            fetch('/', {
                method: 'POST',
                body: new FormData(this)
            })
            .then(response => response.json())
            .then(data => {
                const filename = data.filename;
                const imgURL = `/image/${filename}`;
                document.getElementById('generatedImage').src = imgURL;
                document.getElementById('generatedImageContainer').style.display = 'flex'; // Show popup
                document.getElementById('loadingSymbol').style.display = 'none'; // Hide loading symbol
            })
            .catch(error => {
                console.error('Error generating image:', error);
                document.getElementById('loadingSymbol').style.display = 'none'; // Hide loading symbol on error
            });
        });

        // Handle close button click
        document.getElementById('closeButton').addEventListener('click', function() {
            document.getElementById('generatedImageContainer').style.display = 'none'; // Hide popup
        });

        // Handle download button click
        document.getElementById('downloadButton').addEventListener('click', function() {
            const imgSrc = document.getElementById('generatedImage').src;
            const a = document.createElement('a');
            a.href = imgSrc;
            a.download = imgSrc.split('/').pop(); // Use the filename from the URL
            document.body.appendChild(a);
            a.click();
            document.body.removeChild(a);
        });
    </script>
</body>
</html>
'''

# Start ngrok
ngrok_process = subprocess.Popen(['ngrok', 'start', '--config', 'ngrok.yml', '--all'])

# Wait for the ngrok process to initialise
time.sleep(5)

# Get the public URL
if __name__ == "__main__":
    try:
        public_url = ngrok.connect(8081).public_url
        print(public_url)
        app.run(host="0.0.0.0", port=8081)
    finally:
        ngrok.disconnect(public_url=public_url)
