In [None]:
import os
import json
from PIL import Image
from orient_anything import get_3angle, render_3D_axis, overlay_images_with_scaling
from transformers import AutoImageProcessor
from rewards.orient import OrientLoss
from torchvision.transforms import ToTensor, ToPILImage
import torch
import numpy as np
device = torch.device("cuda:1")

orientation = [0, 90, 90]
reward_loss = OrientLoss(1.0, torch.float16, device, '/root/.cache/huggingface/hub', False)
results = {}
val_preprocess = AutoImageProcessor.from_pretrained("facebook/dinov2-large", cache_dir='./')

for img_path in sorted([path for path in os.listdir("./results/orient") if path.startswith("a photo of a car_(90,90,90)")]):
    image = Image.open(os.path.join("./results/orient", img_path, "init.png"))
    angles = get_3angle(image, reward_loss.orient_estimator, val_preprocess, device)
    key = img_path.split("_")[-1].zfill(2)  # JSON의 key 값으로 사용
    results[key] = angles[0].item()  # value 값으로 angles[0] 저장

# JSON 파일 저장
with open("./results/orient/azimuth_per_seed.json", "w") as f:
    json.dump(results, f, indent=4, sort_keys=True)


In [None]:
azimuths = results.values()
distribution = [[], [], [], []]
for angle in azimuths:
    if angle >= 0 and angle < 90:
        distribution[0].append(angle)
    elif angle >= 90 and angle < 180:
        distribution[1].append(angle)
    elif angle >= 180 and angle < 270:
        distribution[2].append(angle)
    elif angle >= 270 and angle < 360:
        distribution[3].append(angle)

print(len(distribution[0]))
print(len(distribution[1]))
print(len(distribution[2]))
print(len(distribution[3]))

In [None]:
print(distribution[1])

In [None]:
import os
import re
import matplotlib.pyplot as plt
from PIL import Image

def sort_filenames(filenames):
    """ 파일 리스트를 숫자 인덱스를 기준으로 정렬하는 함수 """
    def extract_third_number(filename):
        match = re.search(r'prompt_\d+_orientation_\d+_(\d+)', filename)
        return int(match.group(1)) if match else float('inf')  # 정수 변환하여 정렬
    return sorted(filenames, key=extract_third_number)

def visualize_orientations(image_folder):
    """ 주어진 폴더에서 이미지를 정렬하여 10x10 그리드로 시각화하는 함수 """
    # 폴더 내 파일 목록 가져오기
    filenames = [f for f in os.listdir(image_folder) if f.endswith("orientation.png")]
    
    # 파일 정렬
    sorted_filenames = sort_filenames(filenames)
    
    # 10x10 그리드 생성
    #fig, axes = plt.subplots(10, 10, figsize=(15, 15), dpi=100)
    fig, axes = plt.subplots(5, 10, figsize=(15, 8), dpi=1000)
    #fig, axes = plt.subplots(2, 10, figsize=(15, 3), dpi=1000)
    
    for i, ax in enumerate(axes.flat):
        if i < len(sorted_filenames):
            img_path = os.path.join(image_folder, sorted_filenames[i])
            img = Image.open(img_path)  # Pillow를 사용하여 이미지 로드
            
            ax.imshow(img)
            ax.set_title(f"Iteration {i}", fontsize=2)
            ax.axis("off")
        else:
            ax.axis("off")  # 이미지가 부족한 경우 빈 칸 처리
    
    plt.tight_layout()
    #plt.show()
    fig.savefig(os.path.join(image_folder, "total_distributions.png"))

image_folder = "/root/code/ReNO/results/var1/sd-turbo/reg_True_lr_3.0_seed_0_noise_optimize_False_noises_0"  # 이미지가 저장된 폴더 경로 지정
visualize_orientations(image_folder)


In [None]:
def visualize_images(image_folder):
    """ 주어진 폴더에서 이미지를 정렬하여 10x10 그리드로 시각화하는 함수 """
    # 폴더 내 파일 목록 가져오기
    filenames = [f for f in os.listdir(image_folder) if f.endswith(".png") and not f.endswith("orientation.png") and not f.endswith("init.png") and not f.endswith("result.png")]
    
    # 파일 정렬
    sorted_filenames = sort_filenames(filenames)
    
    # 10x10 그리드 생성
    #fig, axes = plt.subplots(10, 10, figsize=(15, 15))
    fig, axes = plt.subplots(5, 10, figsize=(15, 8), dpi=1000)
    #fig, axes = plt.subplots(2, 10, figsize=(15, 3), dpi=1000)
    
    for i, ax in enumerate(axes.flat):
        if i < len(sorted_filenames):
            img_path = os.path.join(image_folder, sorted_filenames[i])
            img = Image.open(img_path)  # Pillow를 사용하여 이미지 로드
            
            ax.imshow(img)
            ax.set_title(f"Iteration {i}", fontsize=8)
            ax.axis("off")
        else:
            ax.axis("off")  # 이미지가 부족한 경우 빈 칸 처리
    
    plt.tight_layout()
    #plt.show()
    fig.savefig(os.path.join(image_folder, "total_images.png"))
# 예제 사용법
image_folder = "/root/code/ReNO/results/var1/sd-turbo/reg_True_lr_3.0_seed_0_noise_optimize_False_noises_0"  # 이미지가 저장된 폴더 경로 지정
visualize_images(image_folder)

## Inference Orient-Anything

In [12]:
import argparse
from vision_tower import DINOv2_MLP
from transformers import AutoImageProcessor
import torch
from PIL import Image
import torch.nn.functional as F
from utils import *
from inference import *
import os
from huggingface_hub import hf_hub_download

# Argument parser for source and save directories
# Download the model checkpoint
ckpt_path = hf_hub_download(repo_id="Viglong/Orient-Anything", filename="croplargeEX2/dino_weight.pt", repo_type="model", cache_dir='/root/data/model', resume_download=True)
print(ckpt_path)

# Setup device and model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dino = DINOv2_MLP(
    dino_mode='large',
    in_dim=1024,
    out_dim=360+180+180+2,
    evaluate=True,
    mask_dino=False,
    frozen_back=False
)

dino.eval()
print('Model created')
dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
dino = dino.to(dtype=torch.float32, device=device)
print('Weights loaded')
val_preprocess = AutoImageProcessor.from_pretrained("facebook/dinov2-large", cache_dir='/root/data/model')

/root/data/model/models--Viglong--Orient-Anything/snapshots/5249ecae5cf2b8371874a88e9ab766ce81760242/croplargeEX2/dino_weight.pt
large
Model created


  dino.load_state_dict(torch.load(ckpt_path, map_location='cpu'))


Weights loaded


In [16]:
image_folder = "/root/code/ReNO/results/resampled/sd-turbo/reg_True_lr_3.0_seed_0_noise_optimize_False_noises_0"
output_file = "orientations_prompt_0_orientation_4.txt"

import re

def sort_filenames(filenames):
    """ 파일 리스트를 숫자 인덱스를 기준으로 정렬하는 함수 """
    def extract_third_number(filename):
        match = re.search(r'prompt_\d+_orientation_\d+_(\d+)', filename)
        return int(match.group(1)) if match else float('inf')  # 정수 변환하여 정렬
    return sorted(filenames, key=extract_third_number)


filenames = [f for f in os.listdir(image_folder) if f.endswith(".png") and not f.endswith("orientation.png") and not f.endswith("init.png") and not f.endswith("result.png")]
sorted_filenames = sort_filenames(filenames)
with open(os.path.join(image_folder, output_file), "w") as file:
    for image_path in sorted_filenames:
        image = Image.open(os.path.join(image_folder, image_path)).convert('RGB')

        angles = get_3angle(image, dino, val_preprocess, device)
        azimuth = float(np.radians(angles[0]))
        polar = float(np.radians(angles[1]))
        rotation = float(angles[2])
        confidence = float(angles[3])
        if image_path.startswith("prompt_0_orientation_4"):
            result_line = f"filename: {image_path}, azimuth: {angles[0]}\n"
            file.write(result_line)
            print(result_line.strip())  # 화면에도 출력

    #render_axis = render_3D_axis(azimuth, polar, rotation)
    #res_img = overlay_images_with_scaling(render_axis, image)
    
    #output_file = image_path[:-4] + "_rendered"
    #res_img.save()'rendered_png')

  azimuth = float(np.radians(angles[0]))
  polar = float(np.radians(angles[1]))


filename: prompt_0_orientation_4_0.png, azimuth: 40.0
filename: prompt_0_orientation_4_1.png, azimuth: 36.0
filename: prompt_0_orientation_4_2.png, azimuth: 30.0
filename: prompt_0_orientation_4_3.png, azimuth: 319.0
filename: prompt_0_orientation_4_4.png, azimuth: 315.0
filename: prompt_0_orientation_4_5.png, azimuth: 315.0
filename: prompt_0_orientation_4_6.png, azimuth: 315.0
filename: prompt_0_orientation_4_7.png, azimuth: 310.0
filename: prompt_0_orientation_4_8.png, azimuth: 311.0
filename: prompt_0_orientation_4_9.png, azimuth: 315.0
filename: prompt_0_orientation_4_10.png, azimuth: 311.0
filename: prompt_0_orientation_4_11.png, azimuth: 311.0
filename: prompt_0_orientation_4_12.png, azimuth: 315.0
filename: prompt_0_orientation_4_13.png, azimuth: 317.0
filename: prompt_0_orientation_4_14.png, azimuth: 311.0
filename: prompt_0_orientation_4_15.png, azimuth: 310.0
filename: prompt_0_orientation_4_16.png, azimuth: 315.0
filename: prompt_0_orientation_4_17.png, azimuth: 296.0
filen

In [6]:
val_preprocess

BitImageProcessor {
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_convert_rgb": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "BitImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 256
  }
}

In [7]:
import rembg
import torch
import numpy as np 
from PIL import Image

def remove_background(image: Image,
    rembg_session=None,
    force=None,
    **rembg_kwargs,
) -> Image:

    rembg_session = rembg.new_session()
    removed_image = rembg.remove(image, session=rembg_session)

    alpha = torch.from_numpy(np.array(removed_image))[..., 3] > 0
    nonzero_coords = torch.nonzero(alpha, as_tuple=True)

    # Use bounding box and expand it
    y_min, x_min = nonzero_coords[0].min().item(), nonzero_coords[1].min().item()
    y_max, x_max = nonzero_coords[0].max().item(), nonzero_coords[1].max().item()
    
    return [x_min, x_max, y_min, y_max]

In [8]:
for idx in range(10):
    image = Image.open(f"/root/code/ReNO/results/masking_test/no_mask/sd-turbo/reg_True_lr_3.0_seed_{idx}_noise_optimize_False_noises_0/prompt_0_orientation_0_init.png")
    print(remove_background(image))

NameError: name 'torch' is not defined