In [26]:
# pyright: reportMissingImports=false
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import (
    Compose,
    Grayscale,
    Normalize,
    Resize,
    RandomHorizontalFlip,
    ToTensor,
    ToPILImage,
)
from pytorch3d.io import load_obj, save_obj
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes

from torchvision.datasets import ImageFolder

from src.utilities.util import scale_geometry
from src.augment.geoaug import GeoAugment

def pyramid_transform(img_size, mean=0, std=1):
    transform = Compose([            
        Resize([img_size, img_size]),
        RandomHorizontalFlip(),
        Grayscale(),
        ToTensor(),
        Normalize(mean=(mean), std=(std)),
    ])
    def final_transform(img):        
        return transform(img)
    
    return final_transform

class SampleRenderDataset(torch.utils.data.Dataset):
    
    def __init__(self, config):        
        self.image_root = config.data_renders_dir        
        self.image_mean = config.fast_image_mean
        self.image_std = config.fast_image_std
        self.image_size = config.fast_image_size
                
        self.transform = pyramid_transform(self.image_size, 
                                           self.image_mean, self.image_std)
        self.img_ds = ImageFolder(self.image_root, transform=self.transform)        
        
    def scale(self, t, size):
        return F.interpolate(t[None], size=size, mode='bilinear', align_corners=True)[0]
        
    def __len__(self):
        return len(self.img_ds)
    
    def get_samples(self, idx):
        f =  img_ds.imgs[idx][0]
        f = f.replace('renders', 'samples').replace('.png', '.pth')
        return torch.load(f)
   
    def __getitem__(self, idx):              
        res = {}
        image, label = self.img_ds[idx]
        res['image'] =  image
        res['label'] =  label
        res['samples'] =  self.get_samples(idx)        
        return res
    
    
from src.config import get_parser

config = get_parser().parse_args(args=[])   

config.data_renders_dir = '/home/bobi/Desktop/pic2mesh/data/augmented/renders'

ds = SampleRenderDataset(config)
ds

<__main__.SampleRenderDataset at 0x7f7d661ebbb0>

In [27]:
ds[0]

{'image': tensor([[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          ...,
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]]),
 'label': 0,
 'samples': tensor([[ 0.0339, -0.2852,  0.4394],
         [ 0.2332,  0.4915,  0.5623],
         [ 0.4892, -0.3453,  0.2136],
         ...,
         [-0.2259, -0.1352,  0.1179],
         [ 0.0136,  0.2328,  0.4754],
         [-0.3026,  0.8602,  0.1647]])}

In [2]:
image_root = '/home/bobi/Desktop/pic2mesh/data/augmented/renders'
img_ds = ImageFolder(image_root)
img_ds

Dataset ImageFolder
    Number of datapoints: 18600
    Root location: /home/bobi/Desktop/pic2mesh/data/augmented/renders

In [3]:
img_ds[0]

(<PIL.Image.Image image mode=RGB size=512x512 at 0x7F7D671CB9A0>, 0)

In [11]:
idx = 0
path = img_ds.imgs[idx][0].replace('renders', 'samples').replace('.png', '.pth')
torch.load(path).shape

torch.Size([65536, 3])

In [1]:
from src.data.sample_render import SampleRenderDataModule

from src.config import get_parser

config = get_parser().parse_args(args=[])   

config.data_renders_dir = '/home/bobi/Desktop/pic2mesh/data/augmented/renders'

dm = SampleRenderDataModule(config)
dm

<src.data.sample_render.SampleRenderDataModule at 0x7f02b46c6640>

In [2]:
res = next(iter(dm.train_dataloader()))
res

{'image': tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           ...,
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.]]],
 
 
         [[[-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           ...,
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.]]],
 
 
         [[[-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           ...,
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.],
           [-1., -1., -1.,  ..., -1., -1., -1.]]],
 
 
         ...,
 
 
         [[[-1., -1., -1.