# Installation
Install RelTR and import necessary packages.

In [1]:
!git clone https://github.com/yrcong/RelTR.git
%cd RelTR/
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
import requests
import matplotlib.pyplot as plt

Cloning into 'RelTR'...
remote: Enumerating objects: 317, done.[K
remote: Counting objects: 100% (317/317), done.[K
remote: Compressing objects: 100% (175/175), done.[K
remote: Total 317 (delta 140), reused 282 (delta 120), pack-reused 0[K
Receiving objects: 100% (317/317), 27.49 MiB | 55.09 MiB/s, done.
Resolving deltas: 100% (140/140), done.
/kaggle/working/RelTR


# VG labels
VG 150 enitiy classes and 50 relationship classes.

In [3]:
CLASSES = [ 'N/A', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike',
                'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building',
                'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup',
                'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence',
                'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy',
                'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean',
                'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men',
                'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw',
                'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post',
                'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt',
                'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow',
                'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel',
                'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle',
                'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']

REL_CLASSES = ['__background__', 'above', 'across', 'against', 'along', 'and', 'at', 'attached to', 'behind',
                'belonging to', 'between', 'carrying', 'covered in', 'covering', 'eating', 'flying in', 'for',
                'from', 'growing on', 'hanging from', 'has', 'holding', 'in', 'in front of', 'laying on',
                'looking at', 'lying on', 'made of', 'mounted on', 'near', 'of', 'on', 'on back of', 'over',
                'painted on', 'parked on', 'part of', 'playing', 'riding', 'says', 'sitting on', 'standing on',
                'to', 'under', 'using', 'walking in', 'walking on', 'watching', 'wearing', 'wears', 'with']


# Build and load the pretrained model

In [8]:
from models.backbone import Backbone, Joiner
from models.position_encoding import PositionEmbeddingSine
from models.transformer import Transformer
from models.reltr import RelTR

position_embedding = PositionEmbeddingSine(128, normalize=True)
backbone = Backbone('resnet50', False, False, False)
backbone = Joiner(backbone, position_embedding)
backbone.num_channels = 2048

transformer = Transformer(d_model=256, dropout=0.1, nhead=8,
                          dim_feedforward=2048,
                          num_encoder_layers=6,
                          num_decoder_layers=6,
                          normalize_before=False,
                          return_intermediate_dec=True)

model = RelTR(backbone, transformer, num_classes=151, num_rel_classes = 51,
              num_entities=100, num_triplets=200)

# The checkpoint is pretrained on Visual Genome
# ckpt = torch.hub.load_state_dict_from_url(
#     url='https://cloud.tnt.uni-hannover.de/index.php/s/PB8xTKspKZF7fyK/download/checkpoint0149.pth',
#     map_location='cpu', check_hash=True)
ckpt = torch.load("/kaggle/input/reitr/other/default/1/checkpoint0149.pth") #,map_location='cpu')
model.load_state_dict(ckpt['model'])
model.eval()

RelTR(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
          (self_attn_entity): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features

In [9]:
# Some transformation functions
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
          (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

# Load Image
You can replace the link with other images. Note that the entities in the used image should be included in the VG labels.

In [None]:
im = Image.open('/kaggle/input/data-hcm-ai-challenge-2023-batch-1/Data/Keyframes_L01/L01_V001/000470.jpg')
plt.imshow(im)
plt.axis('off')  # Hide axes
plt.show()

In [None]:
#url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Yellow_Happy.jpg/1200px-Yellow_Happy.jpg'

# Apply the transformations
img = transform(im).unsqueeze(0)

# Print the shape of the transformed image tensor
print(img.shape)

# Inference

In [13]:
import os
import glob
import torch
from PIL import Image
from tqdm import tqdm
# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Assuming other necessary imports and initializations are here

# Paths
input_images_path = "/kaggle/input/data-hcm-ai-challenge-2023-batch-1/Data/Keyframes_L01/L01_V003"  # Directory containing input images
output_directory = "/kaggle/working/output_file_test2"  # Directory to save output text files

# Ensure the output directory exists
os.makedirs(output_directory, exist_ok=True)

# Function to process a single image
def process_image(img_path, model):
    im = Image.open(img_path).convert("RGB")
    img = transform(im).unsqueeze(0)#.to(device)  # Adjust based on your model's input requirements
    texts = []
    with torch.no_grad():
        # propagate through the model
        
        outputs = model(img)

        # keep only predictions with >0.3 confidence
        probas = outputs['rel_logits'].softmax(-1)[0, :, :-1]
        probas_sub = outputs['sub_logits'].softmax(-1)[0, :, :-1]
        probas_obj = outputs['obj_logits'].softmax(-1)[0, :, :-1]
        keep = torch.logical_and(probas.max(-1).values > 0.3, 
                                 torch.logical_and(probas_sub.max(-1).values > 0.3,
                                                   probas_obj.max(-1).values > 0.3))

        # convert boxes from [0; 1] to image scales
        sub_bboxes_scaled = rescale_bboxes(outputs['sub_boxes'][0, keep], im.size)
        obj_bboxes_scaled = rescale_bboxes(outputs['obj_boxes'][0, keep], im.size)

        topk = 10  # display up to 10 images
        keep_queries = torch.nonzero(keep, as_tuple=True)[0]
        indices = torch.argsort(-probas[keep_queries].max(-1)[0] * 
                                probas_sub[keep_queries].max(-1)[0] * 
                                probas_obj[keep_queries].max(-1)[0])[:topk]
        keep_queries = keep_queries[indices]

        # save the attention weights
        conv_features, dec_attn_weights_sub, dec_attn_weights_obj = [], [], []
        hooks = [
            model.backbone[-2].register_forward_hook(
                lambda self, input, output: conv_features.append(output)
            ),
            model.transformer.decoder.layers[-1].cross_attn_sub.register_forward_hook(
                lambda self, input, output: dec_attn_weights_sub.append(output[1])
            ),
            model.transformer.decoder.layers[-1].cross_attn_obj.register_forward_hook(
                lambda self, input, output: dec_attn_weights_obj.append(output[1])
            )
        ]

        # Process features and attention weights
        outputs = model(img)
        for hook in hooks:
            hook.remove()

        conv_features = conv_features[0]
        dec_attn_weights_sub = dec_attn_weights_sub[0]
        dec_attn_weights_obj = dec_attn_weights_obj[0]

        # get the feature map shape
        h, w = conv_features['0'].tensors.shape[-2:]

        # Open the file in write mode
        #with open(output_file_path, 'w') as file:
        for idx, (sxmin, symin, sxmax, symax), (oxmin, oymin, oxmax, oymax) in \
                zip(keep_queries, sub_bboxes_scaled[indices], obj_bboxes_scaled[indices]):

            texts.append(title_text = (CLASSES[probas_sub[idx].argmax()] + ' ' +
                          REL_CLASSES[probas[idx].argmax()] + ' ' +
                          CLASSES[probas_obj[idx].argmax()]))

            # Write the text to the file

            #file.write(f"Title: {title_text}\n\n")

                # Print the text to the console
#                 print(f"Image: {img_path}")
#                 print(f"Title: {title_text}\n")
    merged_text = "\n".join(texts)
    return merged_text
        
        
for keyframe in tqdm(os.listdir(paths)):
  path_keyframe = os.path.join(paths,keyframe)
  video_paths = sorted(glob.glob(f"{path_keyframe}/*/"))
  video_paths = ['/'.join(i.split('/')[:-1]) for i in video_paths]

  start_time = time.time()
  for vd_path in video_paths:

    re_feats = []
    keyframe_paths = glob.glob(f'{vd_path}/*.jpg')
    keyframe_paths = sorted(keyframe_paths, key=lambda x : x.split('/')[-1].replace('.jpg',''))

    for keyframe_path in tqdm(keyframe_paths):


      #text = ocr_image(keyframe_path)

      #//////////////////////////////////
      text = process_image(keyframe_path, model)
      #//////////////////////////////////
      #if detect(text) == 'vi' :
      text = Translation(text)

      # Convert text to embedding vector
      #embedding = embedding_model(**tokenizer(text, return_tensors="pt", padding=True, truncation=True,max_length=512, add_special_tokens = True)).pooler_output.detach().numpy()
      embeddings = embedding_model.encode(text)
      # Append embedding to re_feats list
      re_feats.append(embeddings)

    name_npy = vd_path.split('/')[-1]

    # Construct output file path
    outfile = os.path.join(des_path, f'{name_npy}.npy')

    # Ensure the directory exists before saving
    os.makedirs(des_path, exist_ok=True)
    np.save(outfile, re_feats)

    print(f"Processed {vd_path} in {time.time() - start_time} seconds")


  2%|▏         | 20/1203 [01:24<1:23:35,  4.24s/it]


KeyboardInterrupt: 

In [None]:
feature_shape = 512


def write_bin_file_ocr(bin_path: str, npy_path: str, method='cosine'):
    if method in 'L2':
      index = faiss.IndexFlatL2(feature_shape)
    elif method in 'cosine':
      index = faiss.IndexFlatIP(feature_shape)
    else:
      assert f"{method} not supported"
    npy_files = glob.glob(os.path.join(npy_path, "*.npy"))
    npy_files_sorted = sorted(npy_files)

    for npy_file in npy_files_sorted:
        feats = np.load(npy_file)
        print(f"Loaded {npy_file}, shape: {feats.shape}")


        # Convert to float32 and reshape to match feature_shape
        feats = feats.astype(np.float32)
        feats = feats.reshape(-1, feats.shape[-1])

        # Resize or trim feats_normalized to match feature_shape if necessary
        if feats.shape[1] != feature_shape:
            feats = feats[:, :feature_shape]

        assert feats.shape[1] == feature_shape, \
            f"Query features dimension {feats.shape[1]} do not match index dimension {feature_shape}"

        # Add to Faiss index
        index.add(feats)

    # Write the Faiss index to disk
    faiss.write_index(index, os.path.join(bin_path, f"faiss_ReITR_{method}.bin"))
    print(f'Saved {os.path.join(bin_path, f"faiss_ReITR_{method}.bin")}')


# write ocr
write_bin_file_ocr(bin_path=f"{WORK_DIR}/data/dicts/bin_ReITR", npy_path=f"{WORK_DIR}/data/dicts/npy_ReITR")

