In [None]:
from init_notebook import *

In [None]:
image_ds = Imagenet1kIterableDataset(
    size_filter=[
        (500, 375)
        #(375, 500),
        #(500, 500),
        #(500, 400),
    ],
)
grid = []
for image in tqdm(image_ds.limit(64)):
    #grid.append(image_resize_crop(image, (64, 64)))
    grid.append(resize(image, 1/4))
display(VF.to_pil_image(make_grid(grid, nrow=8)))

In [None]:
grid = []
for image in tqdm(image_ds.shuffle(1000).limit(2000)):
    #grid.append(image_resize_crop(image, (64, 64)))
    grid.append(resize(image, 1/16))
display(VF.to_pil_image(make_grid(grid, nrow=32)))

In [None]:
image_ds = ImageFolderIterableDataset(
    "~/Pictures/__diverse",
    recursive=True,
    force_channels=3,
)
for image in image_ds.limit(2)
    display(VF.to_pil_image(image))

In [None]:
PATCH_SIZE = 31

patch_ds = ImagePatchIterableDataset(
    image_ds.shuffle(1000),
    shape=PATCH_SIZE,
    stride=PATCH_SIZE // 2,
    interleave_images=64,
)
VF.to_pil_image(make_grid(list(patch_ds.limit(64**2)), nrow=64))

In [None]:
from sklearn.decomposition import IncrementalPCA
NUM_PATCHES = 128
pca = IncrementalPCA(NUM_PATCHES)

try:
    for batch in tqdm(DataLoader(patch_ds, batch_size=1024)):
        pca.partial_fit(batch.numpy().reshape(batch.shape[0], 3 * PATCH_SIZE**2))
except KeyboardInterrupt:
    pass

In [None]:
patches = torch.from_numpy(pca.components_).reshape(NUM_PATCHES, 3, PATCH_SIZE, PATCH_SIZE)
print(patches.min(), patches.max(), patches.mean())
def normalize(patch: torch.Tensor):
    patch = patch - patch.min()
    patch = patch / patch.max()
    #patch += .5
    return patch.clamp(0, 1)
VF.to_pil_image(resize(
    #signed_to_image(make_grid(patches.unsqueeze(1), nrow=16))
    make_grid([normalize(p) for p in patches], nrow=16)
    , 4))

In [None]:
FILEPATH = Path("data")
os.makedirs(FILEPATH, exist_ok=True)
PATCHES_FILENAME = FILEPATH / f"pca-patches-{NUM_PATCHES}-{PATCH_SIZE}x{PATCH_SIZE}"

In [None]:
torch.save(patches, PATCHES_FILENAME)

In [None]:
image = next(iter(image_ds.skip(0)))#[..., :512, :512]
conv = nn.Conv2d(3, NUM_PATCHES, PATCH_SIZE, bias=False, stride=PATCH_SIZE // 2)
with torch.no_grad():
    conv.weight[:] = patches

In [None]:
feat1 = conv(image)
print(feat1.shape)
VF.to_pil_image(resize(make_grid(
    [(signed_to_image(i)*3).clamp(0, 1) for i in feat1]
    #[normalize(i) for i in feat1]
    #normalize(feat1[3:6])
    , nrow=2), 2))

# stage2

In [None]:
NUM_PATCHES2 = NUM_PATCHES
pca2 = IncrementalPCA(NUM_PATCHES2)

try:
    with torch.no_grad():
        for image in tqdm(image_ds):
            features = conv(image)
            for batch in iter_image_patches(features, shape=PATCH_SIZE, stride=PATCH_SIZE//2, batch_size=1024):
                #print(batch.shape)
                pca2.partial_fit(batch.numpy().reshape(batch.shape[0], -1))
        #
except KeyboardInterrupt:
    pass

In [None]:
patches2 = torch.from_numpy(pca2.components_).reshape(NUM_PATCHES2, NUM_PATCHES, PATCH_SIZE, PATCH_SIZE)
print(patches2.shape, patches2.min(), patches2.max(), patches2.mean())
VF.to_pil_image(resize(
    make_grid([normalize(p[o:o+1]) for p in patches2 for o in range(16)], nrow=16)
    , 4))

# dataset from .pt files 

In [None]:
class TensorFilesIterableDataset(BaseIterableDataset):
    def __init__(
            self,
            path: Union[str, Path],
    ):
        super().__init__()
        self._path = Path(path)

    def __iter__(self):
        for file in sorted(self._path.glob("*.pt")):
            tensor = torch.load(file)
            for t in tensor:
                yield t

tensor_ds = TensorFilesIterableDataset(
    config.BIG_DATASETS_PATH / "imagenet1k-uint8-by-shape" / "3x375x500"
)
for image in tqdm(tensor_ds):
    pass
    #grid.append(image_resize_crop(image, (128, 128)))
#display(VF.to_pil_image(make_grid(grid)))


In [None]:
96/2**3