In [1]:
from typing import Dict
import webdataset as wds
import numpy as np
from omegaconf import DictConfig, ListConfig
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
from PIL import Image
from torchvision import transforms
import torchvision
from einops import rearrange
from ldm.util import instantiate_from_config
from datasets import load_dataset
import pytorch_lightning as pl
import copy
import csv
import cv2
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, ConcatDataset
import json
import os, sys
import webdataset as wds
import io
import tarfile
import math
import re
from safetensors.torch import load as load_sftr, load_file as load_sftr_file
from torch.utils.data.distributed import DistributedSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ExtendedObjaverseData(Dataset):
    def __init__(self,
        root_dir='.objaverse/hf-objaverse-v1/views',
        data_config_file=None,
        load_tensors=False,
        image_transforms=[],
        postprocess=None,
        return_paths=False,
        total_view=54,
        use_canonical_views=True,
        num_canonical_views=6,
        color0_prob=0.5,
        cond_polar_thr=180,
        elevation_cond=False,
        elevation_std=10.0,
        validation=False,
        ) -> None:
        """Create a dataset from a folder of images.
        If you pass in a root directory it will be searched for images
        ending in ext (ext can be a list)
        """
        print("********** ", root_dir)
        self.root_dir = Path(root_dir)
        self.return_paths = return_paths
        if isinstance(postprocess, DictConfig):
            postprocess = instantiate_from_config(postprocess)
        self.postprocess = postprocess
        self.total_view = total_view
        self.use_canonical_views = use_canonical_views
        self.num_canonical_views = num_canonical_views
        self.color0_prob = color0_prob
        self.cond_polar_thr = cond_polar_thr
        self.load_tensors = load_tensors
        self.elevation_cond = elevation_cond
        self.elevation_std = elevation_std

        data_config_path = data_config_file if data_config_file else os.path.join(root_dir, 'data_config.json')
        with open(data_config_path) as f:
            self.paths = json.load(f)
            
        total_objects = len(self.paths)
        if validation:
            self.paths = self.paths # [math.floor(total_objects / 100. * 99.):] # used last 1% as validation
            import random
            random.shuffle(self.paths)
        else:
            self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training
        print('============= length of dataset %d =============' % len(self.paths))
        self.tform = image_transforms

    def __len__(self):
        return len(self.paths)

    def load_img(self, object_storage, index, fname_format_str='rgba/rgba_{view_index:04d}.png'):
        img_fname = f'{fname_format_str.format(view_index=index)}'
        if isinstance(object_storage, tarfile.TarFile):
            object_name = object_storage.getnames()[0]
            image = object_storage.extractfile(f'{object_name}/{img_fname}').read()
            image = Image.open(io.BytesIO(image)).convert('RGBA')
        else:
            image = Image.open(os.path.join(object_storage, img_fname)).convert('RGBA')
        return image
    
    def load_tensor(self, object_storage, index, color_idx=0,
                    fname_format_str='view-{view_index:03d}-c{color_idx:02d}.sftr', key='vae_latent'):
        tensor_fname = f'{fname_format_str.format(view_index=index, color_idx=color_idx)}'
        if isinstance(object_storage, tarfile.TarFile):
            object_name = object_storage.getnames()[0]
            tensor = object_storage.extractfile(f'{object_name}/{tensor_fname}').read()
            tensor = load_sftr(tensor)[key]
        else:
            tensor = load_sftr_file(os.path.join(object_storage, tensor_fname))[key]
        return tensor

    def load_viewpoint(self, object_storage, index, prefix='frame_'):
        metas_fname = f'{prefix}{index:04d}.json'
        if isinstance(object_storage, tarfile.TarFile):
            object_name = object_storage.getnames()[0]
            metas = object_storage.extractfile(f'{object_name}/{metas_fname}').read()
            metas = json.loads(metas)
        else:
            with open(os.path.join(object_storage, metas_fname), 'r') as f:
                metas = json.loads(f.read())
        polar, azimuth, r = metas["polar"], metas["azimuth"], metas["r"]
        return polar, azimuth, r

    def process_img(self, img, background_color=(255, 255, 255)):
        background = Image.new('RGBA', img.size, background_color)
        img = Image.alpha_composite(background, img).convert("RGB")
        return self.tform(img) if self.tform else img

    def get_T(self, object_storage, index_target, index_cond):
        target_polar, target_azimuth, target_r = self.load_viewpoint(object_storage, index_target)
        cond_polar, cond_azimuth, cond_r = self.load_viewpoint(object_storage, index_cond)
        
        d_polar = target_polar - cond_polar
        d_azimuth = (target_azimuth - cond_azimuth) % (2 * math.pi)
        d_r = target_r - cond_r
        
        if self.elevation_cond:
            randomized_polar_cond = np.clip(np.random.normal(cond_polar, self.elevation_std / 180. * math.pi), 0, math.pi)
            d_T = torch.tensor([d_polar, math.sin(d_azimuth), math.cos(d_azimuth), randomized_polar_cond])
        else:
            d_T = torch.tensor([d_polar, math.sin(d_azimuth), math.cos(d_azimuth), d_r])
        return d_T

    def get_background_ratio(self, img):
        return (np.array(img)[..., 3] != 0).sum() / img.height / img.width

    def extract_data(self, object_storage, index_target, index_cond, color_idx=0):
        data = {}
        if not self.load_tensors:
            img_target = self.load_img(object_storage, index_target)
            img_cond = self.load_img(object_storage, index_cond)
            data["image_target"] = self.process_img(img_target)
            data["image_cond"] = self.process_img(img_cond)

            data["background_ratio_target"] = self.get_background_ratio(img_target)
            data["background_ratio_cond"] = self.get_background_ratio(img_cond)
            data["target_polar"], data["target_azimuth"], _ = \
                self.load_viewpoint(object_storage, index_target)
            data["cond_polar"], data["cond_azimuth"], _ = \
                self.load_viewpoint(object_storage, index_cond)
        else:
            data["latent_target"] = self.load_tensor(object_storage, index_target, color_idx=color_idx)
            data["latent_cond"] = self.load_tensor(object_storage, index_cond, color_idx=color_idx)
            data["clip_emb_cond"] = self.load_tensor(
                object_storage, index_cond, color_idx=color_idx,
                fname_format_str='clip-{view_index:03d}-c{color_idx:02d}.sftr',
                key='clip_emb')
        data["T"] = self.get_T(object_storage, index_target, index_cond)

        return data
    
    def get_indices(self, object_storage, total_view):
        available_indices = [v for v in range(total_view) if (self.use_canonical_views or v >= self.num_canonical_views)]
        if self.cond_polar_thr > 0:
            cond_polar_thr_rad = self.cond_polar_thr / 180. * math.pi
            polars = [(i, self.load_viewpoint(object_storage, i)[0]) for i in available_indices]
            polars = [p for p in polars if p[1] < cond_polar_thr_rad]
            index_cond = random.choice(polars)[0]
            index_target = random.choice([i for i in available_indices if i != index_cond])
        else:
            index_target, index_cond = random.sample(available_indices, 2)
        return index_target, index_cond

    def __getitem__(self, index):
        data = {}
        # TODO: set seed
        object_filepath = os.path.join(self.root_dir, self.paths[index])
        is_tar = False
        if object_filepath.endswith('.tar'):
            is_tar = True
        object_name = Path(object_filepath).stem
        if self.return_paths:
            data["path"] = str(object_filepath)

        object_storage = tarfile.open(object_filepath) if is_tar else object_filepath
        
        total_view = self.total_view
        object_files = object_storage.getnames() if is_tar else os.listdir(object_storage)
        if self.load_tensors:
            total_view = len([f for f in object_files if re.findall(r'view-(\d+)-c00.sftr', f)])
            total_view = min(total_view, len([f for f in object_files if re.findall(r'clip-(\d+)-c00.sftr', f)]))
            total_view = min(total_view, len([f for f in object_files if re.findall(r'frame_(\d+).json', f)]))
        else:
            total_view = len([f for f in object_files if re.findall(r'rgba/rgba_(\d+).png', f)])
            total_view = min(total_view, len([f for f in object_files if re.findall(r'frame_(\d+).json', f)]))

        # TODO: remove
        if total_view < 48:
            print(f"==== Invalid object {object_name} ====")
            return self.__getitem__((index + 1) % len(self.paths))

        color_idx = 0 if random.random() < self.color0_prob else 1

        try:
            index_target, index_cond = self.get_indices(object_storage, total_view)
            # index_target, index_cond = random.sample(range(total_view-1), 2) # without replacement
            # index_target, index_cond = 2, 1
            data = self.extract_data(object_storage, index_target, index_cond, color_idx=color_idx)
        except KeyboardInterrupt:
            raise
        except:
            print(f"************* Invalid files {object_filepath} {index_target} {index_cond} ***************")
            with open("/fsx/proj-mod3d/dmitry/repos/zero123/zero123/invalid_files.txt", "a") as f:
                f.write(f'{object_filepath}:({index_target}, {index_cond})\n')
            index_target, index_cond = 1, 2
            data = self.extract_data(object_storage, index_target, index_cond, color_idx=color_idx)
            return self.__getitem__((index + 1) % len(self.paths))

        data['object_name'] = object_name
        data['index_target'], data['index_cond'] = index_target, index_cond

        if self.postprocess is not None:
            data = self.postprocess(data)

        return data

In [7]:
from glob import glob
import json

In [4]:
data_dir = '/scratch/objaverse_new_untar/'

In [5]:
paths = glob(os.path.join(data_dir, "*.json"))

In [10]:
with open(paths[0]) as f:
    j = json.load(f)

In [15]:
paths[0]

'/scratch/objaverse_new_untar/f854f5eb9d4e480f97db63805f04ee59.ffffff.json'

In [16]:
j['objaverse']['license']

'by'

In [19]:
valid_paths = []
for path in paths:
    with open(path) as f:
        j = json.load(f)
    if j['objaverse']['license'] not in ('by-nc-sa', 'by-nc', 'by-sa'):
        valid_paths.append(path)

In [20]:
len(valid_paths) / len(paths)

0.9040178571428571

In [24]:
os.path.basename(valid_paths[0]).split('.')[0]

'f854f5eb9d4e480f97db63805f04ee59'

In [25]:
with open('good_objaverse_dirs.json') as f:
    old_data = json.load(f)

In [29]:
num_matches = 0
for vpath in valid_paths:
    base = os.path.basename(vpath).split('.')[0]
    num_matches += base in old_data

In [33]:
untarred_objects = os.listdir('/scratch/objaverse_untar_2')
with open('objaverse_filtered_by_nc.json') as f:
    filtered_objects = json.load(f)

In [41]:
len(untarred_objects)

189342

In [39]:
filtered_present_objects = list(set(filtered_objects).intersection(set(untarred_objects)))

In [46]:
151067 / 189342

0.797852563086901

In [45]:
len(filtered_present_objects) / len(untarred_objects)

0.8107762672835398

In [44]:
with open('objaverse_filtered_by_nc_valid.json', "w") as f:
    json.dump(filtered_present_objects, f)