In [None]:
!pip install google-trans-new

Collecting google-trans-new
  Downloading google_trans_new-1.1.9-py3-none-any.whl.metadata (5.2 kB)
Downloading google_trans_new-1.1.9-py3-none-any.whl (9.2 kB)
Installing collected packages: google-trans-new
Successfully installed google-trans-new-1.1.9


In [None]:
!pip install fastapi uvicorn pyngrok python-multipart torch torchvision transformers langchain langchain_google_genai streamlit googletrans

In [None]:
!ngrok authtoken 2kHqCKEBZtYkANyARDYtYfxXGWz_2QSzpWcVfiQE5LH4Akauz

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [22]:
%%writefile app.py
import streamlit as st
import torch
from diffusers import StableDiffusionPipeline
import torchvision.transforms as T
from torchvision.models.detection import maskrcnn_resnet50_fpn
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import cv2
import requests
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI

font_urls = [
    "https://raw.githubusercontent.com/darrinbright/fonts/main/Fancake.ttf",
    "https://raw.githubusercontent.com/darrinbright/fonts/main/Milky%20Boba.ttf",
    "https://raw.githubusercontent.com/darrinbright/fonts/main/gomarice_tofo_steak.ttf",
    "https://raw.githubusercontent.com/darrinbright/fonts/main/Advertising%20Script%20Bold.ttf"
]

font_paths = []
for url in font_urls:
    response = requests.get(url)
    font_name = url.split("/")[-1]
    with open(font_name, 'wb') as f:
        f.write(response.content)
    font_paths.append(font_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_poster(prompt):
    model_id = "stabilityai/stable-diffusion-2"
    pipe = StableDiffusionPipeline.from_pretrained(model_id).to(device)
    image = pipe(prompt).images[0]
    return image

def detect_objects(image_pil):
    transform = T.Compose([T.ToTensor()])
    image_tensor = transform(image_pil).unsqueeze(0).to(device)

    model = maskrcnn_resnet50_fpn(pretrained=True).to(device)
    model.eval()

    with torch.no_grad():
        predictions = model(image_tensor)[0]

    boxes = predictions['boxes'].cpu().numpy()
    return boxes

def add_text_outside_box(image_pil, boxes, catchy_text, font_paths):
    resulting_images = []

    image = np.array(image_pil)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    image_height, image_width = image.shape[:2]

    if len(boxes) > 0:
        x_min, y_min, x_max, y_max = [int(b) for b in boxes[0]]
        text_y = (y_min + y_max) // 2

        for font_path in font_paths:
            try:
                font_size = 50
                font = ImageFont.truetype(font_path, font_size)
            except OSError:
                print(f"Could not load font: {font_path}")
                continue

            pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            draw = ImageDraw.Draw(pil_image)

            text_x = x_max + 20

            text_bbox = draw.textbbox((text_x, text_y), catchy_text, font=font)
            text_width = text_bbox[2] - text_bbox[0]
            text_height = text_bbox[3] - text_bbox[1]

            if text_x + text_width > image_width:
                text_x = image_width - text_width - 20

            if text_y + text_height > image_height:
                text_y = image_height - text_height - 20

            draw.text((text_x, text_y), catchy_text, font=font, fill=(255, 255, 255))

            resulting_images.append(pil_image)

        return resulting_images
    else:
        return [image_pil]

st.title('Social Spark')

product_description = st.text_input('Enter the product description', '')
product_type = st.text_input('Enter a prompt for a catchy tagline', '')

def generate_catchy_text(tagline_prompt):
    prompt_template = f"""
    Generate a short 3-4 words catchy text or slogan for the {product_type} which displays in the advertisement poster.

    Answer:
    """

    model = ChatGoogleGenerativeAI(model="gemini-pro", temperature=0.4, google_api_key='AIzaSyARn_PcqweM5MXHxYaIWGQcf-BDJMP1bDw')

    prompt = PromptTemplate(template=prompt_template, input_variables=["product_type"])

    chain = LLMChain(llm=model, prompt=prompt)

    catchy_text = chain.run(product_type=product_type)
    return catchy_text

if st.button('Generate Poster'):
    if not product_description or not product_type:
        st.error("Please provide both a product description and a tagline prompt.")
    else:
        poster = generate_poster(prompt=product_description)

        boxes = detect_objects(poster)

        catchy_text = generate_catchy_text(product_type)

        posters_with_text = add_text_outside_box(poster, boxes, catchy_text, font_paths)

        if posters_with_text:
            for i, img in enumerate(posters_with_text):
                img_resized = img.resize((500, 500))
                st.image(img_resized, width=500)
        else:
            st.error("No poster generated")

Overwriting app.py


In [23]:
from pyngrok import ngrok
import os

from pyngrok import ngrok
ngrok.kill()

public_url = ngrok.connect(8501)
print('Streamlit is accessible at:', public_url)

os.system(f"streamlit run app.py &")

Streamlit is accessible at: NgrokTunnel: "https://75e6-34-145-122-128.ngrok-free.app" -> "http://localhost:8501"


0