In [1]:
'''
 * The Recognize Anything Plus Model (RAM++) inference on unseen classes
 * Written by Xinyu Huang
'''
import argparse
import numpy as np
import random
from torchvision.transforms import Normalize, Compose, Resize, ToTensor

import torch

from PIL import Image
from ram.models import ram_plus
from ram import inference_ram_openset as inference
from ram import get_transform

from ram.utils import build_openset_llm_label_embedding
from torch import nn
import json

# parser = argparse.ArgumentParser(
#     description='Tag2Text inferece for tagging and captioning')
# parser.add_argument('--image',
#                     metavar='DIR',
#                     help='path to dataset',
#                     default='images/openset_example.jpg')
# parser.add_argument('--pretrained',
#                     metavar='DIR',
#                     help='path to pretrained model',
#                     default='pretrained/ram_plus_swin_large_14m.pth')
# parser.add_argument('--image-size',
#                     default=384,
#                     type=int,
#                     metavar='N',
#                     help='input image size (default: 448)')
# parser.add_argument('--llm_tag_des',
#                     metavar='DIR',
#                     help='path to LLM tag descriptions',
#                     default='datasets/openimages_rare_200/openimages_rare_200_llm_tag_descriptions.json')

# args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"device using: {device}")


def convert_to_rgb(image):
    return image.convert("RGB")


# transform = get_transform(image_size=384)
transform = Compose([
        convert_to_rgb,
        Resize((2160, 3840)),
        ToTensor(),
    ])

transform2 = Compose([
        Resize((384, 384)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

#######load model
model = ram_plus(pretrained='pretrained/ram_plus_swin_large_14m.pth',
                         image_size=384,
                         vit='swin_l')

#######set openset interference

print('Building tag embedding:')
with open('datasets/openimages_rare_200/openimages_rare_200_llm_tag_descriptions.json', 'rb') as fo:
    llm_tag_des = json.load(fo)
openset_label_embedding, openset_categories = build_openset_llm_label_embedding(llm_tag_des)

model.tag_list = np.array(openset_categories)

model.label_embed = nn.Parameter(openset_label_embedding.float())

model.num_class = len(openset_categories)
# the threshold for unseen categories is often lower
model.class_threshold = torch.ones(model.num_class) * 0.5
#######

model.eval()

model = model.to(device)




device using: cuda


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


--------------
pretrained/ram_plus_swin_large_14m.pth
--------------
load checkpoint from pretrained/ram_plus_swin_large_14m.pth
vit: swin_l
Building tag embedding:
Creating pretrained CLIP model


100%|██████████| 14/14 [00:00<00:00, 28.28it/s]


In [2]:
image = transform(Image.open('../2ndFloorData/images/L2ndFloor-D-2021-12-22_T-22_00_01.jpg')).unsqueeze(0).to(device)

res = inference(transform2(image), model)
print("Image Tags: ", res)

Image Tags:  Traffic cone


In [3]:
transform2(image).shape

torch.Size([1, 3, 384, 384])

In [None]:
import torch

def Slicing(img, window_height, window_width, stride):
    batch_size, num_channels, img_height, img_width = img.shape
    print(f"Image shape: {img.shape}")
    
    slices = []
    for i in range(0, img_height - window_height + 1, int(stride * window_height)):
        for j in range(0, img_width - window_width + 1, int(stride * window_width)):
            img_segment = img[:, :, i:i+window_height, j:j+window_width]
            slices.append(img_segment)
    
    return torch.cat(slices, dim=0)

sliced = Slicing(image, 512, 512, stride=0.5)
sliced.shape

Image shape: torch.Size([1, 3, 2160, 3840])


torch.Size([98, 3, 512, 512])

In [5]:
transform2(sliced).shape

torch.Size([98, 3, 384, 384])

In [6]:
with torch.no_grad():
    tags = model.generate_tag_openset_prob(transform2(sliced))

In [7]:
len(tags)

98

In [9]:
def get_max_probabilities(list_of_dicts):
    max_probs = {}

    for d in list_of_dicts:
        for tag, prob in d.items():
            if tag not in max_probs:
                max_probs[tag] = prob
            else:
                max_probs[tag] = max(max_probs[tag], prob)

    return max_probs

In [11]:
max_probs = get_max_probabilities(tags)
max_probs

{'Barricade': 0.8060165643692017,
 'Traffic cone': 0.9115189909934998,
 'Traffic barrel': 0.7850309014320374,
 'Scaffold': 0.5142012238502502,
 'Trailer truck': 0.4665645658969879,
 'Police car': 0.8906596899032593,
 'Ambulance': 0.748346209526062,
 'Firecar': 0.6207081079483032,
 'Excavator': 0.3295373320579529,
 'Construction truck': 0.5063043832778931,
 'Car moving truck': 0.5681664347648621,
 'Offroad parking': 0.4186611473560333,
 'Construction car': 0.5562924742698669,
 'Construction worker': 0.690403163433075}

In [None]:
import glob
import os
from tqdm import tqdm

image_dir = '../2ndFloorData/images/'
image_paths = glob.glob(os.path.join(image_dir, '*.*'))

image_paths = [p for p in image_paths if p.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff'))]

results = {}
for img_path in tqdm(image_paths):
    file_name = os.path.basename(img_path) 
    image = transform(Image.open(img_path)).unsqueeze(0).to(device)
    sliced = Slicing(image, 512, 512, stride=0.5)
    with torch.no_grad():
        tags = model.generate_tag_openset_prob(transform2(sliced))
    max_probs = get_max_probabilities(tags)
    results[file_name] = max_probs

In [21]:
results

{'L2ndFloor-D-2022-01-28_T-00_30_01.jpg': {'Barricade': 0.8109170198440552,
  'Traffic cone': 0.9033704996109009,
  'Traffic barrel': 0.7182284593582153,
  'Scaffold': 0.5912752151489258,
  'Trailer truck': 0.5170974135398865,
  'Police car': 0.8763160109519958,
  'Ambulance': 0.6897092461585999,
  'Firecar': 0.541136622428894,
  'Excavator': 0.3267523944377899,
  'Construction truck': 0.5438394546508789,
  'Car moving truck': 0.5371878147125244,
  'Offroad parking': 0.43454355001449585,
  'Construction car': 0.5737396478652954,
  'Construction worker': 0.6603937745094299},
 'L2ndFloor-D-2022-02-27_T-03_30_01.jpg': {'Barricade': 0.7447649836540222,
  'Traffic cone': 0.749268651008606,
  'Traffic barrel': 0.6054819822311401,
  'Scaffold': 0.4709216356277466,
  'Trailer truck': 0.8329820036888123,
  'Police car': 0.8434134721755981,
  'Ambulance': 0.8311448097229004,
  'Firecar': 0.6064722537994385,
  'Excavator': 0.32304590940475464,
  'Construction truck': 0.8629425168037415,
  'Car mo

In [26]:
import pickle

# Save
with open('../out_ram.pkl', 'wb') as f:
    pickle.dump(results, f)