In [27]:
# pyright: reportMissingImports=false
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import (
    Compose,
    Grayscale,
    Normalize,
    Resize,
    RandomHorizontalFlip,
    ToTensor,    
)

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import pytorch_lightning as pl

from src.utilities.util import (
    grid_to_list,
    list_to_grid,
)
def pyramid_transform(img_size, mean=0, std=1):
    transform = Compose([            
        Resize([img_size, img_size]),
        RandomHorizontalFlip(),
        Grayscale(),
        ToTensor(),
        Normalize(mean=(mean), std=(std)),
    ])
    def final_transform(img):        
        return transform(img)
    
    return final_transform

class SlicedRenderDataset(torch.utils.data.Dataset):
    
    def __init__(self, config):        
        self.image_root = config.data_renders_dir        
        self.image_mean = config.fast_image_mean
        self.image_std = config.fast_image_std
        self.image_size = config.fast_image_size
        self.full_size = config.grid_full_size
        self.slice_size = config.grid_slice_size
                
        self.transform = pyramid_transform(self.image_size, 
                                           self.image_mean, self.image_std)
        self.img_ds = ImageFolder(self.image_root, transform=self.transform)
        self.slice_indices = self.make_indices(self.full_size - self.slice_size)
        
    def __len__(self):
        return len(self.img_ds) * len(self.slice_indices)
    
    def make_indices(self, n):         
        t = torch.arange(n)
        return torch.stack(torch.torch.meshgrid(t, t), dim=-1).reshape(-1, 2)
    
    def scale(self, t, size):
        return F.interpolate(t, size=size, mode='bilinear', align_corners=True)    
    
    def get_grid(self, idx):
        f =  self.img_ds.imgs[idx][0]
        f = f.replace('renders', 'grid').replace('.png', '.pth')
        grid = list_to_grid(torch.load(f)[None])        
        grid = self.scale(grid, self.full_size)[0]
        return grid
    
    def get_slice(self, idx):
        grid = self.get_grid(idx % len(self.img_ds))
        indices = self.slice_indices[idx % len(self.slice_indices)]
        r, c = indices
        return (grid[:, r:r+self.slice_size, c:c+self.slice_size],
                indices)
   
    def __getitem__(self, idx):              
        res = {}        
        image, label = self.img_ds[idx % len(self.img_ds)]
        res['image'] =  image
        #res['label'] =  label
        slice_data, slice_idx = self.get_slice(idx)
        res['slice_data'] = slice_data
        res['slice_idx'] = slice_idx
        return res

class SlicedRenderDataModule(pl.LightningDataModule):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.fast_batch_size
        self.num_workers = config.num_workers
        self.pin_memory = config.pin_memory
        self.train_ds = SlicedRenderDataset(self.config)         
        
    def train_dataloader(self):
        return DataLoader(self.train_ds, shuffle=True, batch_size=self.batch_size, 
            num_workers=self.num_workers, pin_memory=self.pin_memory)        
    
from src.config import get_parser

config = get_parser().parse_args(args=[]) 
ds = SampleRenderDataset(config)
ds0 = ds[0]
for k,v in ds0.items():
    print(k, v.shape)

image torch.Size([1, 128, 128])
slice_data torch.Size([3, 16, 16])
slice_idx torch.Size([2])


In [28]:
dm = SlicedRenderDataModule(config)
dm

<__main__.SlicedRenderDataModule at 0x7ff86735acd0>

In [30]:
len(dm.train_dataloader())

595200

In [25]:
ds.slice_size, ds.full_size, len(ds)

(16, 32, 4761600)

In [16]:
grid_full_size = 5
grid_slice_size = 2
res =  []
for r in range(grid_full_size-grid_slice_size):
    for c in range(grid_full_size-grid_slice_size):
        res.append((r, c+10)) 
        print(r, c+10)
len(res)

0 10
0 11
0 12
1 10
1 11
1 12
2 10
2 11
2 12


9

In [13]:
n = grid_full_size-grid_slice_size
t = torch.arange(n)
r = torch.stack(torch.torch.meshgrid(t, t+10), dim=-1).reshape(-1, 2)
r, r.shape

(tensor([[ 0, 10],
         [ 0, 11],
         [ 0, 12],
         [ 1, 10],
         [ 1, 11],
         [ 1, 12],
         [ 2, 10],
         [ 2, 11],
         [ 2, 12]]),
 torch.Size([9, 2]))

In [9]:
t

tensor([0, 1, 2])