In [None]:
%pushd ../../

In [None]:
%env CUDA_VISIBLE_DEVICES=2
%env TORCH_EXTENSIONS_DIR=/tmp/torch_extensions_tongzhou

In [None]:
import json

import os
import sys
import tempfile
from tqdm.auto import tqdm

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np

torch.cuda.set_device(0)

In [None]:
from netdissect import setting

In [None]:
segopts = 'netpqc'

In [None]:
segmodel, seglabels, _ = setting.load_segmenter(segopts)

In [None]:
class UnsupervisedImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root, transform=None, max_size=None, get_path=False):
        self.temp_dir = tempfile.TemporaryDirectory()
        os.symlink(root, os.path.join(self.temp_dir.name, 'dummy'))
        root = self.temp_dir.name
        super().__init__(root, transform=transform)
        self.get_path = get_path
        self.perm = None
        if max_size is not None:
            actual_size = super().__len__()
            if actual_size > max_size:
                self.perm = torch.randperm(actual_size)[:max_size].clone()
                logging.info(f"{root} has {actual_size} images, downsample to {max_size}")
            else:
                logging.info(f"{root} has {actual_size} images <= max_size={max_size}")

    def _find_classes(self, dir):
        return ['./dummy'], {'./dummy': 0}

    def __getitem__(self, key):
        if self.perm is not None:
            key = self.perm[key].item()
        sample = super().__getitem__(key)[0]
        if self.get_path:
            path, _ = self.samples[key]
            return sample, path
        else:
            return sample
            

    def __len__(self):
        if self.perm is not None:
            return self.perm.size(0)
        else:
            return super().__len__()

In [None]:
for i, l in enumerate(seglabels):
    if 'dome' in l:
        print(i, l)

In [None]:
import glob

In [None]:
torch.backends.cudnn.benchmark=True

In [None]:
class Sup2UnsupDatasetWrapper(object):
    def __init__(self, dataset, max_size=None, get_key=False):
        self.perm = None
        if max_size is not None:
            actual_size = len(dataset)
            if actual_size > max_size:
                self.perm = torch.randperm(actual_size)[:max_size].clone()
                logging.info(f"{dataset} has {actual_size} images, downsample to {max_size}")
            else:
                logging.info(f"{dataset} has {actual_size} images <= max_size={max_size}")
        self.dataset = dataset
        self.get_key = get_key

    def __getitem__(self, key):
        if self.perm is not None:
            key = self.perm[key].item()
        sample = self.dataset[key][0]
        if self.get_key:
            return sample, key
        else:
            return sample

    def __len__(self):
        if self.perm is not None:
            return self.perm.size(0)
        else:
            return len(self.dataset)

In [None]:
root = '/data/vision/torralba/datasets/LSUN/lsun2017'
split = 'church_outdoor_train'

In [None]:

# max center crop
# from biggan
# https://github.com/ajbrock/BigGAN-PyTorch/blob/65ade92981e9f44e3b7aea895e20886219a85a25/utils.py#L434
class CenterCropLongEdge(object):
    """Crops the given PIL Image on the long edge.
    Args:
      size (sequence or int): Desired output size of the crop. If size is an
          int instead of sequence like (h, w), a square crop (size, size) is
          made.
    """

    def __call__(self, img):
        """
        Args:
            img (PIL Image): Image to be cropped.
        Returns:
            PIL Image: Cropped image.
        """
        return torchvision.transforms.functional.center_crop(img, min(img.size))
    
transform = transforms.Compose([
                              CenterCropLongEdge(),
                              transforms.Resize(256, 256),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])


In [None]:
dataset = Sup2UnsupDatasetWrapper(torchvision.datasets.LSUN(root, [split], transform=transform), get_key=True)

In [None]:
len(dataset)

In [None]:
!ls ./notebooks/stats/ffhq/smiling

In [None]:
!ls churches

In [None]:
num_samplers = 4
sampler_shard = len(dataset) // 4
sampler_idx = 2

In [None]:
class sampler(torch.utils.data.Sampler):
    def __init__(self):
        self.sampler_min = sampler_shard * sampler_idx
        if sampler_idx == num_samplers - 1:
            self.sampler_max = len(dataset)
        else:
            self.sampler_max = sampler_shard * (sampler_idx + 1)
        pass
    
    def __len__(self):
        return self.sampler_max - self.sampler_min
    
    def __iter__(self):
        
        yield from range(self.sampler_min, self.sampler_max)
        
sampler = sampler()

In [None]:
dome_idx = 1708

In [None]:
seg_path = 'churches/real/train'
device = 'cuda'
os.makedirs(seg_path, exist_ok=True)

loader = torch.utils.data.DataLoader(dataset, num_workers=24, batch_size=8, pin_memory=True, sampler=sampler)  

has_dome = []
largest = -1

with torch.no_grad():
    for x, keys in tqdm(loader):
        segs = segmodel.segment_batch(x.to(device)).detach().cpu()
        for key, seg in zip(keys, segs):
            if (seg == dome_idx).any():
                has_dome.append(key)
#             torch.save(seg, os.path.join(seg_path, f'{key}.pth'))
        largest = key
        del segs

In [None]:
has_dome = [v.item() for v in has_dome]
with open('churches/real/train/has_dome_2.json', 'w') as f:
    json.dump(has_dome, f)