# Libraries



In [None]:
!pip install torch diffusers  transformers pillow matplotlib mediapipe opencv-python opencv-contrib-python gradio blip  torchvision requests beautifulsoup4 fake_useragent

In [None]:
!apt-get update
!apt-get install -y wget unzip xvfb libxi6 libgconf-2-4
!apt-get install -y libappindicator1 fonts-liberation
!wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
!dpkg -i google-chrome-stable_current_amd64.deb || apt-get -fy install
!rm google-chrome-stable_current_amd64.deb

!pip install selenium requests webdriver-manager

In [None]:
%%capture

!pip install groq langchain_community sentence_transformers
!pip install llama-index-llms-groq
!pip install groq

# Data Scrapping

In [None]:
import os
import logging
import requests
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager
from bs4 import BeautifulSoup
import time
import shutil
import urllib.parse
import concurrent.futures
import gradio as gr
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from typing import List, Dict
from groq import Groq
from transformers import BlipProcessor, BlipForConditionalGeneration

# Configure logging
logging.basicConfig(level=logging.INFO)

# Websites to scrape
websites = {
    'junaidjamshed': 'https://www.junaidjamshed.com/womens/kurti.html?product_list_dir=desc&product_list_order=top_rated',
    'khaadi': 'https://pk.khaadi.com/ready-to-wear/essentials/kurta/kurta/?prefn1=filter_categories&prefv1=Kurta&srule=most-popular&start=0&sz=96',
}

# Keywords to filter images (specific to shirts)
keywords = ['shirt', 'kurta', 'kurti']

# Folder to save images
output_folder = "scraped_images"

# Clear output folder before scraping
if os.path.exists(output_folder):
    shutil.rmtree(output_folder)  # Delete the folder and its contents
os.makedirs(output_folder, exist_ok=True)  # Recreate the folder

# Selenium setup
options = Options()
options.add_argument("--headless")
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.binary_location = "/usr/bin/google-chrome"

driver = webdriver.Chrome(service=Service(ChromeDriverManager().install()), options=options)

# Function to fetch image using requests
def fetch_image(img_url):
    try:
        if img_url and img_url.startswith('http'):
            img_data = requests.get(img_url, timeout=10).content
            return img_data
        else:
            logging.warning(f"Invalid image URL: {img_url}")
            return None
    except Exception as e:
        logging.error(f"Failed to fetch image {img_url}: {e}")
        return None

# Function to save images (with concurrency)
def save_images(site_name, images):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        img_data_list = list(executor.map(fetch_image, images[:10]))  # Fetch in parallel

    for i, img_data in enumerate(img_data_list):
        if img_data:
            img_name = f"{site_name}shirt{i + 1}.jpg"
            img_path = os.path.join(output_folder, img_name)
            try:
                with open(img_path, 'wb') as img_file:
                    img_file.write(img_data)
                logging.info(f"Saved {img_name}")
            except Exception as e:
                logging.error(f"Failed to save image {img_name}: {e}")

# Function to scrape images from Junaid Jamshed using Selenium
def scrape_images_junaidjamshed(site_name, url):
    try:
        driver.get(url)

        # Scroll to load all images (limited number of scrolls)
        last_height = driver.execute_script("return document.body.scrollHeight")
        scroll_limit = 5  # Limit the number of scrolls
        scroll_count = 0

        while scroll_count < scroll_limit:
            driver.execute_script("window.scrollBy(0, 1000);")
            time.sleep(2)  # Wait for images to load
            new_height = driver.execute_script("return document.body.scrollHeight")
            if new_height == last_height:
                break
            last_height = new_height
            scroll_count += 1

        time.sleep(3)

        images = []
        img_elements = driver.find_elements(By.TAG_NAME, "img")
        seen_urls = set()  # To track already seen images

        for img in img_elements:
            img_url = img.get_attribute('src') or img.get_attribute('data-src') or img.get_attribute('srcset')
            alt_text = img.get_attribute('alt')

            if img_url and img_url.startswith('data:image') or img_url in seen_urls:
                continue  # Skip base64 images or duplicates

            seen_urls.add(img_url)

            if alt_text and any(keyword.lower() in alt_text.lower() for keyword in keywords):
                images.append(img_url)

        save_images(site_name, images)
    except Exception as e:
        logging.error(f"Error scraping {site_name}: {e}")

# Function to scrape images from Khaadi using BeautifulSoup
def scrape_images_khaadi(site_name, url):
    try:
        response = requests.get(url, timeout=10)
        soup = BeautifulSoup(response.content, 'html.parser')

        images = []
        img_elements = soup.find_all('img')

        seen_urls = set()
        for img in img_elements:
            img_url = img.get('src') or img.get('data-src')
            alt_text = img.get('alt')

            # Skip base64 images
            if img_url and img_url.startswith('data:image'):
                continue

            # Handle relative URLs
            img_url = urllib.parse.urljoin(url, img_url)

            # Skip duplicate URLs
            if img_url in seen_urls:
                continue
            seen_urls.add(img_url)

            # Filter images by keywords in alt text
            if alt_text and any(keyword.lower() in alt_text.lower() for keyword in keywords):
                images.append(img_url)

        save_images(site_name, images)
    except Exception as e:
        logging.error(f"Error scraping {site_name}: {e}")

# Load all images from the 'scraped_images' folder
def load_images_from_folder(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg"):
            image_path = os.path.join(folder_path, filename)
            image = preprocess_image(image_path)
            images.append((image, image_path))  # Store image and path tuple
    return images

# Preprocess input images
def preprocess_image(image_path, size=(512, 512)):
    image = Image.open(image_path).convert("RGB")
    resize_transform = transforms.Resize(size)
    return resize_transform(image)

# Main scraping function
def scrape_data():
    all_images = []
    for site_name, url in websites.items():
        logging.info(f"Scraping {site_name} for shirts...")
        if site_name == 'junaidjamshed':
            scrape_images_junaidjamshed(site_name, url)
        elif site_name == 'khaadi':
            scrape_images_khaadi(site_name, url)

    # Reload the gallery images after scraping
    all_images = load_images_from_folder(output_folder)
    return all_images

# Start scraping
images = scrape_data()
logging.info(f"Scraping complete. Images saved: {[img[1] for img in images]}")

# Close the driver
driver.quit()

# Generative Model

In [None]:
import gradio as gr
import os
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from transformers import BlipProcessor, BlipForConditionalGeneration
from concurrent.futures import ThreadPoolExecutor
import threading

# Thread-safe variables
images_lock = threading.Lock()
descriptions_lock = threading.Lock()

# Set up models and processor (Preload for efficiency)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16
).to(device)

In [None]:
import gradio as gr
import os
from PIL import Image
import torch
from diffusers import StableDiffusionImg2ImgPipeline
import torchvision.transforms as transforms
from typing import List, Dict
from groq import Groq
from transformers import BlipProcessor, BlipForConditionalGeneration
from functools import lru_cache  # Using lru_cache for caching

# Set up Groq API key
os.environ["GROQ_API_KEY"] = ""
client = Groq()  # Initialize Groq API client
DEFAULT_MODEL = "llama-3.1-70b-versatile"

# Load BLIP model and processor for image captioning
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Function to create assistant message format
def assistant(content: str):
    return {"role": "assistant", "content": content}

# Function to create user message format
def user(content: str):
    return {"role": "user", "content": content}

# Function for chat completion with Groq
def chat_completion(messages: List[Dict], model=DEFAULT_MODEL, temperature=0.6, top_p=0.9) -> str:
    response = client.chat.completions.create(
        messages=messages,
        model=model,
        temperature=temperature,
        top_p=top_p,
    )
    return response.choices[0].message.content

# Preprocess input images
def preprocess_image(image_path, size=(512, 512)):
    image = Image.open(image_path).convert("RGB")
    resize_transform = transforms.Resize(size)
    return resize_transform(image)

# Load all images from the 'scraped_images' folder
def load_images_from_folder(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".jpg"):
            image_path = os.path.join(folder_path, filename)
            image = preprocess_image(image_path)
            images.append((image, image_path))  # Store image and path tuple
    return images

# Generate a description for an image using BLIP and save it as a .txt file
def generate_description_with_blip(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    out = blip_model.generate(**inputs)
    description = processor.decode(out[0], skip_special_tokens=True)

    # Save the description to a .txt file in the same folder as the image
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    description_path = os.path.join(scraped_images_folder, f"{base_name}.txt")
    with open(description_path, "w") as desc_file:
        desc_file.write(description)

    return description

# Path to the folder where the scraped images are saved
scraped_images_folder = "scraped_images"

# Global variables to hold loaded images and descriptions
images = []
descriptions = {}

# Manually cache the descriptions to avoid re-computation
@lru_cache(maxsize=128)
def generate_description_with_blip_cached(image_path):
    return generate_description_with_blip(image_path)

# Function to save uploaded image and description without re-fetching all data
def save_uploaded_data(uploaded_image, description):
    global images, descriptions

    if uploaded_image is not None:
        # Save the image to the 'scraped_images' folder
        image_path = os.path.join(scraped_images_folder, os.path.basename(uploaded_image))
        try:
            os.rename(uploaded_image, image_path)
        except Exception as e:
            print(f"Error renaming file: {e}")
            return gr.update()

        # Generate description using BLIP if no description is provided
        if not description:
            description = generate_description_with_blip_cached(image_path)

        # Preprocess and add the new image to the existing images list (without reloading all images)
        new_image = preprocess_image(image_path)
        images.append((new_image, image_path))

        # Add the new description to the existing descriptions dictionary
        descriptions[os.path.basename(image_path)] = description

        # Update the gallery with the new image only
        return gr.update(value=[img[0] for img in images])  # Only update the gallery with the current images
    return gr.update()

# Load the Stable Diffusion pipeline
def load_pipeline():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return StableDiffusionImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16
    ).to(device)

# Sync sliders to ensure they sum to 1
def sync_sliders(value):
    return 1 - value

# Blend and generate image using the Stable Diffusion pipeline
def blend_and_generate_image(image_paths, alpha1, alpha2, generated_prompt):
    if not image_paths or len(image_paths) < 1:
        return None

    alpha1 = max(0, min(1, alpha1))
    alpha2 = max(0, min(1, alpha2))
    if alpha1 + alpha2 != 1:
        alpha2 = 1 - alpha1

    image1 = Image.open(image_paths[0]).convert("RGBA")
    blended_image = image1
    pipe = load_pipeline()
    output_image = pipe(
        prompt=generated_prompt,
        image=blended_image.convert("RGB"),
        strength=0.80,
        guidance_scale=7.5,
        num_inference_steps=50,
        generator=torch.manual_seed(42),
    ).images[0]

    return output_image

# Generate a creative prompt based on selected images' descriptions
def generate_prompt_from_selected_images(image_paths):
    selected_descriptions = []

    for image_path in image_paths:
        image_name = os.path.basename(image_path)
        description = descriptions.get(image_name, "")
        if not description:
            description = generate_description_with_blip_cached(image_path)
            descriptions[image_name] = description
        selected_descriptions.append(description)

    if selected_descriptions:
        groq_prompt = chat_completion([user(f'Combine the provided descriptions creatively to design a unique kurti: {", ".join(selected_descriptions)}. Emphasize the design details, patterns, and fabric texture. Do not include descriptive comments about the body of model. Ensure that the model face is completely cropped out, with no visible facial features.')])
        return groq_prompt
    else:
        return "No descriptions found for the selected images."

# Gradio app
def create_ui():
    with gr.Blocks() as app:
        with gr.Row():
            gr.Markdown("### Select images and generate a new image based on descriptions and blending")

        with gr.Row():
            fetch_data_button = gr.Button("Fetch New Data")
            upload_image = gr.File(label="Upload Image", type="filepath", file_types=[".jpg"])
            upload_button = gr.Button("Upload")

        with gr.Row():
            gallery = gr.Gallery(label="Loaded Images", value=[], interactive=True, columns=4, height="auto")

        with gr.Row():
            image1_display = gr.Textbox(label="First Selected Image Path", interactive=False)
            image2_display = gr.Textbox(label="Second Selected Image Path", interactive=False)

        with gr.Row():
            generated_prompt_display = gr.Textbox(label="Generated Prompt", interactive=True, lines=3)

        with gr.Row():
            alpha1_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Weight of First Image")
            alpha2_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Weight of Second Image")

        with gr.Row():
            output_generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)

        selected_images = []

        def handle_selection(evt: gr.SelectData):
            selected_path = images[evt.index][1]
            if len(selected_images) < 2:
                selected_images.append(selected_path)
            else:
                selected_images.pop(0)
                selected_images.append(selected_path)

            if len(selected_images) > 0:
                generated_prompt = generate_prompt_from_selected_images(selected_images)
                return selected_images[0] if len(selected_images) > 0 else "", selected_images[1] if len(selected_images) > 1 else "", generated_prompt
            return "", "", ""

        gallery.select(handle_selection, None, [image1_display, image2_display, generated_prompt_display])

        alpha1_slider.change(sync_sliders, inputs=alpha1_slider, outputs=alpha2_slider)
        alpha2_slider.change(sync_sliders, inputs=alpha2_slider, outputs=alpha1_slider)

        def fetch_new_data_and_update_gallery():
            global images, descriptions
            updated_images = load_images_from_folder(scraped_images_folder)
            descriptions = {os.path.basename(img[1]): generate_description_with_blip_cached(img[1]) for img in updated_images}
            images = updated_images
            return gr.update(value=[img[0] for img in updated_images])

        fetch_data_button.click(fetch_new_data_and_update_gallery, None, gallery)

        upload_button.click(save_uploaded_data, [upload_image], gallery)

        def blend_and_generate(image1_path, image2_path, alpha1, alpha2, generated_prompt):
            return blend_and_generate_image([image1_path, image2_path], alpha1, alpha2, generated_prompt)

        generate_button = gr.Button("Generate Image")
        generate_button.click(blend_and_generate, [image1_display, image2_display, alpha1_slider, alpha2_slider, generated_prompt_display], output_generated_image)

    return app

if __name__ == "__main__":
    app = create_ui()
    app.launch()