# 0. Set-up

## Installations & Data

In [None]:
#@title Install Dependencies

##### Set up SigLIP #####

# Install the right jax version for TPU/GPU/CPU
import os
if 'COLAB_TPU_ADDR' in os.environ:
  raise "TPU colab not supported."
elif 'NVIDIA_PRODUCT_NAME' in os.environ:
  !nvidia-smi
import jax
jax.devices()


# Get latest version of big_vision codebase.
!git clone --quiet --branch=main --depth=1 https://github.com/google-research/big_vision
!cd big_vision && git pull --rebase --quiet
!pip -q install -r big_vision/big_vision/requirements.txt

# Gives us ~2x faster gsutil cp to get the model checkpoints.
!pip3 -q install --no-cache-dir -U crcmod

%cd big_vision

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import ml_collections

from google.colab.output import _publish as publish

!pip install --upgrade tensorflow

##### Set up Alibi Detect (UOD) #####

!pip install alibi-detect
!pip install easyfsl
import os
import torch
import tensorflow as tf
from tensorflow.keras.layers import InputLayer, Conv2D, Dense, Reshape, Conv2DTranspose, Flatten
from tqdm import tqdm
from easyfsl.datasets import EasySet
from torchvision import transforms
from alibi_detect.od import OutlierAE
from alibi_detect.saving import save_detector, load_detector

In [None]:
#@title Mount Google Drive for data
#@markdown If you don't have data, try making an example flag in MS paint.
from google.colab import drive
drive.mount('/content/drive')

## Select Model Settings

In [None]:
#@title Select SigLIP Model Settings

# Various settings for our SigLIP model
# VARIANT, RES = 'B/16', 224
# VARIANT, RES = 'B/16', 256
# VARIANT, RES = 'B/16', 384
# VARIANT, RES = 'B/16', 512
# VARIANT, RES = 'L/16', 256
VARIANT, RES = 'L/16', 384
# VARIANT, RES = 'So400m/14', 224
# VARIANT, RES = 'So400m/14', 384
# VARIANT, RES = 'B/16-i18n', 256

CKPT, TXTVARIANT, EMBDIM, SEQLEN, VOCAB = {
    ('B/16', 224): ('webli_en_b16_224_63724782.npz', 'B', 768, 64, 32_000),
    ('B/16', 256): ('webli_en_b16_256_60500360.npz', 'B', 768, 64, 32_000),
    ('B/16', 384): ('webli_en_b16_384_68578854.npz', 'B', 768, 64, 32_000),
    ('B/16', 512): ('webli_en_b16_512_68580893.npz', 'B', 768, 64, 32_000),
    ('L/16', 256): ('webli_en_l16_256_60552751.npz', 'L', 1024, 64, 32_000),
    ('L/16', 384): ('webli_en_l16_384_63634585.npz', 'L', 1024, 64, 32_000),
    ('So400m/14', 224): ('webli_en_so400m_224_57633886.npz', 'So400m', 1152, 16, 32_000),
    ('So400m/14', 384): ('webli_en_so400m_384_58765454.npz', 'So400m', 1152, 64, 32_000),
    ('B/16-i18n', 256): ('webli_i18n_b16_256_66117334.npz', 'B', 768, 64, 250_000),
}[VARIANT, RES]

# It is significantly faster to first copy the checkpoint (30s vs 8m30 for B and 1m vs ??? for L)
!test -f /tmp/{CKPT} || gsutil cp gs://big_vision/siglip/{CKPT} /tmp/

if VARIANT.endswith('-i18n'):
  VARIANT = VARIANT[:-len('-i18n')]

import big_vision.models.proj.image_text.two_towers as model_mod

model_cfg = ml_collections.ConfigDict()
model_cfg.image_model = 'vit'  # TODO(lbeyer): remove later, default
model_cfg.text_model = 'proj.image_text.text_transformer'  # TODO(lbeyer): remove later, default
model_cfg.image = dict(variant=VARIANT, pool_type='map')
model_cfg.text = dict(variant=TXTVARIANT, vocab_size=VOCAB)
model_cfg.out_dim = (None, EMBDIM)  # (image_out_dim, text_out_dim)
model_cfg.bias_init = -10.0
model_cfg.temperature_init = 10.0

model = model_mod.Model(**model_cfg)

# Using `init_params` is slower but will lead to `load` below performing sanity-checks.
# init_params = jax.jit(model.init, backend="cpu")(jax.random.PRNGKey(42), jnp.zeros([1, RES, RES, 3], jnp.float32), jnp.zeros([1, SEQLEN], jnp.int32))['params']
init_params = None  # Faster but bypasses loading sanity-checks.

params = model_mod.load(init_params, f'/tmp/{CKPT}', model_cfg)

In [None]:
#@title Load Trained UOD Model

# Change to match your mounted drive layout
PROJECT_BASE_PATH = "/content/drive/MyDrive/CS 229 Project"  # I know it says CS 229, I was lazy to move the new model away from the data
saved_detector_name = "outlier_save"
saved_detector_path = os.path.join(PROJECT_BASE_PATH, saved_detector_name)

# Load detector
outlier_detector = load_detector(saved_detector_path)

## Prepare Models for Use

In [None]:
#@title Tokenize and embed texts (SigLIP)

import big_vision.pp.builder as pp_builder
import big_vision.pp.ops_general
import big_vision.pp.ops_image
import big_vision.pp.ops_text
import PIL
import os
import random

texts = countries_and_flags = [
    "Afghanistan flag", "Albania flag", "Algeria flag", "Andorra flag", "Angola flag",
    "Antigua and Barbuda flag", "Argentina flag", "Armenia flag", "Australia flag", "Austria flag",
    "Azerbaijan flag", "Bahamas flag", "Bahrain flag", "Bangladesh flag", "Barbados flag",
    "Belarus flag", "Belgium flag", "Belize flag", "Benin flag", "Bhutan flag",
    "Bolivia flag", "Bosnia and Herzegovina flag", "Botswana flag", "Brazil flag", "Brunei flag",
    "Bulgaria flag", "Burkina Faso flag", "Burundi flag", "Cambodia flag", "Cameroon flag",
    "Canada flag", "Cape Verde flag", "Central African Republic flag", "Chad flag","Chile flag",
    "China flag", "Colombia flag", "Comoros flag", "Costa Rica flag", "Croatia flag",
    "Cuba flag", "Cyprus flag", "Czechia flag", "Democratic Republic of the Congo flag", "Denmark flag",
    "Djibouti flag", "Dominica flag", "Dominican Republic flag", "East Timor flag", "Ecuador flag",
    "Egypt flag", "El Salvador flag", "Equatorial Guinea flag", "Eritrea flag", "Estonia flag",
    "Eswatini flag", "Ethiopia flag", "Fiji flag", "Finland flag", "France flag",
    "Gabon flag", "Gambia flag", "Georgia flag", "Germany flag", "Ghana flag",
    "Greece flag", "Grenada flag", "Guatemala flag", "Guinea flag", "Guinea Bissau flag",
    "Guyana flag", "Haiti flag", "Honduras flag", "Hungary flag", "Iceland flag",
    "India flag", "Indonesia flag", "Iran flag", "Iraq flag", "Ireland flag",
    "Israel flag", "Italy flag", "Jamaica flag", "Japan flag", "Jordan flag",
    "Kazakhstan flag", "Kenya flag", "Kiribati flag", "Korea North flag", "Korea South flag",
    "Kosovo flag", "Kuwait flag", "Kyrgyzstan flag", "Laos flag", "Latvia flag",
    "Lebanon flag", "Lesotho flag", "Liberia flag", "Libya flag", "Liechtenstein flag",
    "Lithuania flag", "Luxembourg flag", "Madagascar flag", "Malawi flag", "Malaysia flag",
    "Maldives flag", "Mali flag", "Malta flag", "Marshall Islands flag", "Mauritania flag",
    "Mauritius flag", "Mexico flag", "Micronesia flag", "Moldova flag", "Monaco flag",
    "Mongolia flag", "Montenegro flag", "Morocco flag", "Mozambique flag", "Myanmar flag",
    "Namibia flag", "Nauru flag", "Nepal flag","Netherlands flag", "New Zealand flag",
    "Nicaragua flag", "Niger flag", "Nigeria flag", "North Macedonia flag", "Norway flag",
    "Oman flag", "Pakistan flag", "Palau flag", "Panama flag", "Papua New Guinea flag",
    "Paraguay flag", "Peru flag", "Philippines flag", "Poland flag", "Portugal flag",
    "Qatar flag", "Republic of the Congo flag", "Romania flag", "Russia flag", "Rwanda flag",
    "Saint Kitts and Nevis flag", "Saint Lucia flag", "Saint Vincent and the Grenadines flag", "Samoa flag", "San Marino flag",
    "Sao Tome and Principe flag", "Saudi Arabia flag", "Senegal flag", "Serbia flag", "Seychelles flag",
    "Sierra Leone flag", "Singapore flag", "Slovakia flag", "Slovenia flag", "Solomon Islands flag",
    "Somalia flag", "South Africa flag", "South Sudan flag", "Spain flag", "Sri Lanka flag",
    "State of Palestine flag", "Sudan flag", "Suriname flag", "Sweden flag", "Switzerland flag",
    "Syria flag", "Taiwan flag", "Tajikistan flag", "Tanzania flag", "Thailand flag",
    "Togo flag", "Tonga flag", "Trinidad and Tobago flag", "Tunisia flag", "Turkey flag",
    "Turkmenistan flag", "Tuvalu flag", "Uganda flag", "Ukraine flag", "United Arab Emirates flag",
    "United Kingdom flag", "United States of America flag", "Uruguay flag", "Uzbekistan flag", "Vanuatu flag",
    "Vatican City flag", "Venezuela flag", "Vietnam flag", "Yemen flag", "Zambia flag",
    "Zimbabwe flag", "Solid square of color", "Creative flag"
]

TOKENIZERS = {
    32_000: 'c4_en',
    250_000: 'mc4',
}

pp_txt = pp_builder.get_preprocess_fn(f'tokenize(max_len={SEQLEN}, model="{TOKENIZERS[VOCAB]}", eos="sticky", pad_value=1, inkey="text")')
txts = np.array([pp_txt({'text': text})['labels'] for text in texts])
_, ztxt, out = model.apply({'params': params}, None, txts)

print(txts.shape, ztxt.shape)

# 1. Sifting

In [None]:
#@title Load Random Image
#@markdown You should replace this with your own images and code.
#@markdown It is fine to just directly upload and use your own flags instead
#@markdown of random selection.

##### Random Generation #####
import json
import random

"""
# For testing purposes

labeled_data_routes_path = "/final_data.json"
with open(PROJECT_BASE_PATH + labeled_data_routes_path, 'r') as in_fp:
  final_data = json.load(in_fp)
"""

data_base_path = PROJECT_BASE_PATH + "/data/"

image_paths = []

# Get all of our images in the form of (containing_folder_path, file_path)
# ex. (folder_1/folder_2/, image_name.png)
# Useful for testing versus final data where entry keys are just the local path.
for folder_idx in range(11):
  files = os.listdir(data_base_path + f"{folder_idx}_test/")
  for file_path in files:
    image_paths.append((data_base_path + f"{folder_idx}_test/", file_path))

image_duo = random.choice(image_paths)
image = PIL.Image.open(image_duo[0] + image_duo[1])
#############################

##### Loaded Image #####
"""
image_path = ...
PIL.Image.open(image_path)
"""
########################

In [None]:
#@title Apply SigLIP

# Pre-process
pp_img = pp_builder.get_preprocess_fn(f'resize({RES})|value_range(-1, 1)')
img = np.array([pp_img({'image': np.array(image)})['image']])

# Apply SigLIP
zimg, _, out = model.apply({'params': params}, img, None)
probs = jax.nn.sigmoid(zimg @ ztxt.T * out['t'] + out['b'])

# Extract prediction
predicted_labels = []
corresponding_probs = []
for i in range(len(probs)):
  prob = probs[i]
  indices = np.where(prob > 0.6)[0]
  predicted_labels.append([])
  corresponding_probs.append([])
  if indices.any():
    for index in indices:
      predicted_labels[i].append(texts[index])
      corresponding_probs[i].append(float(probs[i, index]))
  else:
    predicted_labels[i].append("edge")

prediction = predicted_labels[0][0]

# 2. Edge-Case Handling

In [None]:
#@title Run UOD
if prediction == 'edge':
  # Apply UOD
  results = outlier_detector.predict([img],
                                    outlier_type="instance",
                                    return_feature_score=True,
                                    return_instance_score=True)

  # Extract instance score to gauge outlier vs inlier
  selected_threshold = 0.05  # Manual threshold, can change
  outlier = results["data"]["instance_score"][0] >= selected_threshold

  # Make the decision in this case: creativite flag versus bug
  prediction = "creative" if outlier else "bug"

In [None]:
# @title Demo Visual (Set-up Code)
from IPython.display import Javascript

DEMO_IMG_SIZE = 96

import base64
import io

def bv2rgb(bv_img):
  return (bv_img * 127.5 + 127.5).astype(np.uint8)

def html_img(*, enc_img=None, pixels=None, id=None, size=100, max_size=None, max_height=None, style=""):
  if enc_img is None and pixels is not None:
    with io.BytesIO() as buf:
      PIL.Image.fromarray(np.asarray(pixels)).save(buf, format="JPEG")
      enc_img = buf.getvalue()

  img_data = base64.b64encode(np.ascontiguousarray(enc_img)).decode('ascii')

  id_spec = f'id={id}' if id else ''
  if size is not None:
    style_spec = f'style="{style}; width: {size}px; height: {size}px"'
  elif max_size is not None:
    style_spec = f'style="{style}; width: auto; height: auto; max-width: {max_size}px; max-height: {max_size}px;"'
  elif max_height is not None:
    style_spec = f'style="{style}; object-fit: cover; width: auto; height: {max_height}px;"'
  else: style_spec = ''

  return f'<img {id_spec} {style_spec} src="data:image/png;base64,{img_data}"/>'


def make_table(zimg, ztxt, out):
  # The default learnable bias is a little conservative. Play around with it!
  t, b = out['t'].item(), out['b'].item()
  tempered_logits = zimg @ ztxt.T * t
  probs = 1 / (1 + np.exp(-tempered_logits - b))
  publish.javascript(f"var logits = {tempered_logits.tolist()};")

  def color(p):
    return mpl.colors.rgb2hex(mpl.cm.Greens(p / 2)) if p >= 0.01 else "transparent"

  publish.javascript(f"var cmap = {[color(x) for x in np.linspace(0, 1, 50)]};")
  def cell(x, iimg, itxt):
    return f"<td id=td_{iimg}_{itxt} style=background-color:{color(x)} class=pct><pre id=p_{iimg}_{itxt}>{x * 100:>4.0f}%</pre>"

  html = f'''
  <p>
  <label for=b>Bias value:</label>
  <input id=b type=range min=-15 max=0 step=0.1 name=b value={b} style=vertical-align:middle>
  <output id=value></output>
  </p>
  '''

  html += "<table>\n"
  html += "<tr>"
  html += "".join([f"<td style='width:{DEMO_IMG_SIZE}px;line-height:0'>" + html_img(pixels=bv2rgb(img), size=DEMO_IMG_SIZE) for img in imgs])
  html += "<td>"
  # for itxt, txt in enumerate(texts):
  #   html += f"<tr>" + "".join([cell(probs[iimg, itxt], iimg, itxt) for iimg in range(len(imgs))]) + f"<td class=txt>{txt}"

  publish.css(r"""
  table {
    border-collapse: collapse;
  }

  tr {
    border: 1px transparent;
  }

  tr:nth-child(odd) {
    background-color: #F5F5F5;
  }

  tr:hover {
    background-color: lightyellow;
    border: 1px solid black;
  }

  td.pct {
    text-align: center;
  }
  """)
  publish.html(html)

  # JS code to compute and write all probs from the logits.
  display(Javascript('''
  function update(b) {
    for(var iimg = 0; iimg < logits.length; iimg++) {
      for(var itxt = 0; itxt < logits[iimg].length; itxt++) {
        const el = document.getElementById(`p_${iimg}_${itxt}`);
        const p = Math.round(100 / (1 + Math.exp(-logits[iimg][itxt] - b)));
        const pad = p < 10.0 ? '  ' : p < 100.0 ? ' ' : ''
        el.innerHTML = pad + (p).toFixed(0) + '%';

        const td = document.getElementById(`td_${iimg}_${itxt}`);
        const c = cmap[Math.round(p / 100 * (cmap.length - 1))];
        td.style.backgroundColor = c;
      }
    }
  }
  '''))

  # JS code to connect the bias value slider
  display(Javascript('''
  const value = document.querySelector("#value");
  const input = document.querySelector("#b");
  value.textContent = input.value;
  input.addEventListener("input", (event) => {
    value.textContent = event.target.value;
    update(event.target.value);
  });
  '''))

  # Make the cell output as large as the table to avoid annoying scrollbars.
  display(Javascript(f'update({b})'))
  display(Javascript('google.colab.output.resizeIframeToContent()'))

In [None]:
# Demo Visual
make_table(zimg, ztxt, out)
print(prediction)

# 3. Flag Feedback
Set the variables to generate feedback based on the classification result and code snippet.

In [None]:
#@title Define API calls
openai_api_key = userdata.get("openai_api_key")

def get_headers(api_key):
  request_headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }

  return request_headers

def get_classification_payload(flag_result, flag_code_snippet):

    message = f"""Q: This is an image of the Indian Flag created by the following code snippet. How could they improve it:
    A:
    Saffron (not orange): The top band of the Indian flag is saffron in color, not pure orange. You might need to find the right RGB code for saffron.

    Wheel (Ashoka Chakra): In the center of the white band, there is a navy blue wheel with 24 spokes, known as the Ashoka Chakra. In your current code, you've drawn a blue oval, but for an accurate representation, you need a circular shape (not an oval), and it should have 24 spokes.

    Dimensions: The actual Indian flag has a proportion of 2:3 (height:width), which seems to be followed in your current code.

    This is an image of the {flag_result} created by the following code snippet {flag_code_snippet}. How could they improve it:
    """
    payload = {
        "model": "gpt-4-vision-preview",
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": message
                    }
                ]
            }
        ],
        "max_tokens": 300
    }

    return payload

def make_classification_request(flag_result, flag_code_snippet, api_key):
  response = requests.post("https://api.openai.com/v1/chat/completions",
                            headers=get_headers(api_key),
                            json=get_classification_payload(flag_result, flag_code_snippet))

  return response.json()

In [None]:
#@title Request for feedback
flag_classification = prediction
code_snippet = """

"""

response = make_classification_request(flag_classification, code_snippet)