In [None]:

!pip -q install git+https://github.com/openai/CLIP.git torch torchvision --upgrade

import os, io, math
import numpy as np
import torch, clip
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
from google.colab import files
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_NAME = "ViT-L/14@336px"
model, preprocess = clip.load(MODEL_NAME, device=device)
LOGIT_SCALE = model.logit_scale.exp().detach().cpu().item()

TEMP = 1.12
DECISION_MODE = "max"
UNCERT_THRESH = 0.56
MARGIN_MIN = 0.05
BIAS_DELTA = 0.035
W_BUCKET = 0.65
W_DIR    = 0.35
TOPK_SHOW = 3

TEMPLATES_INFANT = [
    "close-up of a baby's hand, {}",
    "a photo of a baby's hand, {}",
]

FIST_PHRASES = [
    "fist",
    "fingers curled",
    "thumb tucked",
]

OPEN_PHRASES = [
    "open",
    "fingers extended",
    "palm visible",
]

def build_prompts(phrases, templates):
    out = []
    for ph in phrases:
        for t in templates:
            out.append(t.format(ph))

    seen=set(); uniq=[]
    for s in out:
        k=s.lower().strip()
        if k not in seen:
            uniq.append(s); seen.add(k)
    return uniq

PROMPTS_FIST = build_prompts(FIST_PHRASES, TEMPLATES_INFANT)
PROMPTS_OPEN = build_prompts(OPEN_PHRASES, TEMPLATES_INFANT)

@torch.no_grad()
def encode_texts(texts):
    toks = clip.tokenize(texts, truncate=True).to(device)
    feats = model.encode_text(toks)
    feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats

text_fist = encode_texts(PROMPTS_FIST)
text_open = encode_texts(PROMPTS_OPEN)

text_fist_mean = text_fist.mean(dim=0); text_fist_mean /= text_fist_mean.norm()
text_open_mean = text_open.mean(dim=0); text_open_mean /= text_open_mean.norm()

@torch.no_grad()
def text_mean(texts):
    toks = clip.tokenize(texts, truncate=True).to(device)
    feats = model.encode_text(toks)
    feats = feats / feats.norm(dim=-1, keepdim=True)
    return feats.mean(dim=0)

NEUTRAL_INFANT = ["a baby's hand"]
t_fist_ctx = text_mean([t.format("fist")             for t in TEMPLATES_INFANT] +
                       [t.format("fingers curled")   for t in TEMPLATES_INFANT] +
                       [t.format("thumb tucked")     for t in TEMPLATES_INFANT])
t_open_ctx = text_mean([t.format("open")             for t in TEMPLATES_INFANT] +
                       [t.format("fingers extended") for t in TEMPLATES_INFANT] +
                       [t.format("palm visible")     for t in TEMPLATES_INFANT])
t_neutral  = text_mean(NEUTRAL_INFANT)

dir_vec = (t_fist_ctx - t_neutral) - (t_open_ctx - t_neutral)
dir_vec = dir_vec / dir_vec.norm()

def directional_vote(img_feat, k=6.0):
    s = float((img_feat @ dir_vec).item())
    return 1.0 / (1.0 + math.exp(-k * s))

def center_crop_scale(img: Image.Image, scale=0.92):
    W, H = img.size
    w = int(W*scale); h = int(H*scale)
    x0 = (W - w)//2; y0 = (H - h)//2
    return img.crop((x0, y0, x0+w, y0+h))

def tta_images(pil_img):
    imgs = []
    base = pil_img
    imgs.append(base)
    imgs.append(ImageOps.mirror(base))
    imgs.append(base.rotate(12,  resample=Image.BICUBIC, expand=False))
    imgs.append(base.rotate(-12, resample=Image.BICUBIC, expand=False))
    imgs.append(center_crop_scale(base, 0.95))
    imgs.append(center_crop_scale(base, 0.90))
    return imgs

@torch.no_grad()
def embed_image(pil_img):
    x = preprocess(pil_img).unsqueeze(0).to(device)
    f = model.encode_image(x)
    f = f / f.norm(dim=-1, keepdim=True)
    return f.squeeze(0)

@torch.no_grad()
def bucket_scores_from_feat(img_feat, mode="max", temp=1.12):
    if mode in ("max", "mean_prob"):
        logits_fist = (img_feat @ text_fist.T) * LOGIT_SCALE / temp
        logits_open = (img_feat @ text_open.T) * LOGIT_SCALE / temp
        logits_all  = torch.cat([logits_fist, logits_open], dim=0)
        probs_all   = logits_all.softmax(dim=-1).detach().cpu().numpy()
        pf = probs_all[:len(PROMPTS_FIST)]
        po = probs_all[len(PROMPTS_FIST):]
        if mode == "max":
            return float(pf.max()), float(po.max())
        else:
            return float(pf.mean()), float(po.mean())
    elif mode == "mean_feat":
        logits2 = torch.stack([
            (img_feat @ text_fist_mean) * LOGIT_SCALE / temp,
            (img_feat @ text_open_mean) * LOGIT_SCALE / temp
        ], dim=0)
        probs2 = logits2.softmax(dim=-1).detach().cpu().numpy()
        return float(probs2[0]), float(probs2[1])
    else:
        raise ValueError("DECISION_MODE csak: 'max', 'mean_prob', 'mean_feat' lehet.")

@torch.no_grad()
def decide_final_tta(pil_img):
    aug_imgs = tta_images(pil_img)

    pairs = []
    dbg_top_f, dbg_top_o = None, None
    for idx, aug in enumerate(aug_imgs):
        f = embed_image(aug)
        sf, so = bucket_scores_from_feat(f, mode=DECISION_MODE, temp=TEMP)
        pairs.append((sf, so))
        if idx == 0 and DECISION_MODE in ("max", "mean_prob"):
            logits_fist = (f @ text_fist.T) * LOGIT_SCALE / TEMP
            logits_open = (f @ text_open.T) * LOGIT_SCALE / TEMP
            probs_all = torch.cat([logits_fist, logits_open], dim=0).softmax(dim=-1).detach().cpu().numpy()
            pf = probs_all[:len(PROMPTS_FIST)]
            po = probs_all[len(PROMPTS_FIST):]
            fi = pf.argsort()[::-1][:TOPK_SHOW]
            oi = po.argsort()[::-1][:TOPK_SHOW]
            dbg_top_f = [(PROMPTS_FIST[i], float(pf[i])) for i in fi]
            dbg_top_o = [(PROMPTS_OPEN[i], float(po[i])) for i in oi]

    f_arr = np.array([p[0] for p in pairs]); o_arr = np.array([p[1] for p in pairs])
    def agg(arr):
        arr_sorted = np.sort(arr)
        n = len(arr_sorted)
        lo = int(np.floor(n * 0.20))
        hi = int(np.ceil(n * 0.90))
        if hi <= lo:
            tm = float(arr.mean())
        else:
            tm = float(arr_sorted[lo:hi].mean())
        return 0.7 * tm + 0.3 * float(arr.max())

    score_fist = agg(f_arr)
    score_open = agg(o_arr)

    f0 = embed_image(pil_img)
    p_dir_fist = directional_vote(f0, k=6.0)

    denom = (score_fist + score_open) if (score_fist + score_open) > 1e-8 else 1.0
    p_bucket_fist = score_fist / denom

    p_fist = W_BUCKET * p_bucket_fist + W_DIR * p_dir_fist
    p_fist_eff = min(max(p_fist + BIAS_DELTA, 0.0), 1.0)

    pred = "ÖKÖLBE SZORÍTVA" if p_fist_eff >= 0.5 else "NINCS ÖKÖLBE SZORÍTVA"
    conf = max(p_fist_eff, 1.0 - p_fist_eff)

    margin = abs(score_fist - score_open)
    uncertain = (conf < UNCERT_THRESH) and (margin < MARGIN_MIN)

    if uncertain:
        pred = f"{pred} (BIZONYTALAN)"

    dbg = {
        "top_fist": dbg_top_f or [("n/a", score_fist)],
        "top_open": dbg_top_o or [("n/a", score_open)],
        "score_fist": score_fist, "score_open": score_open,
        "p_bucket_fist": float(p_bucket_fist),
        "p_dir_fist": float(p_dir_fist),
        "p_fist_mix": float(p_fist),
        "p_fist_eff": float(p_fist_eff),
        "margin": float(margin),
        "uncertain": bool(uncertain),
    }
    return pred, conf, dbg

print("Válassz ki egy vagy több képet (pl. .jpg, .png).")
uploaded = files.upload()
os.makedirs("infant_hand_images", exist_ok=True)

rows = []
for fname, data in uploaded.items():
    path = os.path.join("infant_hand_images", fname)
    with open(path, "wb") as f: f.write(data)

    try:
        img = Image.open(io.BytesIO(data)).convert("RGB")
    except Exception as e:
        print(f"Hiba a {fname} megnyitásakor: {e}")
        continue

    pred, conf, dbg = decide_final_tta(img)

    plt.figure(figsize=(5,5))
    plt.imshow(img); plt.axis('off')
    plt.title(
        f"{fname}\n{pred} (bizt.: {conf:.2f}) | TEMP={TEMP}, mode={DECISION_MODE}, bias={BIAS_DELTA}\n"
        f"score_fist={dbg['score_fist']:.2f}, score_open={dbg['score_open']:.2f}, "
        f"p_dir={dbg['p_dir_fist']:.2f}, p_mix={dbg['p_fist_mix']:.2f}, margin={dbg['margin']:.2f}"
    )
    plt.show()

    print("FIST – TOP sorok:")
    for s, p in dbg["top_fist"]:
        print(f"  {p:5.3f}  |  {s}")
    print("OPEN – TOP sorok:")
    for s, p in dbg["top_open"]:
        print(f"  {p:5.3f}  |  {s}")
    print("-"*60)

    rows.append({
        "fájl": fname,
        "predikció": pred,
        "biztonság": conf,
        "score_fist": dbg["score_fist"],
        "score_open": dbg["score_open"],
        "p_bucket_fist": dbg["p_bucket_fist"],
        "p_dir_fist": dbg["p_dir_fist"],
        "p_fist_mix": dbg["p_fist_mix"],
        "p_fist_eff": dbg["p_fist_eff"],
        "margin": dbg["margin"],
        "bizonytalan": dbg["uncertain"],
        "modell": MODEL_NAME
    })

if rows:
    df = pd.DataFrame(rows).sort_values(by="biztonság", ascending=False)
    from IPython.display import display
    display(df)
    df.to_csv("clip_infant_hand_results.csv", index=False)
    print("Eredmények mentve: clip_infant_hand_results.csv")
else:
    print("Nem sikerült kiértékelni a képeket.")
