In [None]:
from init_notebook import *

In [None]:
from experiments.datasets.classic import *
ds = cifar10_dataset(train=True)
ds2 = mnist_dataset(train=True)

print([ds[i][0].shape for i in range(8)])
VF.to_pil_image(make_grid(
    [ds[i][0] for i in range(8*8)]
))

In [None]:
def blur(images, a: float):
    ks = 21
    sig = max(0.001, math.pow(a, 1.7) * 7)
    images = VF.gaussian_blur(images, [ks, ks], [sig, sig]) 
    return images
    
images = torch.cat([ds[i][0].unsqueeze(0) for i in range(8)])
for i in range(20):
    display(VF.to_pil_image(make_grid(blur(images, i/19))))

In [None]:
def elastic(images, a: float):
    tr = VT.ElasticTransform(a*250., 3.)
    sig = max(0.001, math.pow(a, 1.7) * 7)
    images = tr(images)
    return images
    
images = torch.cat([ds[i][0].unsqueeze(0) for i in range(8)])
for i in range(20):
    display(VF.to_pil_image(make_grid(elastic(images, i/19))))

In [None]:
from experiments.diffusion.sampler import *
class DiffusionSamplerDeform(DiffusionSamplerBase):

    def __init__(
            self,
            alpha: float = 50.,
            sigma: float = 3.,
            fill: float = -1.,
    ):
        self.alpha = alpha
        self.sigma = sigma
        self.fill = fill

    @staticmethod
    def get_displacement(sigma: List[float], size: List[int], batch_size, generator):
        dx = torch.rand([1, 1] + size, generator=generator) * 2 - 1
        if sigma[0] > 0.0:
            kx = int(8 * sigma[0] + 1)
            # if kernel size is even we have to make it odd
            if kx % 2 == 0:
                kx += 1
            dx = VF.gaussian_blur(dx, [kx, kx], sigma)
        dx = dx / (size[0] / batch_size)

        dy = torch.rand([1, 1] + size, generator=generator) * 2 - 1
        if sigma[1] > 0.0:
            ky = int(8 * sigma[1] + 1)
            # if kernel size is even we have to make it odd
            if ky % 2 == 0:
                ky += 1
            dy = VF.gaussian_blur(dy, [ky, ky], sigma)
        dy = dy / size[1] 
        return torch.concat([dx, dy], 1).permute([0, 2, 3, 1])[0]  # 1 x H x W x 2

    def _add_noise(
            self,
            images: torch.Tensor,
            noise_amounts: torch.Tensor,
            generator: Optional[torch.Generator],
    ):
        B, C, H, W = images.shape
        disp = self.get_displacement([self.sigma, self.sigma], [B * H, W], B, generator).to(images)
        # disp is [B * H, W, 2]
        disp = disp.view(B, H, W, 2) * noise_amounts[:, None, None] * self.alpha
        
        return torch.cat([
            VF.elastic_transform(image.unsqueeze(0), d.unsqueeze(0), fill=self.fill)
            for image, d in zip(images, disp)
        ])

sampler = DiffusionSamplerDeform(200, 6)

images = torch.cat([ds[i][0].unsqueeze(0) for i in range(8)])
for i in range(10):
    noisy_images, _ = sampler.add_noise(images * 2 - 1., torch.ones(images.shape[0], 1).to(images) * i / 9)
    display(VF.to_pil_image(make_grid(noisy_images * .5 + .5)))

In [None]:
from experiments.denoise.resconv import ResConv

class Module(nn.Module):
  def __init__(self):
      super().__init__()
      self.module = ResConv(
          in_channels=3,
          num_layers=3,
          channels=32,
          stride=1,
          kernel_size=[3, 7, 9],
          padding=[1, 3, 4],
          activation="gelu",
          activation_last_layer=None,
      )

  def forward(self, x):
      return (x - self.module(x)).clamp(0, 1) 

model = Module()
model.load_state_dict(
    torch.load("../checkpoints/denoise/deblur5x10-resconv-bs:64_opt:AdamW_lr:0.0003_l:3_ks1:3_ks2:9_ch:32_stride:1_act:gelu/best.pt")
    ["state_dict"]
)
model

In [None]:
def generate(
    batch_size: int, shape: Tuple[int, int, int],
    ks: int = 5, sigma: float = 10.,
    steps: int = 5,
):
    with torch.no_grad():
        noise = torch.randn((batch_size, shape[0], shape[1]//4, shape[2]//4)).clamp(0, 1)
        noise *= .1 + .9 * torch.linspace(0, 1, batch_size)[:, None, None, None]
        noise = VF.gaussian_blur(noise, [ks, ks], [sigma, sigma])
        noise = resize(noise, 4)
        noise = VF.gaussian_blur(noise, [ks, ks], [sigma, sigma])
        display(VF.to_pil_image(make_grid(noise, nrow=batch_size)))
        
        for i in range(5):
            denoised = model(noise)
            display(VF.to_pil_image(make_grid(denoised, nrow=batch_size)))
            
            noise = (noise + denoised) / 2
            noise = VF.gaussian_blur(noise, [ks, ks], [sigma, sigma])

generate(10, (3, 48, 48))

In [None]:
import torchvision
ds = torchvision.datasets.STL10("~/prog/data/datasets/", download=True)

In [None]:
print(ds.data.shape)
torch.Tensor(ds.data).max()

In [None]:
ds.labels

In [None]:
ds2 = TensorDataset(torch.Tensor(ds.data), torch.Tensor(ds.labels))
print(len(ds2))
print(ds2[0][0].shape)
VF.to_pil_image(make_grid(
    [ds2[i][0]/255 for i in range(8*8)]
))

In [None]:
class ViT(nn.Module):
    def __init__(
            self,
            image_size: int,
            image_channels: int,
            patch_size: int,
            num_layers: int,
            num_heads: int,
            hidden_dim: int,
            mlp_dim: int,
            dropout: float = 0.0,
            attention_dropout: float = 0.0,
    ):
        super().__init__()
        model = torchvision.models.VisionTransformer(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
        )
        self.patch_size = model.patch_size
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.proj_in = nn.Conv2d(image_channels, hidden_dim, kernel_size=self.patch_size, stride=self.patch_size)
        self.transformer = model.encoder
        self.proj_out = nn.ConvTranspose2d(hidden_dim, image_channels, kernel_size=self.patch_size, stride=self.patch_size)

    def forward(self, x):
        y = self.proj_in(x)
        shape = y.shape
        y = y.flatten(-2).permute(0, 2, 1)

        batch_class_token = self.class_token.expand(x.shape[0], -1, -1)
        y = torch.cat([batch_class_token, y], dim=1)

        y = self.transformer(y)
        y = y[:, :-1, :].permute(0, 2, 1)
        y = self.proj_out(y.view(shape))
        return y

m = ViT(32, 3, 4, 3, 4, 100, 1000)
print(f"params: {num_module_parameters(m):,}")
print(m(torch.ones(1, 3, 32, 32)).shape)
m

In [None]:
img = torch.ones(1, 3, 32, 32)
conv = nn.Conv2d(3, 64, 4, stride=4)
conv2 = nn.ConvTranspose2d(64, 3, 4, stride=4)

img2 = conv(img)
print(img2.shape)
print(conv2(img2).shape)