In [1]:
from tqdm import tqdm
import torch
import cv2
import numpy as np
import math
import matplotlib.pyplot as plt
OUTPUT_DIR = "./output/"

In [None]:
cube_img = cv2.imread("output/cube_map.png")

In [None]:
def create_3dmap_from_size_torch(img_w, img_h, device):
    # np.linspaceをtorch.linspaceに置き換え
    h = torch.linspace(-np.pi/2, np.pi/2, img_h, device=device)
    w = torch.linspace(-np.pi, np.pi, img_w, device=device)
    
    # オフセットの追加
    h += (np.pi/2) / img_h
    w += np.pi / img_w
    
    # np.meshgridをtorch.meshgridに置き換え
    theta, phi = torch.meshgrid(w, h)
    
    # 3D座標の計算
    x = torch.cos(phi) * torch.cos(theta)
    y = torch.cos(phi) * torch.sin(theta)
    z = torch.sin(phi)
    
    return x, y, z

In [2]:
def padding_cube(img, device):
    # Convert input numpy array to PyTorch tensor
    img_tensor = torch.tensor(img).to(device).permute(0, 1, 2)
    
    h, w, c = img_tensor.shape
    cw = w // 4
    print(h, w, c)
    
    # Initialize canvas tensor
    canvas = torch.zeros((h+4, w+4, c), dtype=img_tensor.dtype, device=device)
    canvas[2:-2, 2:-2,:] = img_tensor
    
    # up    
    canvas[0:2, cw+2:2*cw+2,:] = torch.rot90(img_tensor[cw:cw+2, 3*cw:,:], 2, [0,1])
    # bottom
    canvas[-2:, cw+2:2*cw+2,:] = torch.rot90(img_tensor[2*cw-2:2*cw, 3*cw:,:], 2, [0,1])
    # left
    canvas[cw+2:2*cw+2, 0:2,:] = img_tensor[cw:2*cw, -2:,:]
    # right
    canvas[cw+2:2*cw+2, -2:,:] = img_tensor[cw:2*cw, 0:2,:]

    # Rotate and copy
    canvas[cw:cw+2, :cw+2,:] = torch.rot90(canvas[:cw+2, cw+2:cw+4,:], 1, [0,1])
    canvas[:cw+2, cw:cw+2,:] = torch.rot90(canvas[cw+2:cw+4, :cw+2,:], 3, [0,1])
    canvas[2*cw+2:2*cw+4, :cw+2,:] = torch.rot90(canvas[2*cw+2:, cw+2:cw+4,:], 3, [0,1])
    canvas[2*cw+2:, cw:cw+2,:] = torch.rot90(canvas[2*cw:2*cw+2, :cw+2,:], 1, [0,1])
    canvas[cw:cw+2, 2*cw+2:3*cw+4,:] = torch.rot90(canvas[:cw+2, 2*cw:2*cw+2,:], 3, [0,1])
    canvas[:cw+2, 2*cw+2:2*cw+4,:] = torch.rot90(canvas[cw+2:cw+4, 2*cw+2:3*cw+4,:], 1, [0,1])
    canvas[2*cw+2:2*cw+4, 2*cw+2:3*cw+2,:] = torch.rot90(canvas[2*cw+2:-2, 2*cw:2*cw+2,:], 1, [0,1])
    canvas[2*cw+2:, 2*cw+2:2*cw+4,:] = torch.rot90(canvas[2*cw:2*cw+2, 2*cw+2:3*cw+4,:], 3, [0,1])
    
    # Flip and copy
    #canvas[cw:cw+2, 3*cw+2:,:] = torch.flip(canvas[3:1:-1, 2*cw+1:cw-1:-1,:], [0,1])
    #canvas[2*cw+2:2*cw+4, 3*cw+2:,:] = torch.flip(canvas[-3:-5:-1, 2*cw+1:cw-1:-1,:], [0,1])
    
    # Convert the tensor back to a numpy array
    return canvas.cpu().numpy()

In [3]:
def cube_to_equirectangular_torch(img, width, device):
    # imgをテンソルに変換
    img_tensor = torch.tensor(img, device=device).float()

    img_w = width
    img_h = width // 2
    width = img.shape[1] // 4

    x, y, z = create_3dmap_from_size_torch(img_w, img_h, device)

    w = 0.5

    # front
    xx = w*y / x + w
    yy = w*z / x + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (x > 0)
    tmpx = torch.where(mask, xx*width + width, 0)
    tmpy = torch.where(mask, yy*width + width, 0)
     
    # back
    xx = w*y / x + w
    yy = -w*z / x + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (x < 0)
    tmpx = torch.where(mask, xx*width + width*3, tmpx)
    tmpy = torch.where(mask, yy*width + width, tmpy)
     
    #right
    xx = -w*x / y + w
    yy = w*z / y + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (y > 0)
    tmpx = torch.where(mask, xx*width + width*2, tmpx)
    tmpy = torch.where(mask, yy*width + width, tmpy)
     
    #left
    xx = -w*x / y + w
    yy = -w*z / y + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (y < 0)
    tmpx = torch.where(mask, xx*width, tmpx)
    tmpy = torch.where(mask, yy*width + width, tmpy)
     
    #up
    xx = -w*y / z + w
    yy = -w*x / z + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (z < 0)
    tmpx = torch.where(mask, xx*width + width, tmpx)
    tmpy = torch.where(mask, yy*width, tmpy)
     
    #bottom
    xx = w*y / z + w
    yy = -w*x / z + w    
    mask = (xx > 0) & (xx < 1) & (yy > 0) & (yy < 1) & (z > 0)
    tmpx = torch.where(mask, xx*width + width, tmpx)
    tmpy = torch.where(mask, yy*width + width*2, tmpy)

    cube = padding_cube(img, device)
    print(cube.shape, type(cube))

    # grid_sampleを使うための座標の変換
    grid = torch.stack((2*y/img_h - 1, 2*x/img_w - 1), dim=-1)
    grid = grid.unsqueeze(0)  # バッチ次元の追加

    # チャンネルの次元を先頭に移動
    #cube = cube.permute(2, 0, 1).unsqueeze(0)
    cube = torch.tensor(cube, device=device).permute(2, 0, 1).float().cpu().numpy().astype(np.float32)

    # grid_sampleを使用してリマップ
    output = cv2.remap(cube, tmpx.cpu().numpy().astype(np.float32), tmpy.cpu().numpy().astype(np.float32), interpolation=cv2.INTER_LINEAR)

    return output

numpy.ndarray

In [4]:
immm = cube_to_equirectangular_torch(cube_img, 1920*2, device)

In [5]:
cv2.imwrite("eq2.jpg", immm)

torch.Size([6, 3, 1560, 1560])
