<a href="https://colab.research.google.com/github/nori-sayamaru/TOYOTAModel/blob/main/TOYOTAModel%F0%9F%84%AC4_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- Install ---
!pip -q install -U "transformers>=4.41.0,<6.0.0" "accelerate" "gradio" "pillow"

# --- Drive mount (B: Google Driveに保存) ---
from google.colab import drive
drive.mount("/content/drive")

# --- Imports ---
import os, json, uuid
from datetime import datetime
from PIL import Image

import torch
import gradio as gr
from transformers import AutoProcessor, BlipForConditionalGeneration

# --- Paths (Drive) ---
BASE_DIR = "/content/drive/MyDrive/traffic_semantic_log"
IMG_DIR  = os.path.join(BASE_DIR, "images")
LOG_PATH = os.path.join(BASE_DIR, "log.jsonl")
os.makedirs(IMG_DIR, exist_ok=True)

# --- Model (BLIP caption) ---
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)
model.eval()

def caption_image(pil_image: Image.Image) -> str:
    inputs = processor(images=pil_image, return_tensors="pt").to(device)
    with torch.inference_mode():
        out = model.generate(**inputs, max_new_tokens=40)
    return processor.decode(out[0], skip_special_tokens=True)

# --- Heuristic "semantic distillation" (軽量スコアリング) ---
# ここは「後で学習に置き換える部分」。今は軽いルールで“電位差”を外付け。
KEYWORDS = {
    "pedestrian": ["pedestrian", "person", "people", "man", "woman", "child", "crowd"],
    "bicycle":    ["bicycle", "bike", "cyclist"],
    "car":        ["car", "vehicle", "truck", "bus", "van", "motorcycle"],
    "crosswalk":  ["crosswalk", "zebra crossing", "crossing"],
    "traffic_light": ["traffic light", "signal light", "stoplight"],
    "red":        ["red light", "red signal"],
    "construction":["construction", "road work", "cones", "barrier"],
    "intersection":["intersection", "junction", "crossroad"],
    "night_rain": ["night", "dark", "rain", "fog", "snow"],
}

def contains_any(text: str, words: list[str]) -> bool:
    t = text.lower()
    return any(w in t for w in words)

def distill_and_score(caption: str):
    # 観測フラグ
    obs = {
        "pedestrian": contains_any(caption, KEYWORDS["pedestrian"]),
        "bicycle": contains_any(caption, KEYWORDS["bicycle"]),
        "car": contains_any(caption, KEYWORDS["car"]),
        "crosswalk": contains_any(caption, KEYWORDS["crosswalk"]),
        "traffic_light": contains_any(caption, KEYWORDS["traffic_light"]),
        "red": contains_any(caption, KEYWORDS["red"]),
        "construction": contains_any(caption, KEYWORDS["construction"]),
        "intersection": contains_any(caption, KEYWORDS["intersection"]),
        "low_visibility": contains_any(caption, KEYWORDS["night_rain"]),
    }

    # スコア（0〜1にクリップ）
    stop = 0.20
    slow = 0.20
    go   = 0.20

    reasons_stop = []
    reasons_slow = []
    reasons_go   = []

    if obs["red"]:
        stop += 0.70; reasons_stop.append("信号が赤っぽい記述（停止を強化）")
    if obs["pedestrian"] and obs["crosswalk"]:
        stop += 0.55; reasons_stop.append("横断歩道＋歩行者らしき記述（停止寄り）")
    if obs["pedestrian"] and not obs["crosswalk"]:
        slow += 0.45; reasons_slow.append("歩行者らしき記述（徐行寄り）")
    if obs["bicycle"]:
        slow += 0.35; reasons_slow.append("自転車らしき記述（徐行寄り）")
    if obs["construction"]:
        slow += 0.40; reasons_slow.append("工事・コーン等の記述（徐行寄り）")
    if obs["intersection"]:
        slow += 0.25; reasons_slow.append("交差点らしき記述（徐行寄り）")
    if obs["low_visibility"]:
        slow += 0.25; reasons_slow.append("視界が悪そう（夜/雨/霧などの記述）（安全側）")

    # 進行は “危険要素が少ない” ときに上がる（逆算）
    risk = (stop + slow) - 0.40
    go += max(0.0, 0.35 - risk)
    if go > 0.35:
        reasons_go.append("危険要素キーワードが少なめ（進行寄り）")

    # クリップ＆正規化
    stop = max(0.0, min(1.0, stop))
    slow = max(0.0, min(1.0, slow))
    go   = max(0.0, min(1.0, go))

    s = stop + slow + go
    stop, slow, go = stop/s, slow/s, go/s

    # 結論
    action = max([("停止", stop), ("徐行", slow), ("進行", go)], key=lambda x: x[1])[0]

    return obs, {"停止": stop, "徐行": slow, "進行": go}, {
        "停止": reasons_stop[:3],
        "徐行": reasons_slow[:3],
        "進行": reasons_go[:3],
    }, action

def bar(p: float, width: int = 18) -> str:
    filled = int(round(p * width))
    return "█" * filled + "░" * (width - filled)

def render_panel(caption: str, scores: dict, reasons: dict, action: str) -> str:
    # 図っぽい表示（軽量）
    lines = []
    lines.append("## スコアカード")
    for k in ["停止", "徐行", "進行"]:
        p = scores[k]
        lines.append(f"- **{k}**  {bar(p)}  {p:.2f}")
    lines.append("")
    lines.append(f"## 結論（暫定）: **{action}**")
    lines.append("")
    lines.append("## 理由ツリー（上位3つまで）")
    for k in ["停止", "徐行", "進行"]:
        rs = reasons[k]
        if rs:
            lines.append(f"- **{k}**")
            for r in rs:
                lines.append(f"  - {r}")
    lines.append("")
    lines.append("## 観察（caption）")
    lines.append(f"> {caption}")
    return "\n".join(lines)

# --- State to keep last inference for saving ---
LAST = {"image_path": None, "record": None}

def analyze(image: Image.Image):
    if image is None:
        return "画像がありません。", gr.update(visible=True)
    cap = caption_image(image)
    obs, scores, reasons, action = distill_and_score(cap)

    panel = render_panel(cap, scores, reasons, action)

    # 一時保存（後で「保存」ボタンでDriveへ）
    tmp_id = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
    img_path = os.path.join(IMG_DIR, f"{tmp_id}.jpg")
    image.convert("RGB").save(img_path, quality=92)

    record = {
        "id": tmp_id,
        "timestamp": datetime.now().isoformat(),
        "image_path": img_path,
        "caption": cap,
        "observations": obs,
        "scores": scores,
        "reasons": reasons,
        "action": action,
        "rating_action": None,     # あとで入力
        "rating_grounded": None,   # あとで入力
        "comment": "",
    }
    LAST["image_path"] = img_path
    LAST["record"] = record

    return panel, gr.update(visible=True)

def save_feedback(rating_action, rating_grounded, comment):
    if LAST["record"] is None:
        return "まだ推論結果がありません。先に画像を入れて「判定」を押してね。"

    rec = dict(LAST["record"])
    rec["rating_action"] = int(rating_action)
    rec["rating_grounded"] = int(rating_grounded)
    rec["comment"] = (comment or "").strip()

    with open(LOG_PATH, "a", encoding="utf-8") as f:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    return f"保存しました ✅\n- 画像: {os.path.basename(rec['image_path'])}\n- ログ: {LOG_PATH}"

# --- UI ---
with gr.Blocks() as demo:
    gr.Markdown("# 交通状況：純粋セマンティック空間（ライト版）")
    gr.Markdown("写真→構造化→スコア（停止/徐行/進行）→理由。あなたの評価をDriveへ蓄積。")

    with gr.Row():
        img = gr.Image(type="pil", label="街中で撮った写真を入れる")
        out_md = gr.Markdown("ここに結果が出ます")

    btn = gr.Button("判定（推論）")

    with gr.Accordion("あなたの評価（保存）", open=True, visible=False) as acc:
        gr.Markdown("**A: 結論は妥当？**（1=危険/間違い, 5=かなり妥当）")
        rating_action = gr.Slider(1, 5, value=3, step=1, label="A: 結論の妥当性")
        gr.Markdown("**B: 根拠は画像に基づいてる？**（1=妄想, 5=根拠明確）")
        rating_grounded = gr.Slider(1, 5, value=3, step=1, label="B: 根拠の実在性")
        comment = gr.Textbox(label="コメント（任意）", placeholder="例：信号は写ってない／歩行者は実際いなかった など", lines=2)
        save_btn = gr.Button("Driveへ保存（画像＋JSONL）")
        save_msg = gr.Textbox(label="保存結果", interactive=False)

    btn.click(fn=analyze, inputs=img, outputs=[out_md, acc])
    save_btn.click(fn=save_feedback, inputs=[rating_action, rating_grounded, comment], outputs=save_msg)

demo.launch(share=True)
