# Set up

In [None]:
save_dir = None
checkpoint_dir = None
keyframes_dir = None


In [None]:
!git clone https://github.com/xinyu1205/recognize-anything.git
%pip install timm transformers fairscale pycocoevalcap

# change the working directory to the cloned repo
%cd recognize-anything

In [None]:
import os

dir_path = os.getcwd() # Get the current directory path
parent_path = os.path.dirname(os.path.dirname(dir_path))

if save_dir is None:
    save_dir = f"{parent_path}/metadata/tag_output"
if checkpoint_dir is None:
    checkpoint_dir = f"{parent_path}/metadata/pretrained"
if keyframes_dir is None:
    keyframes_dir = f'{parent_path}/transnet/Keyframes'

In [None]:
import os
import glob
import json
import torch
import numpy as np
from PIL import Image
from ram.models import ram_plus
from ram import get_transform
from tqdm import tqdm
import subprocess

# Download CheckPoint

In [None]:
def download_checkpoints(model, checkpoint_dir):
    print('You selected', model)
    print(f'Checkpoint directory: {checkpoint_dir}')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    print(checkpoint_dir)

    if model == "RAM++":
        ram_weights_path = os.path.join(checkpoint_dir, 'ram_plus_swin_large_14m.pth')
        if not os.path.exists(ram_weights_path):
            url = "https://huggingface.co/xinyu1205/recognize-anything-plus-model/resolve/main/ram_plus_swin_large_14m.pth"
            subprocess.run(['wget', url, '-O', os.path.join(checkpoint_dir, 'ram_plus_swin_large_14m.pth')], check=True)
            print(f"RAM weights downloaded to {ram_weights_path}")
        else:
            print("RAM weights already downloaded!")

model = "RAM++"
download_checkpoints(model, checkpoint_dir)
print(model, 'weights are downloaded!')
print(keyframes_dir)

# Parse data path

In [None]:

def parse_data_path(keyframes_dir):
    all_keyframe_paths = dict()
    for part in sorted(os.listdir(keyframes_dir)):
        all_keyframe_paths[part] =  dict() # L01, L02

    for data_part in sorted(all_keyframe_paths.keys()):
        data_part_path = f'{keyframes_dir}/{data_part}'       
        video_dirs = sorted(os.listdir(data_part_path))         # ['V001', 'V002', ...]
        video_ids = [video_dir for video_dir in video_dirs] 
        for video_id, video_dir in zip(video_ids, video_dirs):
            keyframe_paths = sorted(glob.glob(f'{data_part_path}/{video_dir}/*.jpg'))
            all_keyframe_paths[data_part][video_id] = keyframe_paths

    return all_keyframe_paths
    
all_keyframe_paths = parse_data_path(keyframes_dir)

# Inference

In [None]:
import json
import torch
from PIL import Image
from ram.models import ram_plus
from ram import inference_ram as inference
from ram import get_transform
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

# Configuration
PRETRAINED_MODEL_PATH = f'{checkpoint_dir}/ram_plus_swin_large_14m.pth'
IMAGE_SIZE = 384

def main():
    def create_directory(path):
        """Create a directory if it does not exist."""
        if not os.path.exists(path):
            os.makedirs(path)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = get_transform(image_size=IMAGE_SIZE)

    # Load model
    model = ram_plus(pretrained=PRETRAINED_MODEL_PATH, image_size=IMAGE_SIZE, vit='swin_l')
    model.eval()
    model = model.to(device)


    # Initialize output dictionary
    output_data = {}

    # Process each keyframe
    for part in tqdm(all_keyframe_paths, desc="Processing parts"):
        for video_id in tqdm(all_keyframe_paths[part], desc=f"Processing videos in {part}", leave=False):
            for image_path in tqdm(all_keyframe_paths[part][video_id], desc=f"Processing keyframes in {video_id}", leave=False):
                try:
                    image = transform(Image.open(image_path)).unsqueeze(0).to(device)
                    res = inference(image, model)
                    tags = res[0].split(' | ')
                    output_data[os.path.basename(image_path)] = tags
                except Exception as e:
                    print(f"Error processing {image_path}: {str(e)}")   
                feature_type = "tagging"
                # Save output JSON
                json_path = os.path.join(
                    save_dir, f'{part}', f'{video_id}_{feature_type}.json')
                create_directory(os.path.dirname(json_path))
                with open(json_path, 'w') as f:
                    json.dump(output_data, f)

if __name__ == "__main__":
    main()

In [1]:
# remove the cloned repo
rm -rf ../recognize-anything

In [2]:
# remove the pretrained weights
rm -rf ../pretrained