# VRD Scene Graph with GAT-like (padded) + YOLOv8 Detection Inference

Notebook này tương thích VRD gốc, huấn luyện GraphClassifier & RelationExtractor, và inference từ ảnh qua YOLOv8 (nếu có).

## 0) Thiết lập & (tuỳ chọn) cài YOLOv8

In [None]:

# !pip -q install ultralytics
import os, json, numpy as np, math, warnings
import torch, torch.nn as nn, torch.nn.functional as F
from torch import optim
from pathlib import Path
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')
SEED=123
np.random.seed(SEED); torch.manual_seed(SEED)
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
try:
    from ultralytics import YOLO
    HAS_YOLO=True
except Exception:
    HAS_YOLO=False
    print('[INFO] YOLOv8 chưa sẵn sàng. Bạn có thể cài ultralytics để dùng detector.')


Device: cuda


## 1) Cấu hình đường dẫn VRD

In [None]:

ROOT_PATH = "H:/Download/vrd/"

VRD_ROOT = Path(ROOT_PATH)
POSSIBLE_IMG_DIRS_TRAIN=[ROOT_PATH + 'sg_train_images','images/train','images/train_images','train']
POSSIBLE_IMG_DIRS_TEST =[ROOT_PATH + 'sg_test_images','images/test','images/test_images','test']
POSSIBLE_TRAIN_JSON=[ROOT_PATH + 'sg_train_annotations.json','annotations_train.json','vrd_train.json']
POSSIBLE_TEST_JSON =[ROOT_PATH + 'sg_test_annotations.json','annotations_test.json','vrd_test.json']

def find_first_exists(base: Path, candidates):
    for c in candidates:
        p = base / c
        if p.exists(): return p
    return None

VRD_TRAIN_JSON = find_first_exists(VRD_ROOT, POSSIBLE_TRAIN_JSON)
VRD_TEST_JSON  = find_first_exists(VRD_ROOT, POSSIBLE_TEST_JSON)
VRD_TRAIN_DIR  = find_first_exists(VRD_ROOT, POSSIBLE_IMG_DIRS_TRAIN)
VRD_TEST_DIR   = find_first_exists(VRD_ROOT, POSSIBLE_IMG_DIRS_TEST)

print('VRD_TRAIN_JSON:', VRD_TRAIN_JSON)
print('VRD_TEST_JSON :', VRD_TEST_JSON)
print('VRD_TRAIN_DIR :', VRD_TRAIN_DIR)
print('VRD_TEST_DIR  :', VRD_TEST_DIR)
assert VRD_TRAIN_JSON is not None and VRD_TRAIN_DIR is not None, "Thiếu train.json/dir"
assert VRD_TEST_JSON is not None and VRD_TEST_DIR is not None, "Thiếu test.json/dir"


VRD_TRAIN_JSON: H:\Download\vrd\sg_train_annotations.json
VRD_TEST_JSON : H:\Download\vrd\sg_test_annotations.json
VRD_TRAIN_DIR : H:\Download\vrd\sg_train_images
VRD_TEST_DIR  : H:\Download\vrd\sg_test_images


## 2) Đọc annotations VRD → vocab + đồ thị

In [None]:
import json
from pathlib import Path

# 1) Đọc JSON VRD thô
def read_vrd_json(json_path: Path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return data

# 2) Chuẩn hoá: raw -> list item (mỗi item ~ 1 ảnh)
def normalize_items(raw):
    if isinstance(raw, list):
        return raw
    if isinstance(raw, dict):
        for key in ['images', 'annotations', 'data']:
            if key in raw and isinstance(raw[key], list):
                return raw[key]
    return [raw]

train_raw = read_vrd_json(VRD_TRAIN_JSON)
test_raw  = read_vrd_json(VRD_TEST_JSON)

train_ann = normalize_items(train_raw)
test_ann  = normalize_items(test_raw)

print("Train items:", len(train_ann), "| Test items:", len(test_ann))
print("Ví dụ item[0] keys:", list(train_ann[0].keys()))
# -> ['relationships', 'photo_id', 'height', 'width', 'objects', 'filename']

# ================== PARSER CHUẨN CHO SCHEMA NÀY ==================

def iter_relationships_vrd(item):
    """
    Chuẩn hoá 1 item (1 ảnh) thành các quan hệ:
      { 'subj': <tên lớp chủ thể>, 'obj': <tên lớp đối tượng>, 'predicate': <tên quan hệ> }

    Dựa trên schema:
      - item['relationships']: list các dict
        mỗi dict có:
          'text': ['woman', 'wears', 'shirt']
          'objects': [0, 1]  # không dùng cũng được
          'relationship': 'wearing'
    """
    rels = item.get('relationships', [])

    for r in rels:
        text = r.get('text', [])
        # text = [subj, verb, obj] theo ví dụ bạn gửi
        subj_name = text[0] if len(text) > 0 else 'unknown'
        obj_name  = text[2] if len(text) > 2 else 'unknown'

        # quan hệ chính: ưu tiên field 'relationship'
        rel_name = r.get('relationship')
        if rel_name is None:
            # fallback: lấy từ text[1] nếu có
            rel_name = text[1] if len(text) > 1 else 'none'

        # chuẩn hoá sang string
        subj_name = str(subj_name) if subj_name is not None else 'unknown'
        obj_name  = str(obj_name)  if obj_name  is not None else 'unknown'
        rel_name  = str(rel_name)  if rel_name  is not None else 'none'

        yield {
            'subj': subj_name,
            'obj':  obj_name,
            'predicate': rel_name
        }

# ============ XÂY VOCAB TỪ TRAIN + TEST ============

cats  = set()
preds = set()

for split in (train_ann, test_ann):
    for item in split:
        for r in iter_relationships_vrd(item):
            cats.add(r['subj'])
            cats.add(r['obj'])
            if r['predicate'] != 'none':
                preds.add(r['predicate'])

# thêm 'unknown' & sắp xếp
cats  = sorted(cats | {'unknown'})
preds = sorted(preds)

print("Số lượng categories:", len(cats))
print("Một vài category:", cats[:20])
print("Số lượng predicates (không tính 'none'):", len(preds))
print("Một vài predicates:", preds[:20])

# vocab toàn cục
NODES_VOCAB = cats
REL_VOCAB   = preds + ['none']
REL_NO_NONE = preds

node2id = {n: i for i, n in enumerate(NODES_VOCAB)}
rel2id  = {r: i for i, r in enumerate(REL_VOCAB)}
rel2id_no_none = {r: i for i, r in enumerate(REL_NO_NONE)}

def one_hot(idx, size):
    v = np.zeros(size, dtype=np.float32)
    v[idx] = 1.0
    return v


Train items: 4000 | Test items: 1000
Ví dụ item[0] keys: ['relationships', 'photo_id', 'height', 'width', 'objects', 'filename']
Số lượng categories: 6455
Một vài category: ['1', '1003', '11', '13', '148', '15', '18', '2', '2 hands', '20', '2012', '2013', '2375', '3', '3pm', '4', '40', '50', '6', '6am']
Số lượng predicates (không tính 'none'): 1310
Một vài predicates: ['ab', 'aboe', 'aboev', 'about', 'about to catch', 'about to cut', 'about to hit', 'abouve', 'above', 'aboveq', 'abovw', 'across', 'across from', 'acrosss', 'acrosst', 'adjacent ot', 'adjacent to', 'adjusting', 'adjusts', 'admiring']


## 3) Data splits & batching

In [None]:
def build_graph_from_item(item, max_nodes=32):
    """
    Xây đồ thị cho 1 ảnh:
      - node: category (tên lớp) xuất hiện trong các quan hệ
      - edge: (subj, obj) nếu có quan hệ giữa 2 lớp
      - R[i,j]: id predicate (nếu không, 'none')
      - y: nhãn graph demo (1 nếu có 'person' -- 'wearing' --> 'shirt', v.v. bạn có thể chỉnh)
    """
    rels_norm = list(iter_relationships_vrd(item))

    # gom list node
    nodes_list = []
    for r in rels_norm:
        nodes_list.append(r['subj'])
        nodes_list.append(r['obj'])
    nodes = sorted(set(nodes_list))[:max_nodes]
    if not nodes:
        nodes = ['unknown']

    N = len(nodes)
    X = np.stack([one_hot(node2id[n], len(NODES_VOCAB)) for n in nodes], axis=0)
    A = np.zeros((N, N), dtype=np.float32)
    R = np.full((N, N), rel2id['none'], dtype=np.int64)

    for r in rels_norm:
        s = r['subj']
        o = r['obj']
        p = r['predicate']
        if s in nodes and o in nodes and s != o and p in rel2id:
            i = nodes.index(s)
            j = nodes.index(o)
            A[i, j] = A[j, i] = 1.0
            if p in rel2id_no_none:
                R[i, j] = R[j, i] = rel2id[p]

    # Nhãn graph demo (tuỳ bạn định nghĩa):
    # Ví dụ: 1 nếu có 'woman' -- 'wearing' --> 'shirt'
    y = 0
    for r in rels_norm:
        if (r['predicate'] == 'wearing' and
            r['subj'] == 'woman' and
            r['obj'] == 'shirt'):
            y = 1
            break

    img_name = item.get('filename') or item.get('file_name') or f"{item.get('photo_id')}.jpg"
    return A, X, R, y, nodes, img_name


In [None]:
def make_dataset(ann_list):
    return [build_graph_from_item(item) for item in ann_list]

train_data_all = make_dataset(train_ann)
test_data      = make_dataset(test_ann)

def split_list(L, va_ratio=0.15, seed=SEED):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(L)); rng.shuffle(idx)
    v = int(len(L) * va_ratio)
    return [L[i] for i in idx[v:]], [L[i] for i in idx[:v]]

train_data, val_data = split_list(train_data_all, va_ratio=0.15, seed=SEED)

print("Train:", len(train_data), "| Val:", len(val_data), "| Test:", len(test_data))
A, X, R, y, nodes, img_name = train_data[0]
print("Example graph shapes:", A.shape, X.shape, R.shape,
      "| y:", y,
      "| nodes sample:", nodes[:5],
      "| img:", img_name)

Train: 3400 | Val: 600 | Test: 1000
Example graph shapes: (10, 10) (10, 6455) (10, 10) | y: 0 | nodes sample: ['boy', 'goats', 'hat', 'pants', 'road'] | img: 5101356337_86fe7ee94d_b.jpg


In [None]:
def to_tensors_padded(batch, device):
    B=len(batch); Ns=[b[0].shape[0] for b in batch]; Nmax=max(Ns); F=batch[0][1].shape[1]
    A=np.zeros((B,Nmax,Nmax),dtype=np.float32)
    X=np.zeros((B,Nmax,F),dtype=np.float32)
    R=np.full((B,Nmax,Nmax), rel2id['none'], dtype=np.int64)
    M=np.zeros((B,Nmax),dtype=np.float32)
    y=np.zeros(B,dtype=np.int64); nodes_list=[]; img_names=[]
    for i,(Ai,Xi,Ri,yi,nodes,img_name) in enumerate(batch):
        n=Ai.shape[0]
        A[i,:n,:n]=Ai; X[i,:n,:]=Xi; R[i,:n,:n]=Ri; M[i,:n]=1.0; y[i]=yi
        nodes_list.append(nodes); img_names.append(img_name)
    return (torch.tensor(A,dtype=torch.float32,device=device),
            torch.tensor(X,dtype=torch.float32,device=device),
            torch.tensor(R,dtype=torch.long,device=device),
            torch.tensor(M,dtype=torch.float32,device=device),
            torch.tensor(y,dtype=torch.long,device=device),
            nodes_list, img_names)

def batches(data, bs=32):
    for i in range(0, len(data), bs):
        yield data[i:i+bs]


## 4) Models

In [None]:

class RelGATLayer(nn.Module):
    def __init__(self,in_dim,out_dim,n_rel,heads=2,dropout=0.1):
        super().__init__()
        self.heads=heads; self.dk=out_dim//heads; assert out_dim%heads==0
        self.Wq=nn.Linear(in_dim,out_dim,bias=False)
        self.Wk=nn.Linear(in_dim,out_dim,bias=False)
        self.Wv=nn.Linear(in_dim,out_dim,bias=False)
        self.rel_emb=nn.Embedding(n_rel, self.dk)
        self.dp=nn.Dropout(dropout)
    def forward(self,X,A,R,mask):
        B,N,F=X.shape; H=self.heads; dk=self.dk
        X=X*mask.unsqueeze(-1)
        Q=self.Wq(X).view(B,N,H,dk); K=self.Wk(X).view(B,N,H,dk); V=self.Wv(X).view(B,N,H,dk)
        rel=self.rel_emb(R).view(B,N,N,dk)
        logits=torch.einsum('bnhd,bmhd->bhnm',Q,K)/(dk**0.5)
        logits=logits+torch.einsum('bnhd,bnmd->bhnm',Q,rel)
        edge_mask=(A==0).unsqueeze(1); pad_mask=(mask==0).unsqueeze(1).unsqueeze(2)
        logits=logits.masked_fill(edge_mask|pad_mask,-1e9)
        attn=torch.softmax(logits,dim=-1); attn=self.dp(attn)
        out=torch.einsum('bhnm,bmhd->bnhd',attn,V).contiguous().view(B,N,H*dk)
        out=out*mask.unsqueeze(-1); return out,attn

class GraphClassifier(nn.Module):
    def __init__(self,in_dim,hid_dim,n_rel,heads=2):
        super().__init__()
        self.g1=RelGATLayer(in_dim,hid_dim,n_rel,heads=heads,dropout=0.1)
        self.g2=RelGATLayer(hid_dim,hid_dim,n_rel,heads=heads,dropout=0.1)
        self.cls=nn.Sequential(nn.Linear(hid_dim,hid_dim), nn.ReLU(), nn.Linear(hid_dim,2))
    def forward(self,X,A,R,mask):
        H1,_=self.g1(X,A,R,mask); H1=F.relu(H1)
        H2,_=self.g2(H1,A,R,mask); H2=F.relu(H2)
        m=mask.unsqueeze(-1); g=(H2*m).sum(dim=1)/m.sum(dim=1).clamp_min(1e-6)
        return self.cls(g)

class RelationExtractor(nn.Module):
    def __init__(self,in_dim,hid_dim,n_rel,heads=2):
        super().__init__()
        self.g1=RelGATLayer(in_dim,hid_dim,n_rel,heads=heads,dropout=0.1)
        self.g2=RelGATLayer(hid_dim,hid_dim,n_rel,heads=heads,dropout=0.1)
        self.edge_mlp=nn.Sequential(nn.Linear(2*hid_dim,hid_dim), nn.ReLU(), nn.Linear(hid_dim, len(REL_NO_NONE)))
    def forward(self,X,A,R,mask):
        H1,_=self.g1(X,A,R,mask); H1=F.relu(H1)
        H2,_=self.g2(H1,A,R,mask); H2=F.relu(H2)
        return H2
    def edge_logits(self,H,pairs):
        B,P,_=pairs.shape; i=pairs[...,0]; j=pairs[...,1]
        Hi=H.gather(1, i.unsqueeze(-1).expand(-1,-1,H.size(-1)))
        Hj=H.gather(1, j.unsqueeze(-1).expand(-1,-1,H.size(-1)))
        feat=torch.cat([Hi,Hj],dim=-1)
        return self.edge_mlp(feat)


## 5) Train loops

In [None]:

gc_model=GraphClassifier(in_dim=len(NODES_VOCAB), hid_dim=64, n_rel=len(REL_VOCAB), heads=2).to(device)
gc_opt=optim.Adam(gc_model.parameters(), lr=1e-3, weight_decay=1e-4)
crit=nn.CrossEntropyLoss()

def eval_graph_acc(dataset, bs=64):
    gc_model.eval(); correct=0; total=0
    with torch.no_grad():
        for batch in batches(dataset, bs):
            A,X,R,M,y,_,_=to_tensors_padded(batch, device)
            pred=gc_model(X,A,R,M).argmax(1)
            correct += (pred==y).sum().item(); total+=y.numel()
    return correct/total if total>0 else 0.0

best=(0.0,None)
for ep in range(1,11):
    gc_model.train(); losses=[]
    for batch in batches(train_data,64):
        A,X,R,M,y,_,_=to_tensors_padded(batch, device)
        logits=gc_model(X,A,R,M); loss=crit(logits,y)
        gc_opt.zero_grad(); loss.backward(); gc_opt.step(); losses.append(loss.item())
    va=eval_graph_acc(val_data)
    if va>best[0]: best=(va, {k:v.detach().cpu().clone() for k,v in gc_model.state_dict().items()})
    print(f"Epoch {ep:02d} | loss={np.mean(losses):.4f} | val_graph_acc={va:.3f}")
if best[1] is not None: gc_model.load_state_dict(best[1])
te=eval_graph_acc(test_data); print('Graph Task — Best Val:', round(best[0],3), '| Test:', round(te,3))

re_model=RelationExtractor(in_dim=len(NODES_VOCAB), hid_dim=128, n_rel=len(REL_VOCAB), heads=2).to(device)
re_opt=optim.Adam(re_model.parameters(), lr=1e-3, weight_decay=1e-4)
edge_crit=nn.CrossEntropyLoss()

def build_edge_targets(A,R,M):
    B,N,_=A.shape; pairs_list=[]; targets_list=[]
    for b in range(B):
        valid=(M[b]>0.5).nonzero(as_tuple=False).flatten()
        if valid.numel()<2: continue
        Ai=A[b][valid][:,valid]; Ri=R[b][valid][:,valid]
        idxs=(Ai>0.5).nonzero(as_tuple=False)
        keep=[]; targ=[]
        for ii,jj in idxs.tolist():
            rel_id=int(Ri[ii,jj].item()); rn=REL_VOCAB[rel_id]
            if rn=='none': continue
            keep.append([ii,jj]); targ.append(rel2id_no_none[rn])
        if not keep: continue
        pairs=valid[torch.tensor(keep, dtype=torch.long, device=A.device)]
        pairs_list.append(pairs)
        targets_list.append(torch.tensor(targ, dtype=torch.long, device=A.device))
    if not pairs_list: return None, None
    P=min(p.size(0) for p in pairs_list)
    pairs=torch.stack([p[:P] for p in pairs_list], dim=0); targets=torch.stack([t[:P] for t in targets_list], dim=0)
    return pairs, targets

def evaluate_edge_acc(dataset, bs=48):
    re_model.eval(); correct=0; total=0
    with torch.no_grad():
        for batch in batches(dataset, bs):
            A,X,R,M,_,_,_ = to_tensors_padded(batch, device)
            H=re_model(X,A,R,M)
            pairs,targets=build_edge_targets(A,R,M)
            if pairs is None: continue
            logits=re_model.edge_logits(H,pairs)
            pred=logits.argmax(-1)
            correct += (pred==targets).sum().item(); total+=targets.numel()
    return correct/total if total>0 else 0.0

best=(0.0,None)
for ep in range(1,16):
    re_model.train(); losses=[]
    for batch in batches(train_data,48):
        A,X,R,M,_,_,_=to_tensors_padded(batch, device)
        H=re_model(X,A,R,M)
        pairs,targets=build_edge_targets(A,R,M)
        if pairs is None: continue
        logits=re_model.edge_logits(H,pairs)
        loss=edge_crit(logits.view(-1, logits.size(-1)), targets.view(-1))
        re_opt.zero_grad(); loss.backward(); re_opt.step(); losses.append(loss.item())
    va=evaluate_edge_acc(val_data, bs=48)
    if va>best[0]: best=(va, {k:v.detach().cpu().clone() for k,v in re_model.state_dict().items()})
    print(f"Epoch {ep:02d} | edge_loss={np.mean(losses):.4f} | val_edge_acc={va:.3f}")
if best[1] is not None: re_model.load_state_dict(best[1])
te=evaluate_edge_acc(test_data, bs=48)
print('Edge Task — Best Val:', round(best[0],3), '| Test:', round(te,3))


Epoch 01 | loss=0.4725 | val_graph_acc=0.942
Epoch 02 | loss=0.2063 | val_graph_acc=0.942
Epoch 03 | loss=0.1665 | val_graph_acc=0.942
Epoch 04 | loss=0.1342 | val_graph_acc=0.942
Epoch 05 | loss=0.0953 | val_graph_acc=0.945
Epoch 06 | loss=0.0557 | val_graph_acc=0.960
Epoch 07 | loss=0.0295 | val_graph_acc=0.967
Epoch 08 | loss=0.0169 | val_graph_acc=0.972
Epoch 09 | loss=0.0100 | val_graph_acc=0.967
Epoch 10 | loss=0.0080 | val_graph_acc=0.970
Graph Task — Best Val: 0.972 | Test: 0.966
Epoch 01 | edge_loss=4.9687 | val_edge_acc=0.218
Epoch 02 | edge_loss=3.0134 | val_edge_acc=0.218
Epoch 03 | edge_loss=2.9626 | val_edge_acc=0.218
Epoch 04 | edge_loss=2.9413 | val_edge_acc=0.218
Epoch 05 | edge_loss=2.9121 | val_edge_acc=0.222
Epoch 06 | edge_loss=2.8242 | val_edge_acc=0.264
Epoch 07 | edge_loss=2.7626 | val_edge_acc=0.269
Epoch 08 | edge_loss=2.7255 | val_edge_acc=0.272
Epoch 09 | edge_loss=2.7019 | val_edge_acc=0.272
Epoch 10 | edge_loss=2.6806 | val_edge_acc=0.272
Epoch 11 | edge_l

## 6) YOLOv8 detection → inference từ ảnh

In [None]:

from PIL import Image

def map_det_name(name: str):
    # map tên lớp YOLO về vocab VRD; nếu không có thì 'unknown'
    return name if name in node2id else 'unknown'

def yolo_detect(image_path: Path, conf=0.25, model_name='yolov8n.pt'):
    if not HAS_YOLO:
        raise RuntimeError("YOLOv8 chưa sẵn sàng. Hãy cài ultralytics hoặc đặt HAS_YOLO=True với model đã load.")
    model = YOLO(model_name)
    res = model.predict(source=str(image_path), conf=conf, verbose=False)[0]
    dets = []  # [(name, [x1,y1,x2,y2], score)]
    for b in res.boxes:
        cls_id = int(b.cls.item())
        name = res.names.get(cls_id, str(cls_id))
        xyxy = b.xyxy[0].tolist()
        score = float(b.conf.item())
        dets.append((name, xyxy, score))
    return dets

def build_graph_from_dets(dets, max_nodes=20):
    """
    dets: list (name, bbox, score)
    Trả về: A, X, R, nodes, pairs
    """
    names = [map_det_name(n) for (n, _, _) in dets]
    # nodes = các lớp (category) xuất hiện
    nodes = sorted(set(names))[:max_nodes]
    if not nodes:
        # không có đối tượng nào
        return (
            np.zeros((0, 0), dtype=np.float32),
            np.zeros((0, len(NODES_VOCAB)), dtype=np.float32),
            np.zeros((0, 0), dtype=np.int64),
            [],
            []
        )

    N = len(nodes)
    X = np.stack([one_hot(node2id[n], len(NODES_VOCAB)) for n in nodes], axis=0)
    A = np.zeros((N, N), dtype=np.float32)
    R = np.full((N, N), rel2id['none'], dtype=np.int64)

    pairs = []
    # Ở đây cho full graph; bạn có thể thay bằng nối theo proximity nếu muốn
    for i in range(N):
        for j in range(N):
            if i != j:
                A[i, j] = A[j, i] = 1.0
                pairs.append((i, j))

    return A, X, R, nodes, pairs

def infer_image(image_path: Path,
                yolo_model='yolov8n.pt',
                conf=0.25,
                topk=10,
                threshold=0.2):
    """
    Trả về: strings, triples, nodes, dets
      - strings: list chuỗi "<object> <relation> <subject>"
      - triples: list (subj, rel, obj, confidence)
      - nodes:   danh sách node (category) trong graph
      - dets:    detections gốc từ YOLO hoặc anotations VRD (fallback)
    """

    # 1) Lấy detections
    if HAS_YOLO:
        dets = yolo_detect(image_path, conf=conf, model_name=yolo_model)
    else:
        dets = []
        # fallback: dùng anotations VRD nếu trùng tên file
        for split in (train_ann, test_ann):
            for item in split:
                fn = item.get('file_name') or f"{item.get('image_id')}.jpg"
                if fn == image_path.name:
                    for r in item.get('relationships', item.get('rels', [])):
                        dets.append(
                            (safe_cat(r.get('subject', {}).get('category', 'unknown')), None, 1.0)
                        )
                        dets.append(
                            (safe_cat(r.get('object', {}).get('category', 'unknown')), None, 1.0)
                        )
                    break

    # Nếu hoàn toàn không có dets → trả về rỗng nhưng đủ 4 phần
    if len(dets) == 0:
        print(f"[WARN] Không tìm thấy detection / anotations nào cho ảnh: {image_path}")
        return [], [], [], dets

    # 2) Xây graph từ dets
    A, X, R, nodes, pairs = build_graph_from_dets(dets)
    n = A.shape[0]

    if n == 0 or len(pairs) == 0:
        print(f"[WARN] Graph sau khi build không có node/cạnh hợp lệ cho ảnh: {image_path}")
        return [], [], nodes, dets

    A_t = torch.tensor(A, dtype=torch.float32, device=device).unsqueeze(0)
    X_t = torch.tensor(X, dtype=torch.float32, device=device).unsqueeze(0)
    R_t = torch.tensor(R, dtype=torch.long,    device=device).unsqueeze(0)
    M_t = torch.ones(1, n, dtype=torch.float32, device=device)

    # 3) Chạy RelationExtractor
    re_model.eval()
    with torch.no_grad():
        H = re_model(X_t, A_t, R_t, M_t)
        pairs_t = torch.tensor(pairs, dtype=torch.long, device=device).unsqueeze(0)
        logits = re_model.edge_logits(H, pairs_t).squeeze(0)  # [P, |REL_NO_NONE|]
        probs = torch.softmax(logits, dim=-1)
        confs, preds = probs.max(dim=-1)

    # 4) Lọc theo threshold + sort
    triples = []
    for (i, j), c, p in zip(
        pairs,
        confs.cpu().numpy().tolist(),
        preds.cpu().numpy().tolist()
    ):
        rel = REL_NO_NONE[p]
        s = nodes[i]
        o = nodes[j]
        if (threshold is not None) and (c < threshold):
            continue
        triples.append((s, rel, o, float(c)))

    triples.sort(key=lambda x: -x[3])
    if topk is not None:
        triples = triples[:topk]

    strings = [f"{subj} {rel} {obj}" for (subj, rel, obj, _) in triples]
    return strings, triples, nodes, dets



### 7) Demo: suy luận trên 1 ảnh test

In [None]:

# Chọn ảnh từ thư mục test và chạy
#############################################################################
sample = test_ann[0]
img_name = '116298453_57a957315a_o.jpg'
img_path = (Path(VRD_TEST_DIR) / img_name)
print('Image:', img_path)

strings, triples, nodes, dets = infer_image(
    img_path,
    yolo_model='yolov8n.pt',
    conf=0.25,
    topk=15,
    threshold=0.195
)
print('Nodes:', nodes)
print('Predicted triples:')
for subj, rel, obj, c in triples:
    print(f"  {subj} {rel} {obj}  (conf={c:.3f})")
#for s in strings:
#    print('  ', s)

Image: H:\Download\vrd\sg_test_images\116298453_57a957315a_o.jpg
Nodes: ['book', 'keyboard', 'mouse', 'tv']
Predicted triples:
  mouse on keyboard  (conf=0.199)
  tv on keyboard  (conf=0.195)


# Task
Implement a proximity-based graph construction method for scene graph generation within the existing GCN+YOLO framework, enhance the YOLO-VRD class name mapping, and then compare its inference performance against the full graph method on selected VRD test images. Perform a qualitative analysis by visually comparing predicted scene graphs with ground truth annotations and summarize the findings, including the effectiveness of both graph construction approaches and the overall GCN+YOLO model performance.

## Implement Proximity Graph

### Subtask:
Implement a new function `build_graph_from_dets_proximity` to construct an adjacency matrix `A` based on the proximity of detected bounding boxes, instead of creating a full graph.


**Reasoning**:
I need to implement the `build_graph_from_dets_proximity` function as described in the instructions, which involves extracting bounding box information, calculating distances, and constructing the adjacency matrix based on proximity. This code block will contain the complete implementation of the function.



In [None]:
def get_bbox_center(bbox):
    x1, y1, x2, y2 = bbox
    return (x1 + x2) / 2, (y1 + y2) / 2

def euclidean_distance(pt1, pt2):
    return math.sqrt((pt1[0] - pt2[0])**2 + (pt1[1] - pt2[1])**2)

def build_graph_from_dets_proximity(dets, max_nodes=20, threshold_dist=50):
    """
    dets: list (name, bbox, score)
    Trả về: A, X, R, nodes, pairs
    """

    # 1) Extract object names, bboxes, and map to VRD vocab
    # Store (mapped_name, bbox) for each detection
    det_info = []
    for name, bbox, _ in dets:
        mapped_name = map_det_name(name)
        det_info.append((mapped_name, bbox))

    # 2) Create a list of unique nodes (categories) present in the detections
    all_category_names = [info[0] for info in det_info if info[0] != 'unknown']
    # Only consider `unknown` if there are no other categories
    if not all_category_names and det_info:
        all_category_names = ['unknown']

    nodes = sorted(list(set(all_category_names)))[:max_nodes]

    # 3) If no nodes are found, return empty arrays
    if not nodes:
        return (
            np.zeros((0, 0), dtype=np.float32),
            np.zeros((0, len(NODES_VOCAB)), dtype=np.float32),
            np.zeros((0, 0), dtype=np.int64),
            [],
            []
        )

    N = len(nodes)

    # 4) Generate node feature matrix X
    X = np.stack([one_hot(node2id[n], len(NODES_VOCAB)) for n in nodes], axis=0)

    # 5) Initialize adjacency matrix A and relation matrix R
    A = np.zeros((N, N), dtype=np.float32)
    R = np.full((N, N), rel2id['none'], dtype=np.int64)
    pairs = []

    # 6) Iterate through all unique pairs of nodes
    for i in range(N):
        for j in range(i + 1, N):
            node_i_name = nodes[i]
            node_j_name = nodes[j]

            # Find all detections associated with nodes[i] and nodes[j]
            dets_node_i_bboxes = [info[1] for info in det_info if info[0] == node_i_name]
            dets_node_j_bboxes = [info[1] for info in det_info if info[0] == node_j_name]

            if not dets_node_i_bboxes or not dets_node_j_bboxes:
                continue # No detections for one or both nodes, skip proximity check

            min_dist = float('inf')
            for bbox_i in dets_node_i_bboxes:
                center_i = get_bbox_center(bbox_i)
                for bbox_j in dets_node_j_bboxes:
                    center_j = get_bbox_center(bbox_j)
                    dist = euclidean_distance(center_i, center_j)
                    min_dist = min(min_dist, dist)

            # 7) If minimum distance is less than threshold_dist, set A[i,j] and add to pairs
            if min_dist < threshold_dist:
                A[i, j] = 1.0
                A[j, i] = 1.0 # Adjacency is symmetric
                pairs.append((i, j))
                # R is not filled here as this is just for building the graph structure

    return A, X, R, nodes, pairs

print("Defined build_graph_from_dets_proximity function.")

## Integrate Proximity Graph into Inference

### Subtask:
Cập nhật hàm `infer_image` để có thể sử dụng ma trận kề A được xây dựng từ `build_graph_from_dets_proximity`. Điều này sẽ cho phép so sánh kết quả suy luận giữa hai phương pháp xây dựng đồ thị.


In [2]:
from PIL import Image
from pathlib import Path # Added this import

def map_det_name(name: str):
    # map tên lớp YOLO về vocab VRD; nếu không có thì 'unknown'
    return name if name in node2id else 'unknown'

def yolo_detect(image_path: Path, conf=0.25, model_name='yolov8n.pt'):
    if not HAS_YOLO:
        raise RuntimeError("YOLOv8 chưa sẵn sàng. Hãy cài ultralytics hoặc đặt HAS_YOLO=True với model đã load.")
    model = YOLO(model_name)
    res = model.predict(source=str(image_path), conf=conf, verbose=False)[0]
    dets = []  # [(name, [x1,y1,x2,y2], score)]
    for b in res.boxes:
        cls_id = int(b.cls.item())
        name = res.names.get(cls_id, str(cls_id))
        xyxy = b.xyxy[0].tolist()
        score = float(b.conf.item())
        dets.append((name, xyxy, score))
    return dets

def build_graph_from_dets(dets, max_nodes=20):
    """
    dets: list (name, bbox, score)
    Trả về: A, X, R, nodes, pairs
    """
    names = [map_det_name(n) for (n, _, _) in dets]
    # nodes = các lớp (category) xuất hiện
    nodes = sorted(set(names))[:max_nodes]
    if not nodes:
        # không có đối tượng nào
        return (
            np.zeros((0, 0), dtype=np.float32),
            np.zeros((0, len(NODES_VOCAB)), dtype=np.float32),
            np.zeros((0, 0), dtype=np.int64),
            [],
            []
        )

    N = len(nodes)
    X = np.stack([one_hot(node2id[n], len(NODES_VOCAB)) for n in nodes], axis=0)
    A = np.zeros((N, N), dtype=np.float32)
    R = np.full((N, N), rel2id['none'], dtype=np.int64)

    pairs = []
    # Ở đây cho full graph; bạn có thể thay bằng nối theo proximity nếu muốn
    for i in range(N):
        for j in range(N):
            if i != j:
                A[i, j] = A[j, i] = 1.0
                pairs.append((i, j))

    return A, X, R, nodes, pairs

def infer_image(image_path: Path,
                yolo_model='yolov8n.pt',
                conf=0.25,
                topk=10,
                threshold=0.2,
                graph_construction_method='full', # New parameter
                proximity_threshold=50):            # New parameter
    """
    Trả về: strings, triples, nodes, dets
      - strings: list chuỗi "<object> <relation> <subject>"
      - triples: list (subj, rel, obj, confidence)
      - nodes:   danh sách node (category) trong graph
      - dets:    detections gốc từ YOLO hoặc anotations VRD (fallback)
    """

    # 1) Lấy detections
    if HAS_YOLO:
        dets = yolo_detect(image_path, conf=conf, model_name=yolo_model)
    else:
        dets = []
        # fallback: dùng anotations VRD nếu trùng tên file
        for split in (train_ann, test_ann):
            for item in split:
                fn = item.get('file_name') or f"{item.get('image_id')}.jpg"
                if fn == image_path.name:
                    for r in item.get('relationships', item.get('rels', [])):
                        dets.append(
                            (safe_cat(r.get('subject', {}).get('category', 'unknown')), None, 1.0)
                        )
                        dets.append(
                            (safe_cat(r.get('object', {}).get('category', 'unknown')), None, 1.0)
                        )
                    break

    # Nếu hoàn toàn không có dets → trả về rỗng nhưng đủ 4 phần
    if len(dets) == 0:
        print(f"[WARN] Không tìm thấy detection / anotations nào cho ảnh: {image_path}")
        return [], [], [], dets

    # 2) Xây graph từ dets, chọn phương pháp xây dựng đồ thị
    if graph_construction_method == 'proximity':
        A, X, R, nodes, pairs = build_graph_from_dets_proximity(dets, threshold_dist=proximity_threshold)
    else:
        A, X, R, nodes, pairs = build_graph_from_dets(dets)

    n = A.shape[0]

    if n == 0 or len(pairs) == 0:
        print(f"[WARN] Graph sau khi build không có node/cạnh hợp lệ cho ảnh: {image_path}")
        return [], [], nodes, dets

    A_t = torch.tensor(A, dtype=torch.float32, device=device).unsqueeze(0)
    X_t = torch.tensor(X, dtype=torch.float32, device=device).unsqueeze(0)
    R_t = torch.tensor(R, dtype=torch.long,    device=device).unsqueeze(0)
    M_t = torch.ones(1, n, dtype=torch.float32, device=device)

    # 3) Chạy RelationExtractor
    re_model.eval()
    with torch.no_grad():
        H = re_model(X_t, A_t, R_t, M_t)
        pairs_t = torch.tensor(pairs, dtype=torch.long, device=device).unsqueeze(0)
        logits = re_model.edge_logits(H, pairs_t).squeeze(0)  # [P, |REL_NO_NONE|]
        probs = torch.softmax(logits, dim=-1)
        confs, preds = probs.max(dim=-1)

    # 4) Lọc theo threshold + sort
    triples = []
    for (i, j), c, p in zip(
        pairs,
        confs.cpu().numpy().tolist(),
        preds.cpu().numpy().tolist()
    ):
        rel = REL_NO_NONE[p]
        s = nodes[i]
        o = nodes[j]
        if (threshold is not None) and (c < threshold):
            continue
        triples.append((s, rel, o, float(c)))

    triples.sort(key=lambda x: -x[3])
    if topk is not None:
        triples = triples[:topk]

    strings = [f"{subj} {rel} {obj}" for (subj, rel, obj, _) in triples]
    return strings, triples, nodes, dets

## Enhance YOLO-VRD Mapping

### Subtask:
Mở rộng hàm `map_det_name` để ánh xạ các tên lớp từ YOLO (ví dụ: 'tv', 'bike') sang các lớp gần nghĩa trong VRD (ví dụ: 'television', 'bicycle').


In [3]:
YOLO_TO_VRD_MAPPING = {
    'tv': 'television',
    'bike': 'bicycle',
    'person': 'man', # Mapping 'person' to 'man' as a common equivalent in VRD, can be adjusted
    'car': 'automobile',
    'bus': 'bus',
    'truck': 'truck',
    'motorcycle': 'motorcycle',
    'airplane': 'plane',
    'boat': 'boat',
    'traffic light': 'traffic light',
    'fire hydrant': 'fire hydrant',
    'stop sign': 'stop sign',
    'parking meter': 'parking meter',
    'bench': 'bench',
    'bird': 'bird',
    'cat': 'cat',
    'dog': 'dog',
    'horse': 'horse',
    'sheep': 'sheep',
    'cow': 'cow',
    'elephant': 'elephant',
    'bear': 'bear',
    'zebra': 'zebra',
    'giraffe': 'giraffe',
    'backpack': 'backpack',
    'umbrella': 'umbrella',
    'handbag': 'handbag',
    'tie': 'tie',
    'suitcase': 'suitcase',
    'frisbee': 'frisbee',
    'skis': 'skis',
    'snowboard': 'snowboard',
    'sports ball': 'ball',
    'kite': 'kite',
    'baseball bat': 'baseball bat',
    'baseball glove': 'baseball glove',
    'skateboard': 'skateboard',
    'surfboard': 'surfboard',
    'tennis racket': 'tennis racket',
    'bottle': 'bottle',
    'wine glass': 'wine glass',
    'cup': 'cup',
    'fork': 'fork',
    'knife': 'knife',
    'spoon': 'spoon',
    'bowl': 'bowl',
    'banana': 'banana',
    'apple': 'apple',
    'sandwich': 'sandwich',
    'orange': 'orange',
    'broccoli': 'broccoli',
    'carrot': 'carrot',
    'hot dog': 'hot dog',
    'pizza': 'pizza',
    'donut': 'donut',
    'cake': 'cake',
    'chair': 'chair',
    'couch': 'couch',
    'potted plant': 'plant',
    'bed': 'bed',
    'dining table': 'table',
    'toilet': 'toilet',
    'laptop': 'laptop',
    'mouse': 'mouse',
    'remote': 'remote',
    'keyboard': 'keyboard',
    'cell phone': 'cell phone',
    'microwave': 'microwave',
    'oven': 'oven',
    'toaster': 'toaster',
    'sink': 'sink',
    'refrigerator': 'refrigerator',
    'book': 'book',
    'clock': 'clock',
    'vase': 'vase',
    'scissors': 'scissors',
    'teddy bear': 'teddy bear',
    'hair drier': 'hair drier',
    'toothbrush': 'toothbrush'
}

def map_det_name(name: str):
    # 1. Check if the name exists in YOLO_TO_VRD_MAPPING
    mapped_name = YOLO_TO_VRD_MAPPING.get(name, name)

    # 2. Then, check if this (either original or mapped) name exists in node2id
    if mapped_name in node2id:
        return mapped_name

    # 3. If the name is not found in the mapping or node2id, return 'unknown'
    return 'unknown'

print('YOLO_TO_VRD_MAPPING and updated map_det_name function defined.')

YOLO_TO_VRD_MAPPING and updated map_det_name function defined.
