In [10]:
# pyright: reportMissingImports=false
import os
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import (
    Compose,
    Grayscale,
    Normalize,
    Resize,
    RandomHorizontalFlip,
    ToTensor,
    ToPILImage,
)
from pytorch3d.io import load_obj, save_obj
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes

from torchvision.datasets import ImageFolder

from src.utilities.util import scale_geometry
from src.augment.geoaug import GeoAugment

def pyramid_transform(img_size, mask_size,  mean=0, std=1):
    transform = {
        'preprocess': Compose([
            Resize([mask_size, mask_size]),
            ToTensor(),            
        ]),
        'head': Compose([
            ToPILImage(),
            RandomHorizontalFlip(),            
        ]),
        'image': Compose([
            Resize([img_size, img_size]),
            Grayscale(),
            ToTensor(),
            Normalize(mean=(mean), std=(std)),

        ]),        
    }
    def final_transform(img, mask):
        img = transform['preprocess'](img)
        img = img * mask        
        flipped = transform['head'](img)
        return {
            'image': transform['image'](flipped),            
        }
    return final_transform

class FullDataset(torch.utils.data.Dataset):
    
    def __init__(self, config):
        
        self.num_workers = config.num_workers
        self.pin_memory = config.pin_memory        
        self.outline_size =  config.fast_outline_size
        self.baseline_size = config.fast_baseline_size
        self.stl_offset =  config.stl_offset
        self.geoaug_policy = config.geoaug_policy

        self.image_root = config.fast_image_root
        self.mask_root = config.mask_root            
        self.image_size = config.fast_image_size
        self.mask_size = config.mask_size
        self.image_mean = config.fast_image_mean
        self.image_std = config.fast_image_std
        self.blends_no = config.data_blends_no
        
        self.data_grid_dir = config.data_grid_dir
        self.data_mesh_dir = config.data_mesh_dir
        
        
        self.transform = pyramid_transform(self.image_size, self.mask_size, 
                                           self.image_mean, self.image_std)
        self.img_ds = ImageFolder(self.image_root)
        
        self.grid_files = [os.path.join(self.data_grid_dir, f) 
                           for f in os.listdir(self.data_grid_dir)]
        self.grid_files.sort()
        grid_verts = [torch.load(f)['vertices']for f in  self.grid_files]
        self.grid_baselines = [
            self.scale(v, self.baseline_size) for v in grid_verts]
        self.grid_outlines = [
            self.scale(v, self.outline_size) for v in grid_verts]
        
        self.mesh_files = [os.path.join(self.data_mesh_dir, f) 
                           for f in os.listdir(self.data_mesh_dir)]
        self.device = torch.device('cpu')
        print('Dataset setup finished')
        
    def scale(self, t, size):
        return F.interpolate(t[None], size=size, mode='bilinear', align_corners=True)[0]
        
    def __len__(self):
        return len(self.img_ds)
    
    def get_samples(self, idx):
        idx_mesh = idx % len(self.mesh_files)
        mesh_file= self.mesh_files[idx_mesh]
        verts, faces = scale_geometry(mesh_file, self.device, offset=self.stl_offset)
        trg_mesh = Meshes(verts=[verts], faces=[faces])
        samples = sample_points_from_meshes(trg_mesh, self.baseline_size ** 2)[0]
        #samples = samples.t().reshape(3, self.baseline_size, self.baseline_size)
        return samples.contiguous()
    
    def get_grid(self, idx):
        idx_grid = idx % len(self.grid_files)
        return {
            'grid_baseline': self.grid_baselines[idx_grid],             
            'grid_outline':  self.grid_outlines[idx_grid],
        }
    
    def get_blends(self, _):
        baselines = torch.stack([self.grid_baselines[i]
             for i in torch.randint(0, len(self.grid_files), (self.blends_no,))])
        outlines = torch.stack([self.grid_outlines[i]
             for i in torch.randint(0, len(self.grid_files), (self.blends_no,))])
        q = F.normalize(torch.rand(self.blends_no), p=1, dim=0).reshape(-1, 1, 1, 1)        
        return {
            'blend_baseline': (baselines * q).sum(dim=0),
            'blend_outline':  (outlines * q).sum(dim=0),
        }
    
    def __getitem__(self, idx):              
        res = {}
        # Image
        # idx_img = idx % len(self.img_ds)
        # image, _ = self.img_ds[idx_img]
        # mask_path =  self.img_ds.imgs[0][0].replace(self.image_root, self.mask_root)
        # mask = torch.load(mask_path.replace('.png', '.pth'))
        # res = self.transform(image, mask)

        #res['samples'] = self.get_samples(idx)
        
        grid =  self.get_grid(idx)
        for key in grid.keys():  res[key] = grid[key]
            
        blend =  self.get_blends(idx)
        for key in blend.keys():  res[key] = blend[key]      
        
        return res
    
from src.config import get_parser

config = get_parser().parse_args(args=[])    

ds = FullDataset(config)
ds[0]

Dataset setup finished


{'grid_baseline': tensor([[[-0.2217, -0.2188, -0.2152,  ...,  0.2601,  0.2589,  0.2581],
          [-0.2262, -0.2236, -0.2202,  ...,  0.2605,  0.2587,  0.2575],
          [-0.2349, -0.2326, -0.2292,  ...,  0.2608,  0.2582,  0.2565],
          ...,
          [-0.3745, -0.3743, -0.3739,  ...,  0.3509,  0.3519,  0.3524],
          [-0.3754, -0.3754, -0.3751,  ...,  0.3523,  0.3527,  0.3529],
          [-0.3760, -0.3760, -0.3759,  ...,  0.3530,  0.3532,  0.3532]],
 
         [[-0.2912, -0.2940, -0.2974,  ..., -0.2078, -0.2098, -0.2110],
          [-0.2857, -0.2881, -0.2914,  ..., -0.2073, -0.2101, -0.2119],
          [-0.2751, -0.2772, -0.2805,  ..., -0.2069, -0.2107, -0.2133],
          ...,
          [ 0.3433,  0.3447,  0.3479,  ...,  0.3685,  0.3717,  0.3741],
          [ 0.3499,  0.3514,  0.3551,  ...,  0.3715,  0.3734,  0.3756],
          [ 0.3535,  0.3553,  0.3595,  ...,  0.3729,  0.3737,  0.3753]],
 
         [[ 0.0507,  0.0522,  0.0545,  ...,  0.0296,  0.0301,  0.0304],
          [

In [4]:
self =  ds
self.grid_files

['/home/bobi/Desktop/pic2mesh/data/stl_grid/abe_white_inpatient_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/abigaile_ortiz_tomb_raider_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/ada_wong_resident_evil_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/administrator_evil_within_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/aidan_overkill_walking_dead_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/amelia_croft_tomb_raider_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/anakin_skywalker_battlefront_2_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/angela_civilian_detroit_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/annie_dead_rising_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/anya_gears_of_war_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/aphrodite_ascendant_one_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/aranea_highwind_final_fantasy_256.pth',
 '/home/bobi/Desktop/pic2mesh/data/stl_grid/ares_ascendant

[tensor([[[-0.2217, -0.2204, -0.2188,  ...,  0.2589,  0.2584,  0.2581],
          [-0.2235, -0.2223, -0.2207,  ...,  0.2588,  0.2583,  0.2579],
          [-0.2262, -0.2251, -0.2236,  ...,  0.2586,  0.2580,  0.2575],
          ...,
          [-0.3754, -0.3754, -0.3754,  ...,  0.3527,  0.3529,  0.3529],
          [-0.3758, -0.3758, -0.3757,  ...,  0.3530,  0.3531,  0.3531],
          [-0.3760, -0.3760, -0.3760,  ...,  0.3532,  0.3532,  0.3532]],
 
         [[-0.2912, -0.2924, -0.2939,  ..., -0.2098, -0.2105, -0.2110],
          [-0.2890, -0.2901, -0.2916,  ..., -0.2099, -0.2108, -0.2114],
          [-0.2857, -0.2867, -0.2881,  ..., -0.2101, -0.2111, -0.2118],
          ...,
          [ 0.3499,  0.3505,  0.3515,  ...,  0.3735,  0.3746,  0.3756],
          [ 0.3520,  0.3527,  0.3537,  ...,  0.3737,  0.3747,  0.3755],
          [ 0.3535,  0.3542,  0.3553,  ...,  0.3737,  0.3745,  0.3753]],
 
         [[ 0.0507,  0.0513,  0.0522,  ...,  0.0301,  0.0303,  0.0304],
          [ 0.0507,  0.0514,