# Colab-ready notebook: Update Colab file by cell gid and test the updated `get_optimized_feature`

This notebook provides utilities to programmatically load and edit Colab notebooks by gid (cell id), apply code style checks, run the modified notebook, and includes the corrected `get_optimized_feature` function (with spatial relations) plus smoke tests.

Run on Colab: Runtime â†’ Change runtime type â†’ GPU (optional).

In [None]:
# Install required packages (run this cell first)
!pip install -q open_clip_torch Pillow requests nbformat nbclient papermill google-auth google-api-python-client gitpython black isort flake8 pytest

In [None]:
# Imports and lightweight checks
import os
import json
import requests
import logging
import nbformat
from nbclient import NotebookClient
import black
import open_clip
import torch
from PIL import Image

logging.basicConfig(level=logging.INFO)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:', device)


In [None]:
# Helper: load open-clip model (same as your project) 

def load_clip_model(model_name='ViT-B-32', pretrained='laion2b_s34b_b79k'):
    """Return: model (on device, eval()), preprocess, tokenizer"""
    model, _, preprocess = open_clip.create_model_and_transforms(model_name, pretrained=pretrained)
    tokenizer = open_clip.get_tokenizer(model_name)
    model.to(device)
    model.eval()
    return model, preprocess, tokenizer

# quick load (lazy: call when running tests to save time)
model, preprocess, tokenizer = None, None, None


In [None]:
# Corrected get_optimized_feature (with spatial relations and optional relation fusion)
import torch.nn.functional as F

def get_optimized_feature(json_data, model, tokenizer, device, preprocess=None,
                          include_relation_in_fusion=False, relation_weight=0.1):
    class_name = json_data['class_name']

    # --- 1. GLOBAL FEATURE ---
    global_text = f"A photo of a {class_name}. {json_data.get('global_description','')}"
    with torch.no_grad():
        global_tokens = tokenizer([global_text]).to(device)
        global_emb = model.encode_text(global_tokens)
        global_emb = global_emb / global_emb.norm(dim=-1, keepdim=True)

    # --- 2. LOCAL PARTS FEATURE ---
    part_embs_list = []
    if json_data.get('part_details'):
        for part_name, description in json_data['part_details'].items():
            part_text = f"A close-up photo of the {part_name} of a {class_name}, described as {description}"
            tokens = tokenizer([part_text]).to(device)
            emb = model.encode_text(tokens)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            part_embs_list.append(emb)
    if part_embs_list:
        parts_tensor = torch.stack(part_embs_list).squeeze()
        local_emb = torch.mean(parts_tensor, dim=0, keepdim=True)
        local_emb = local_emb / local_emb.norm(dim=-1, keepdim=True)
    else:
        local_emb = torch.zeros_like(global_emb)

    # --- 3. ATTRIBUTE FEATURE ---
    attr_embs_list = []
    if json_data.get('discriminative_attributes'):
        for attr in json_data['discriminative_attributes']:
            attr_text = f"A photo of {class_name} with {attr}"
            tokens = tokenizer([attr_text]).to(device)
            emb = model.encode_text(tokens)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            attr_embs_list.append(emb)
    if attr_embs_list:
        attr_tensor = torch.stack(attr_embs_list).squeeze()
        attr_emb = torch.mean(attr_tensor, dim=0, keepdim=True)
        attr_emb = attr_emb / attr_emb.norm(dim=-1, keepdim=True)
    else:
        attr_emb = torch.zeros_like(global_emb)

    # --- 3.5 IMAGE FEATURE (optional) ---
    image_emb = torch.zeros_like(global_emb)
    if json_data.get('image_path'):
        if preprocess is None:
            raise ValueError("preprocess is required to compute image embeddings. Pass preprocess from load_clip_model.")
        img = Image.open(json_data['image_path']).convert('RGB')
        image_input = preprocess(img).unsqueeze(0).to(device)
        with torch.no_grad():
            image_emb = model.encode_image(image_input)
            image_emb = image_emb / image_emb.norm(dim=-1, keepdim=True)

    # --- 3.6 RELATIONAL (spatial relations) ---
    relation_embs_list = []
    if json_data.get('spatial_relations'):
        for part_name, description in json_data['spatial_relations'].items():
            part_text = f"A close-up photo of a {class_name}, with the part {description}"
            tokens = tokenizer([part_text]).to(device)
            emb = model.encode_text(tokens)
            emb = emb / emb.norm(dim=-1, keepdim=True)
            relation_embs_list.append(emb)
    if relation_embs_list:
        parts_tensor = torch.stack(relation_embs_list).squeeze()
        relation_embs = torch.mean(parts_tensor, dim=0, keepdim=True)
        relation_embs = relation_embs / relation_embs.norm(dim=-1, keepdim=True)
    else:
        relation_embs = torch.zeros_like(global_emb)

    # --- 4. FUSION ---
    final_feature = 0.4 * global_emb + 0.25 * image_emb + 0.2 * local_emb + 0.15 * attr_emb
    if include_relation_in_fusion:
        final_feature = final_feature + relation_weight * relation_embs
    final_feature = final_feature / final_feature.norm(dim=-1, keepdim=True)

    return final_feature, global_emb, local_emb, image_emb, relation_embs


In [None]:
# Smoke tests: example JSONs and runs
# Load model now (this may take a moment). Use GPU if available.
model, preprocess, tokenizer = load_clip_model()

json_no_image = {
    "class_name": "golden retriever",
    "global_description": "a friendly medium-large dog with golden coat",
    "part_details": {"face": "broad skull, kind eyes", "tail": "feathered"},
    "discriminative_attributes": ["golden coat", "friendly expression"],
    "spatial_relations": {"face": "front with eyes above nose", "tail": "rear, wagging"}
}

final_feat, g, l, img, rel = get_optimized_feature(json_no_image, model, tokenizer, device, preprocess=None)
print('No image run shapes -> final:', final_feat.shape, 'global:', g.shape, 'local:', l.shape, 'relation:', rel.shape)

# With image: download sample and run
img_url = 'https://images.unsplash.com/photo-1517841905240-472988babdf9?auto=format&fit=crop&w=400&q=80'
resp = requests.get(img_url)
open('/content/sample.jpg', 'wb').write(resp.content)
json_no_image['image_path'] = '/content/sample.jpg'
final_feat2, g2, l2, img2, rel2 = get_optimized_feature(json_no_image, model, tokenizer, device, preprocess=preprocess, include_relation_in_fusion=True, relation_weight=0.1)
print('With image run shapes -> final:', final_feat2.shape, 'image:', img2.shape, 'relation:', rel2.shape)

# Quick similarity checks
cos = torch.nn.functional.cosine_similarity
print('cos(final,global):', float(cos(final_feat2, g2).item()))
print('cos(final,image):', float(cos(final_feat2, img2).item()))
print('cos(final,relation):', float(cos(final_feat2, rel2).item()))


In [None]:
# Notebook-by-GID utilities (load/update/format/execute/diff/save)

def load_notebook(path):
    nb = nbformat.read(path, as_version=4)
    return nb

def map_gid_to_index(nb):
    mapping = {}
    for i, cell in enumerate(nb.cells):
        gid = cell.get('id') or cell.get('metadata', {}).get('gid') or f'cell-{i}'
        mapping[gid] = i
    return mapping

def update_cells_by_gid(nb, updates: dict):
    """updates: {gid: new_source_str}"""
    mapping = map_gid_to_index(nb)
    for gid, new_src in updates.items():
        idx = mapping.get(gid)
        if idx is None:
            logging.warning('gid %s not found', gid)
            continue
        nb.cells[idx].source = new_src
        nb.cells[idx].outputs = []
    return nb

def format_code_cells(nb):
    for cell in nb.cells:
        if cell.cell_type == 'code':
            try:
                cell.source = black.format_str(cell.source, mode=black.Mode())
            except Exception as e:
                logging.warning('black failed: %s', e)
    return nb

def run_notebook(nb, timeout=300):
    client = NotebookClient(nb, timeout=timeout, kernel_name='python3')
    out = client.execute()
    return nb

def notebook_diff(nb_a, nb_b):
    diffs = []
    for i, (ca, cb) in enumerate(zip(nb_a.cells, nb_b.cells)):
        if ca.source != cb.source:
            diff = '\n'.join(list(difflib.unified_diff(ca.source.splitlines(), cb.source.splitlines(), fromfile=f'a_cell_{i}', tofile=f'b_cell_{i}')))
            diffs.append((i, diff))
    return diffs

# Example usage (local path):
# nb = load_notebook('/content/my_colab.ipynb')
# nb2 = update_cells_by_gid(nb, {'some-gid': "print('new')"})
# nb2 = format_code_cells(nb2)
# nb2 = run_notebook(nb2)


In [None]:
# Final notes and short instructions
print('\nðŸ’¡ How to use this notebook:')
print('1) Run the install cell once.\n2) Run the model/test cells (may take a minute to download weights).\n3) To apply edits to a Colab notebook file: load it with load_notebook(), map gids with map_gid_to_index(), apply update_cells_by_gid(), format_code_cells(), then run_notebook() to test outputs.')
