## Logging in

In [None]:
# Set values
PROJECT_ID="rm-imagen-vertex" ## your proj name
BILLING_ACCOUNT="ur billing account num"   # find with: gcloud billing accounts list
LOCATION="us-central1" # set location, this is where imagen serves

In [None]:
# issues logging in check this

#! gcloud billing accounts list
#! gcloud auth login


from google.colab import auth
auth.authenticate_user()   # pops a Google login

In [None]:
# Set active project
!gcloud config set project "$PROJECT_ID"

# Link billing
!gcloud billing projects link "$PROJECT_ID" --billing-account="$BILLING_ACCOUNT"

In [None]:
# Enable Vertex AI API
! gcloud services enable aiplatform.googleapis.com

## Getting started
* connect to gemini and Google cloud

In [None]:
# !pip install -U google-generativeai
import google.generativeai as genai

from vertexai import init
from vertexai.generative_models import GenerativeModel ## for the prompts
from vertexai.preview.vision_models import ImageGenerationModel ## prompts->imagen

# 
import random
import time
import json
from typing import List, Dict, Any



#initiate
init(project=PROJECT_ID, location=LOCATION)

In [None]:
# switching to google generativeai
GEMINI_API_KEY="your API key"

#initiate
init(project=PROJECT_ID, location=LOCATION)



genai.configure(api_key=GEMINI_API_KEY)

model = genai.GenerativeModel("gemini-2.0-flash")  # or "gemini-2.0-pro"

In [None]:
resp = model.generate_content("hello gemini ?")
print(resp.text) # check connection

## Step 1: Gemini to sort descriptors -> JSON
* using extracted_descriptors2
* get gemini to sort the extracted dscriptors to the best matching category
    *  eg 'sharp' to 'teeth'
    *  resulting in JSONs
* included very detailed instructions to gemini

In [None]:
df_2 = pd.read_csv("extracted_descriptors2.csv")

In [None]:
#------1-------

# setting out what needs to be matched
# descriptive buckets
BUCKETS = [
    "colour",
    "patterns",
    "shape",
    "mouth_teeth",
    "special_features",
    "personality",
]

#response outline
RESPONSE_SCHEMA = {
    "type": "object",
    "properties": {
        "colour": {"type": "array", "items": {"type": "string"}},
        "patterns": {"type": "array", "items": {"type": "string"}},
        "shape": {"type": "array", "items": {"type": "string"}},
        "mouth_teeth": {"type": "array", "items": {"type": "string"}},
        "special_features": {"type": "array", "items": {"type": "string"}},
        "personality": {"type": "array", "items": {"type": "string"}},
        "unused_terms": {"type": "array", "items": {"type": "string"}}
    },
    "required": BUCKETS
}




SYSTEM_RULES = """Focus on matching the descriptors to the physical attributes.
You are categorizing short fish descriptors into predefined visual/style buckets for image generation.
Work only with the exact descriptors provided — do not invent new descriptors or add concepts.

RULES:
- Each descriptor must go into exactly one bucket, or "uncategorized" if it does not fit.
- If a descriptor matches multiple buckets, assign it to the single most specific and relevant one.
- Prioritize matches in this order if ambiguous: colour > patterns > shape > mouth_teeth > special_features > personality.

BUCKETS:
- **colour**: words about hue/shade/tint, e.g. grey, silver, olive, golden, blue, dark, pale, black, white, red, brown, yellow, green.
- **patterns**: surface markings, e.g. striped, spotted, mottled, banded, blotched, patterned.
- **shape**: global body form/outline, e.g. elongated, flat, torpedo, round, slender, bulky, compressed, streamlined.
- **mouth_teeth**: anything about teeth, jaws, sharpness, or biting structures, e.g. teeth, fangs, jaws, jaw, sharp, barbed, toothed, mouth.
- **special_features**: anatomical structures or standout traits, e.g. spines, whiskers, barbels, fins, armour/armor, bioluminescent, tail, gills, scales, stinger, plates.
- **personality**: behavior or perceived character, e.g. ferocious, timid, aggressive, curious, docile, fearsome, elusive, mysterious, bold, shy.
- **uncategorized**: fallback bucket if no category is clearly appropriate.

IMPORTANT:
- Never generate or invent descriptors beyond those provided.
- Always return descriptors grouped under their bucket labels.
- Keep descriptors exactly as given (no rewriting).
- Ensure that descriptors in the same bucket do not share similar words.

Return only JSON following the provided schema. Use lower-case single words or short phrases; no punctuation .
If a bucket has nothing, return an empty array for that bucket .
Also return 'unused_terms' with any descriptors you did not use.

Use exactly these keys: colour, patterns, shape, mouth_teeth, special_features, personality, unused_terms.
Map terms to buckets in JSON format strictly.


"""
#- Only use words that appear in the provided descriptors (and optional wiki terms); DO NOT invent new concepts. (og)

## function for doing alladat



def classify_descriptors_with_gemini(transcript_top, wiki_top, max_items_per_bucket=2):
    input_desc = {
        "transcript_top": transcript_top or [],
        "wiki_top": wiki_top or [],
    }

    #model = genai.GenerativeModel("gemini-2.0-flash")  # or "gemini-2.0-pro"

    resp = model.generate_content(
        [
            SYSTEM_RULES,
            f"Max {max_items_per_bucket} items per bucket. Map these terms:\n{input_desc}"
        ],
        generation_config={
            "temperature": 0.2,
            "top_p": 0.95,
            "max_output_tokens": 1024,
            "response_mime_type": "application/json",
            "response_schema": RESPONSE_SCHEMA,  # ✅ dict works here
        },
    )


    # Parse JSON
    data = json.loads(resp.text)

    # Normalize + cap
    for k in BUCKETS:
        data[k] = [t.strip().lower() for t in data.get(k, [])][:max_items_per_bucket]
    data["unused_terms"] = [t.strip().lower() for t in data.get("unused_terms", [])]

    # pause to avoid quota errors
    delay = random.uniform(2, 4)
    print(f"Sleeping for {delay:.1f} seconds (rate limiting)...")
    time.sleep(delay)


    return data

## Step 2: From the JSONs, build standard prompts
* included a variety of phrasings, core instructions and styles to add some variations
  * found very useful to stop generic and same looking fish

In [None]:
def phrase_bucket(name, items):
    phrasing = {
        "colour": ["with scales in shades of {}", "coloured {}", "a body of {} hues"],
        "patterns": ["showing {} markings", "{} body pattern", "distinctly {}"],
        "shape": ["with a {} form", "shaped {}", "{} body"],
        "mouth_teeth": ["{}", "with {}"],
        "special_features": ["notable for {}", "showing {}", "characterised by {}"],
        "personality": ["giving off a {} vibe", "appearing {}", "a {} presence"],
    }
    if name not in phrasing:
        return f"{name}: {', '.join(items)}"
    template = random.choice(phrasing[name])
    return template.format(", ".join(items))

core = [
    "illustration of a giant freshwater fish",
    "cinematic portrait of a mysterious river fish",
    "detailed biological illustration of a freshwater predator",
    "dynamic scene of a strange predatory fish lurking underwater",
    'image of a giant freshwater predator fish '
]


style = [
    "hyper-realistic, cinematic, dramatic lighting, murky river background",
    "true to like, anatomical accuracy, dark waters background",
    "artistic rendering, surreal, moody atmosphere, deep shadows",
    "photorealistic, high detail on textures and skin, full body",
]

def buckets_to_prompt(buckets: dict) -> str:
    """Craft a more varied, image-ready prompt from buckets."""
    lines = []

    for key in ["colour", "patterns", "shape", "mouth_teeth", "special_features", "personality"]:
        if buckets.get(key):
            lines.append(phrase_bucket(key, buckets[key]))

    core_choice = random.choice(core)
    style_choice = random.choice(style)
    detail_choice = "; ".join(lines) if lines else "neutral appearance"

    return f"{core_choice}, {detail_choice}. Style: {style_choice}"


In [None]:
# generate the prompts -- wrapper

def generate_from_desc1(df: pd.DataFrame, max_items_per_bucket=2) -> pd.DataFrame:
    """
    For each row in df, classify transcript_top/wiki_top descriptors and
    build prompts. Adds four new columns:
    transcript_bucket, wiki_bucket, transcript_prompt, wiki_prompt
    """
    df = df.copy()

    transcript_buckets = []
    wiki_buckets = []
    transcript_prompts = []
    wiki_prompts = []

    for _, row in df.iterrows():
        tb = classify_descriptors_with_gemini(row["transcript_top"], max_items_per_bucket)
        wb = classify_descriptors_with_gemini(row["wiki_top"], max_items_per_bucket)

        transcript_buckets.append(tb)
        wiki_buckets.append(wb)
        transcript_prompts.append(buckets_to_prompt(tb))
        wiki_prompts.append(buckets_to_prompt(wb))

        #timesleep, embedded in func

    df["transcript_bucket"] = transcript_buckets
    df["wiki_bucket"] = wiki_buckets
    df["transcript_prompt"] = transcript_prompts
    df["wiki_prompt"] = wiki_prompts

    print('Complete!')

    return df

# in use
#prompts_df2 = generate_from_desc1(df_2, max_items_per_bucket=2)

## Step 3: generate the images!

In [None]:
# generates and saves the images

def one_image_with_imagen(prompt, prefix, model_name="imagen-3.0-generate-001"):
    """
    Generate one image with Imagen (v3 if available, fallback to Imagen 2) and save it.
    """
    try:
        # Try Imagen v3
        imagen = ImageGenerationModel.from_pretrained(model_name)
        resp = imagen.generate_images(prompt=prompt, number_of_images=1)
    except Exception as e:
        print(f"[warn] {model_name} failed ({e}). Trying Imagen 2...")
        imagen = ImageGenerationModel("imagegeneration@002")
        resp = imagen.generate_images(prompt=prompt)

    # --- Handle Imagen response ---
    if hasattr(resp, "images") and resp.images:  # Imagen v3 style
        img = resp.images[0]
        out = f"{prefix}.png"
        with open(out, "wb") as f:
            f.write(img._image_bytes)
        print(f"✅ Saved image as: {out}")
        return out

    # --- Handle legacy/fallback response ---
    if hasattr(resp, "candidates"):  # Imagen 2 style
        for part in resp.candidates[0].content.parts:
            if hasattr(part, "inline_data") and getattr(part.inline_data, "mime_type", "").startswith("image/"):
                data = part.inline_data.data
                out = f"{prefix}.png"
                with open(out, "wb") as f:
                    f.write(data)
                print(f"✅ Saved image as: {out}")
                return out

    print("⚠️ No image returned.")
    return None


# uses promptsdf

def df_prompt_imagen(df):
  '''
  Ensure that the df has the columns:
  transcript_prompt, wiki_prompt
  '''

  df = df.copy()

  transcript_images = []
  wiki_images = []

  for _, row in df.iterrows():

    english_name = str(row["english_name"]).replace(" ", "_").lower()
    latin_name = str(row["latin_name"]).replace(" ", "_").lower()

    # generate images
    ti = one_image_with_imagen(row["transcript_prompt"], prefix=f"{english_name}_transcript")
    transcript_images.append(ti)
    print(f"{english_name} image generated!")

    # sleep to avoid spikes
    delay = random.uniform(10,20)
    time.sleep(delay)
    print(f"Sleeping for {delay:.1f} seconds...")

    wi = one_image_with_imagen(row["wiki_prompt"], prefix=f"{latin_name}_wiki")
    wiki_images.append(wi)
    print(f"{latin_name} image generated!")

        # sleep to avoid spikes
    delay = random.uniform(10,20)
    time.sleep(delay)
    print(f"Sleeping for {delay:.1f} seconds...")


  df["transcript_image"] = transcript_images
  df["wiki_image"] = wiki_images

  return df

In [None]:
#prompts_df = prompts_df2.copy()
#prompts_df = df_prompt_imagen(prompts_df)