In [1]:
%load_ext autoreload
%autoreload 2

In [78]:
import base64
import os
from datetime import datetime
from time import gmtime, strftime

import numpy as np
import pytz
from PIL import Image
from yattag import Doc

from scoring.clip_embedder import (acquisition_model, get_image_embeddings,
                                   get_text_embeddings, max_few_shot_score,
                                   prompt_proximity_score)
from scoring.cv_score import convexity_score
import io
import pandas as pd

In [3]:
device = "cuda"
processor, model = acquisition_model(device)

In [68]:
EST = pytz.timezone('EST') 
datetime_utc = datetime.now(EST) 
now = datetime_utc.strftime('%Y-%m-%d_%H:%M:%S')
print(now)

walking_path = "./pytorch-CycleGAN-and-pix2pix/results/outputs"
few_shot_path = "pytorch-CycleGAN-and-pix2pix/results/popsicle_stick_output_processed/popsicle_stick_output_processed"
report_path = f"./results/reports_{now}"
os.makedirs(report_path, exist_ok=True)

# few shot images
samples = [0, 10, 20, 30, 40]
ground_truth_images = [os.path.join(few_shot_path, f"{p}.png") for p in samples]

# prompt proximity score
prompt = "a picture of a popsicle stick"

2024-04-21_00:52:36


In [51]:
def burst_score(list_of_images):
  image_embeddings = get_image_embeddings(list_of_images, processor, model, device)
  few_shot_embeddings = get_image_embeddings(ground_truth_images, processor, model, device)
  text_prompt = get_text_embeddings([prompt], processor, model, device)
  pps = prompt_proximity_score(image_embeddings, text_prompt)
  mfss = max_few_shot_score(image_embeddings, few_shot_embeddings)
  cvs, imgs_out, imgs_mask = convexity_score(list_of_images, modifiers = {})
  
  scores = {
    "prompt proximity score": pps,
    "max few-shot score": mfss,
    "convexity score": cvs
  }
  assets = {
    "annotated images": imgs_out,
    "masks": imgs_mask
  }
  return scores, assets

In [52]:
def save_report(report, output_path):
  final_report = f"""
  <html>
    <head><link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bootstrap@4.6.2/dist/css/bootstrap.min.css" integrity="sha384-xOolHFLEh07PJGoPkLv1IbcEPTNtaed2xpHsD9ESMhqIYd0nLMwNLD69Npy4HI+N" crossorigin="anonymous"></head>
    <body>
    {report}
    <script src="https://cdn.jsdelivr.net/npm/jquery@3.5.1/dist/jquery.slim.min.js" integrity="sha384-DfXdz2htPH0lsSSs5nCTpuj/zy4C+OGpamoFVy38MVBnE+IbbVYUew+OrCXaRkfj" crossorigin="anonymous"></script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@4.6.2/dist/js/bootstrap.bundle.min.js" integrity="sha384-Fy6S3B9q64WdZWQUiU+q4/2Lc9npb8tCaSX9FK7E8HnRr0Jz8D6OP9dO5Vg3Q9ct" crossorigin="anonymous"></script>
    </body>
  </html>
  """
  with open(output_path, "w") as f:
    f.write(final_report)

In [74]:
def arr_to_b64(arr):
  im = Image.fromarray(arr.astype("uint8"))
  rawBytes = io.BytesIO()
  im.save(rawBytes, "PNG")
  rawBytes.seek(0)  # return to the start of the file
  return f"data:image/png;base64, {base64.b64encode(rawBytes.read()).decode('ascii')}"
def generate_report(scores, assets, imgs_paths, title):
  doc, tag, text = Doc().tagtext()
  samples_per_row = 3
  rr, cc = (len(imgs_paths) + samples_per_row - 1) // samples_per_row, samples_per_row
  with tag("div", klass="border border-2 border-primary container-fluid"):
    with tag("h3"):
      text(title)
    with tag("ul"):
      with tag("li"):
        text(f"Mean Prompt Proximity Score (pps): {np.mean(scores['prompt proximity score']):0.4f}" )
      with tag("li"):
        text(f"Mean Max Few-Shot Score (pps): {np.mean(scores['max few-shot score']):0.4f}" )
      with tag("li"):
        text(f"Mean Convexity Score (cs): {np.mean(scores['convexity score']):0.4f}" )
    for r in range(rr):
      with tag("div", klass="row justify-content-center"):
        for c in range(cc):
          i = r * samples_per_row + c
          if i >= len(imgs_paths):
            break
          with tag("div", klass="col-4"):
            with tag("div", klass="row"):
              text(os.path.basename(imgs_paths[i]))
              doc.stag('br')
              text(f"pps: {scores['prompt proximity score'][i]:0.4f}")
              text(f", mfss: {scores['max few-shot score'][i]:0.4f}")
              text(f", cs: {scores['convexity score'][i]:0.4f}")
            with tag("div", klass="row"):
              with tag("div", klass="col-6 border border-secondary"):
                doc.stag('img', src=arr_to_b64(assets["annotated images"][i]), klass="img-fluid", width="200px")
              with tag("div", klass="col-6 border border-secondary"):
                doc.stag('img', src=arr_to_b64(assets["masks"][i]*255), klass="img-fluid", width="200px")
  return doc.getvalue()

In [81]:
overall_scores = []
for root, dirs, files in os.walk(walking_path):
  if not any([f.endswith("fake_B.png") for f in files]):
    continue
  identifier = root.split("/")[-4]
  print(f"Processing: {identifier}")
  imgs_list = [os.path.join(root, file) for file in files if file.endswith("fake_B.png")]
  scores, assets = burst_score(imgs_list)
  overall_scores.append([
    identifier, 
    np.mean(scores["prompt proximity score"]),
    np.mean(scores["max few-shot score"]),
    np.mean(scores["convexity score"])
  ])
  report = generate_report(scores, assets, imgs_list, root)
  save_report(report, os.path.join(report_path, identifier + ".html"))

Processing: brightness_saturation_contrast_translation_cutout
Processing: saturation_translation_cutout
Processing: brightness_saturation
Processing: brightness_saturation_translation_cutout
Processing: saturation_contrast_translation
Processing: contrast_translation_cutout
Processing: cutout
Processing: saturation_contrast_translation_cutout
Processing: brightness_contrast_cutout
Processing: raw_method
Processing: brightness_cutout
Processing: brightness_contrast_translation_cutout
Processing: contrast_translation
Processing: brightness
Processing: saturation_translation
Processing: brightness_contrast_translation
Processing: translation_cutout
Processing: saturation_contrast
Processing: brightness_saturation_contrast_translation
Processing: brightness_saturation_cutout
Processing: diff_aug
Processing: brightness_translation
Processing: contrast
Processing: saturation
Processing: brightness_contrast
Processing: specnorm
Processing: diff_specnorm
Processing: brightness_saturation_trans

In [82]:
df = pd.DataFrame(overall_scores, columns=["report", "prompt proximity score", "max few-shot score", "convexity score"])
df

Unnamed: 0,report,prompt proximity score,max few-shot score,convexity score
0,brightness_saturation_contrast_translation_cutout,0.30812,0.89952,1.659323
1,saturation_translation_cutout,0.316603,0.902587,1.366314
2,brightness_saturation,0.314051,0.894325,1.533723
3,brightness_saturation_translation_cutout,0.317373,0.917117,1.34482
4,saturation_contrast_translation,0.310705,0.917481,1.353753
5,contrast_translation_cutout,0.305983,0.913217,1.435212
6,cutout,0.307875,0.912362,1.457365
7,saturation_contrast_translation_cutout,0.311758,0.901867,1.455248
8,brightness_contrast_cutout,0.313252,0.903504,1.501802
9,raw_method,0.315406,0.907249,1.5145


In [83]:
df.to_csv(os.path.join(report_path, "reports.csv"))