In [None]:
%pushd ../../

In [None]:
%env CUDA_VISIBLE_DEVICES=3

In [None]:
import json

import os
import sys
import tempfile
from tqdm.auto import tqdm

import torch
import torchvision
from torchvision import transforms
from PIL import Image
import numpy as np

torch.cuda.set_device(0)

In [None]:
from netdissect import setting

In [None]:
segopts = 'netpqc'

In [None]:
segmodel, seglabels, _ = setting.load_segmenter(segopts)

In [None]:
class UnsupervisedImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root, transform=None, max_size=None, get_path=False):
        self.temp_dir = tempfile.TemporaryDirectory()
        os.symlink(root, os.path.join(self.temp_dir.name, 'dummy'))
        root = self.temp_dir.name
        super().__init__(root, transform=transform)
        self.get_path = get_path
        self.perm = None
        if max_size is not None:
            actual_size = super().__len__()
            if actual_size > max_size:
                self.perm = torch.randperm(actual_size)[:max_size].clone()
                logging.info(f"{root} has {actual_size} images, downsample to {max_size}")
            else:
                logging.info(f"{root} has {actual_size} images <= max_size={max_size}")

    def _find_classes(self, dir):
        return ['./dummy'], {'./dummy': 0}

    def __getitem__(self, key):
        if self.perm is not None:
            key = self.perm[key].item()
        sample = super().__getitem__(key)[0]
        if self.get_path:
            path, _ = self.samples[key]
            return sample, path
        else:
            return sample
            

    def __len__(self):
        if self.perm is not None:
            return self.perm.size(0)
        else:
            return super().__len__()

In [None]:
len(seglabels)

In [None]:
transform = transforms.Compose([
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])


In [None]:
def process(img_path, seg_path, device='cuda', batch_size=128, **kwargs):
    os.makedirs(seg_path, exist_ok=True)

    dataset = UnsupervisedImageFolder(img_path, transform=transform, get_path=True)
    loader = torch.utils.data.DataLoader(dataset, num_workers=24, batch_size=batch_size, pin_memory=True)  
    
    with torch.no_grad():
        for x, paths in tqdm(loader):
            segs = segmodel.segment_batch(x.to(device), **kwargs).detach().cpu()
            for path, seg in zip(paths, segs):
                k = os.path.splitext(os.path.basename(path))[0]
                torch.save(seg, os.path.join(seg_path, k + '.pth'))
            del segs

In [None]:
import glob

In [None]:
torch.backends.cudnn.benchmark=True

In [None]:
process(
    '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/domes',
    'churches/domes',
    batch_size=12)

In [None]:
process(
    '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/dome2tree',
    'churches/dome2tree/ours',
    batch_size=12)

In [None]:
process(
    '/data/vision/torralba/ganprojects/placesgan/tracer/utils/samples/dome2spire',
    'churches/dome2spire/ours',
    batch_size=8)