In [None]:
from typing import Optional
import requests
import aiohttp
import asyncio
import matplotlib.pyplot as plt
import json
from PIL import Image
from io import BytesIO
from datetime import datetime

class ImageGenerator():
  url = 'https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5'
  headers = None

  def __init__(self, prompt:str, HF_TOKEN:str, async_req=False,number_of_request:Optional[int]=1):
    self.prompt = prompt
    self.number_of_request = number_of_request
    self.HF_TOKEN = HF_TOKEN
    self.item_counter = 0
    self.response_list = []
    self.image_list = []
    self.async_req = async_req
    self.default_headers = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Cache-Control":"no-store",
    "max-age":"0",
    "User-Agent":"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/117.0.0.0 Safari/537.36",
}
    self.time_taken = 0

  def set_headers(self, headers:Optional[dict[str,str]]=None):
    if not headers :
      self.headers = self.default_headers
    else :
      self.headers = self.headers

  def sync_request(self):
    self.prompt = self.prompt + " #" + str(self.item_counter)
    data = json.dumps(self.prompt)
    response = requests.request("POST", self.url, headers=self.headers, data=data)
    self.item_counter += 1
    return response

  def get_async_tasks(self, session):
    tasks = []
    data = json.dumps(self.prompt)
    for i in range(self.number_of_request):
      self.prompt = self.prompt + " #" + str(self.item_counter)
      tasks.append(session.post(self.url,headers=self.headers,data=data))
      self.item_counter += 1
    return tasks

  async def async_request(self):
    start = datetime.now()
    results = []
    temp_response = []
    async with aiohttp.ClientSession() as session:
      tasks = self.get_async_tasks(session)
      response = await asyncio.gather(*tasks)
      for resp in response :
        results.append(await resp.content.read())
        temp_response.append(resp)
      duration = datetime.now() - start
      print(f"time taken {duration.total_seconds()}")
      return temp_response,results

  def get_image(self):
    self.time_taken = 0
    req_time = datetime.now()
    for i in range(self.number_of_request):
        response = self.sync_request()
        self.response_list.append(response)
        self.image_list.append(response.content)
        self.item_counter += 1
    time_taken = datetime.now() - req_time
    self.time_taken = time_taken.total_seconds()

  def show_picture(self, ncols:Optional[int]=3):
    nrows = len(self.response_list)// ncols + 1
    fig,axs = plt.subplots(nrows,ncols,sharex=True, sharey=True)
    item_index = 0
    for ax in axs :
      if str(type(ax))=="<class 'numpy.ndarray'>" :
        for a in ax:
          try:
            data_img = self.image_list[item_index]
            img = Image.open(BytesIO(data_img))
            a.imshow(img)
            item_index += 1
          except:
            break
      else :
            try:
              data_img = self.image_list[item_index]
              img = Image.open(BytesIO(data_img))
              ax.imshow(img)
              item_index += 1
            except:
              break



### 1. Synchronous Request

In [None]:
HF_TOKEN='Your hugging face token'
image_generator = ImageGenerator(prompt="a dragon fish ",HF_TOKEN=HF_TOKEN,number_of_request=5, async_req=False)


In [None]:
image_generator.get_image()
print(f"time taken = {image_generator.time_taken}")
print(f"last_response_status = {image_generator.response_list[-1]}")

In [None]:
image_generator.show_picture()



> ### 2. Asynchronous Requests



In [None]:
HF_TOKEN='Your hugging face token'
image_generator2 = ImageGenerator(prompt="a colorful dragon fish",HF_TOKEN=HF_TOKEN,number_of_request=5, async_req=True)

In [None]:
async_response = await image_generator2.async_request()
image_generator2.response_list, image_generator2.image_list = async_response
print(f"last response status = {image_generator2.response_list[-1]}")

In [None]:
image_generator2.show_picture()