In [None]:
import sys
sys.path.append("..")

import random

import time
import math
import random
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd

from src.datasets import *
from src.util.image import * 
from src.util import ImageFilter
from src.algo import Space2d, IFS
from src.datasets.generative import *

In [None]:
def plot_samples(
        iterable, 
        total: int = 32, 
        nrow: int = 8, 
        return_image: bool = False, 
        show_compression_ratio: bool = False,
        label: Optional[Callable] = None,
):
    samples = []
    labels = []
    f = ImageFilter()
    try:
        for entry in tqdm(iterable, total=total):
            image = entry
            if isinstance(entry, (list, tuple)):
                image = entry[0]
            if image.ndim == 4:
                image = image.squeeze(0)
            samples.append(image)
            if show_compression_ratio:
                labels.append(round(f.calc_compression_ratio(image), 3))
            elif label is not None:
                labels.append(label(entry))
                
            if len(samples) >= total:
                break
    except KeyboardInterrupt:
        pass
    
    if labels:
        image = VF.to_pil_image(make_grid_labeled(samples, nrow=nrow, labels=labels))
    else:
        image = VF.to_pil_image(make_grid(samples, nrow=nrow))
    if return_image:
        return image
    display(image)

In [None]:
from scripts.train_autoencoder_vae import *

In [None]:
SHAPE = (1, 64, 64)
CODE_SIZE = 128
model = VariationalAutoencoderConv(SHAPE, CODE_SIZE, channels=[32, 64, 128])
model.load_state_dict(torch.load("../checkpoints/vae13-kali-convl3-128/best.pt")["state_dict"])

In [None]:
SHAPE = (3, 64, 64)
CODE_SIZE = 128
model = VariationalAutoencoderConv(SHAPE, CODE_SIZE, channels=[32, 64, 128])
model.load_state_dict(torch.load("../checkpoints/vae13-kali-convl3-128/best.pt")["state_dict"])

## plot random samples

In [None]:
VF.to_pil_image(make_grid(model.decoder(torch.randn(32, CODE_SIZE)*.005).clamp(0, 1)))

## transition

In [None]:
f1, f2 = torch.randn(2, CODE_SIZE) * 0.01

f = torch.zeros(8, 128)
for i in range(f.shape[0]):
    t = i / (f.shape[0] - 1)
    f[i] = f1 * (t-1) + t * f2
VF.to_pil_image(make_grid(model.decoder(f).clamp(0, 1)))