In [None]:
from datasets import load_dataset
import numpy as np
import torchvision
from torch import nn
import torch
from torch.utils.data import DataLoader, IterableDataset


In [None]:
train_dataset = load_dataset("mingyy/chinese_landscape_paintings", split='train', streaming=True)
class StreamingDataLoader(IterableDataset):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size

    def __iter__(self):
        batch = []
        for item in self.dataset:
            batch.append(item)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch:
            yield batch

In [None]:
def streaming_map(dataset, transform_fn):
    for sample in dataset:
        yield transform_fn(sample)

In [None]:

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(512),
    #torchvision.transforms.RandomCrop(512),
])

In [None]:
batch_size=64


transformed_dataset = streaming_map(train_dataset, transforms)
dataloader = DataLoader(StreamingDataLoader(transformed_dataset, batch_size=batch_size), 
                        num_workers=2, pin_memory=True)

In [None]:

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # Additional layers can be added here if needed
        )
        self.conv6_mu = nn.Conv2d(512, latent_dim, 4, 2, 1)
        self.conv6_logvar = nn.Conv2d(512, latent_dim, 4, 2, 1)

    def forward(self, x):
        x = self.conv_layers(x)
        mu = self.conv6_mu(x)
        logvar = self.conv6_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()  # Use nn.Sigmoid() if your data is normalized to [0,1]
        )

    def forward(self, z):
        z = self.deconv_layers(z)
        return z

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        logvar = torch.clamp(logvar, -30, 20)
        variance = logvar.exp()
        stdev = variance.sqrt()
        eps = torch.randn_like(stdev)
        z = mu + eps * stdev
        return z

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        #x_recon = self.decoder(z)
        return z

In [None]:
vae = VAE(128)
checkpoint = torch.load("/vae.pth")
vae.load_state_dict(checkpoint['vae'])

In [None]:
def get_time_embedding(timestep, d_embed=64):
    freqs = torch.pow(10000, -torch.arange(start=0, end=d_embed, step=2, dtype=torch.float32) / d_embed)
    args = timestep * freqs
    embedding = torch.cat([torch.cos(args), torch.sin(args)])
    return embedding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsampleBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x_pooled = self.pool(x)
        return x, x_pooled

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(UpsampleBlock, self).__init__()
        if mid_channels is None:
            mid_channels = out_channels
        self.up = nn.ConvTranspose2d(in_channels // 2, out_channels, kernel_size=2, stride=2)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        return x

class UNetDiffusionModel(nn.Module):
    def __init__(self, time_embedding_channels, batch_size):
        super(UNetDiffusionModel, self).__init__()
        self.down1 = ConvBlock(128 + 128+ time_embedding_channels, 256)
        self.down2 = ConvBlock(256, 512)
        self.up1 = ConvBlock(512, 256)
        self.up2 = ConvBlock(256, 128)
        self.out_conv = nn.Conv2d(128, 128, kernel_size=1)
        self.time_embed = SinusoidalTimeEmbedding(time_embedding_channels)
        self.batch_size = batch_size

    def forward(self, x, t, crops):
        time_emb = self.time_embed(t)
        time_emb = time_emb.unsqueeze(-1).unsqueeze(-1)
        time_emb = time_emb.expand(-1, -1, x.shape[2], x.shape[3])
        time_emb = time_emb.repeat(self.batch_size, 1, 1, 1)

        # Concatenate time embedding with the input feature map
        x = torch.cat([x, time_emb, crops], dim=1)

        # Convolutional blocks without downsampling
        x = self.down1(x)
        x = self.down2(x)

        # Upsampling layers
        x = self.up1(x)
        x = self.up2(x)

        # Output layer
        x = self.out_conv(x)
        return x


In [None]:
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super(SinusoidalTimeEmbedding, self).__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = (torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)).to(device)
        emb = t.float() * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.dim % 2 == 1:  # Zero pad if odd dimension
            emb = F.pad(emb, (0, 1, 0, 0))
        return emb


In [None]:
import numpy as np
import torch
import torch.nn.functional as F

class DDPMSampler:

    def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
        # Params "beta_start" and "beta_end" taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
        # For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
        self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.one = torch.tensor(1.0)

        self.generator = generator

        self.num_train_timesteps = num_training_steps
        self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())

    def set_inference_timesteps(self, num_inference_steps=50):
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_train_timesteps // self.num_inference_steps
        timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
        self.timesteps = torch.from_numpy(timesteps)

    def _get_previous_timestep(self, timestep: int) -> int:
        prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
        return prev_t
    
    def _get_variance(self, timestep: int) -> torch.Tensor:
        prev_t = self._get_previous_timestep(timestep)

        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev

        # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
        # and sample from it to get previous sample
        # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
        variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t

        # we always take the log of variance, so clamp it to ensure it's not 0
        variance = torch.clamp(variance, min=1e-20)

        return variance
    
    def set_strength(self, strength=1):
        """
            Set how much noise to add to the input image. 
            More noise (strength ~ 1) means that the output will be further from the input image.
            Less noise (strength ~ 0) means that the output will be closer to the input image.
        """
        # start_step is the number of noise levels to skip
        start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step

    def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
        t = timestep
        prev_t = self._get_previous_timestep(t)

        # 1. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[t]
        alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev
        current_alpha_t = alpha_prod_t / alpha_prod_t_prev
        current_beta_t = 1 - current_alpha_t

        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)

        # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
        current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t

        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents

        # 6. Add noise
        variance = 0
        if t > 0:
            device = model_output.device
            noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
            # Compute the variance as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
            variance = (self._get_variance(t) ** 0.5) * noise
        
        # sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # the variable "variance" is already multiplied by the noise N(0, 1)
        pred_prev_sample = pred_prev_sample + variance

        return pred_prev_sample
    
    def add_noise(
        self,
        original_samples: torch.FloatTensor,
        timesteps: torch.IntTensor,
    ) -> torch.FloatTensor:
        alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
        sqrt_alpha_prod = sqrt_alpha_prod.flatten()
        while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

        sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
        while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
            sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

        # Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
        # Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
        # here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
        noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

        

    

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

diffusion = UNetDiffusionModel(128, batch_size).to(device)
generator = torch.Generator(device=device)
sampler = DDPMSampler(generator)
vae.to(device)

for parameter in vae.parameters():
    parameter.requires_grad = False

In [None]:
import cv2
import numpy as np


def process_image(pil_image, required_dimension=512):
    # Read the image
    numpy_image = np.array(pil_image)
    image_rgb = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB)

    # Create a mask for white pixels
    lower_white = np.array([254, 254, 254], dtype=np.uint8)
    upper_white = np.array([255, 255, 255], dtype=np.uint8)
    white_pixels_mask = cv2.inRange(image_rgb, lower_white, upper_white)

    # Invert the mask to get non-white areas
    non_white_pixels_mask = cv2.bitwise_not(white_pixels_mask)

    # Use the non-white pixel mask to create the separated_rgb image
    separated_rgb = cv2.bitwise_and(image_rgb, image_rgb, mask=non_white_pixels_mask)

    # Find contours in the non_white_pixels_mask
    contours, _ = cv2.findContours(non_white_pixels_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    if contours:
        x, y, w, h = cv2.boundingRect(contours[0])
        for contour in contours:
            x1, y1, w1, h1 = cv2.boundingRect(contour)
            x, y, w, h = min(x, x1), min(y, y1), max(x + w, x1 + w1) - min(x, x1), max(y + h, y1 + h1) - min(y, y1)

        # Crop the separated_rgb image
        cropped_image = separated_rgb[y:y+h, x:x+w]

        # Determine the smaller dimension
        min_dimension = min(cropped_image.shape[0], cropped_image.shape[1])

        # Scale the image if necessary
        if min_dimension < required_dimension:
            scale_factor = required_dimension / min_dimension
            rescaled_image = cv2.resize(cropped_image, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
        else:
            rescaled_image = cropped_image
            

    return non_white_pixels_mask.ToPILImage(), rescaled_image.ToPILImage(), separated_rgb.ToPILImage()

In [None]:
def create_random_box(image_width, image_height):
    # Determine the side length range for the square box (30% to 80% of the shorter image dimension)
    max_side_length = int(min(image_width, image_height) * 0.8)
    min_side_length = int(min(image_width, image_height) * 0.3)

    # Randomly choose the side length of the square
    side_length = np.random.randint(min_side_length, max_side_length + 1)

    # Ensure the square box fits within the image
    max_x = image_width - side_length
    max_y = image_height - side_length

    x = np.random.randint(0, max_x + 1)
    y = np.random.randint(0, max_y + 1)

    return x, y, side_length, side_length

def modify_image(pil_image, box_x, box_y, box_width, box_height):
    # Convert PIL image to numpy array
    image_array = np.array(pil_image)

    # Create a mask with the same size as the image, initialized to False
    mask = np.ones(image_array.shape[:2])

    # Change all pixels in the box to 0 (black) in the image and set the mask to True in the box area
    image_array[box_y:box_y + box_height, box_x:box_x + box_width, :] = 0
    mask[box_y:box_y + box_height, box_x:box_x + box_width] = 0

    # Extract the cropped image data from the original image
    cropped_image_data = image_array[box_y:box_y + box_height, box_x:box_x + box_width, :]

    # Convert the numpy array back to a PIL image
    modified_image = Image.fromarray(image_array)

    return cropped_image_data, mask, modified_image



In [None]:
mask_threshold = 0.1
def generate_latent_mask(original_latents, masked_latents):
    difference = torch.abs(original_latents - masked_latents)

    # Create a binary mask using torch.where
    mask = torch.where(abs(difference > mask_threshold), torch.tensor(0), torch.tensor(1))
    #mask = torch.where(difference > mask_threshold, torch.tensor(0), torch.tensor(1))

    #what is image=1, isn't 0

    return mask

In [None]:
import torch.optim as optim
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import io
import math
from PIL import Image
from tqdm import tqdm

sampler.set_inference_timesteps(30)
l1_loss = nn.L1Loss()
lr = 0.0001
epochs = 2
clip_value = 1.0
# Optimizers
optimizer = optim.Adam(diffusion.parameters(), lr=lr)
criterion = nn.MSELoss()
# Training loop
for epoch in range(epochs):
    totalLoss = 0
    batch_idx = 0
    batch = []
    crops = []
    masked_images = []
    
    
    for item in train_dataset:
        
        # Extract the image from the 'target' attribute
        image_bytes = io.BytesIO(item['target']['bytes'])
        image = Image.open(image_bytes).convert('RGB')
        img_width, img_height = image.size


        
        
        x, y, width, height = create_random_box(img_width, img_height)
        
        cropped_image_data, mask, masked_image = modify_image(image, x,y,width, height)
        
        image = transforms(image)
        batch.append(image.to(device))
        crops.append(transforms(cropped_image_data).to(device))
        masked_images.append(transforms(masked_image).to(device))

        if len(batch) == batch_size:
            batch_idx += 1
            images = torch.stack(batch)
            crops_stacked = torch.stack(crops)
            masked_images_stacked  = torch.stack(masked_images)
            
            with torch.no_grad():
                latents_without_noise = vae(images)
                encoded_crops = vae(crops_stacked)
                masked_image_latents = vae(masked_images_stacked)
                
                latent_mask = generate_latent_mask(latents_without_noise, masked_image_latents)

            optimizer.zero_grad()

            latents_shape = images.size()
            
            with torch.no_grad():
                latents = latent_mask * latents_without_noise + (1 - latent_mask) * sampler.add_noise(latents_without_noise, sampler.timesteps[0])

            for i, timestep in enumerate(sampler.timesteps):
                # (1, 320)
                time_embedding = get_time_embedding(timestep).to(device)

                # (Batch_Size, 4, Latents_Height, Latents_Width)
                model_input = latents

                # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
                model_output = diffusion(model_input, time_embedding, encoded_crops)

                # (Batch_Size, 4, Latents_Height, Latents_Width) -> (Batch_Size, 4, Latents_Height, Latents_Width)
                latents = latent_mask * latents_without_noise + (1 - latent_mask) * sampler.step(timestep, latents, model_output)

            masked_loss = criterion(latents, latents_without_noise)
            loss = masked_loss.mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(diffusion.parameters(), clip_value)
            optimizer.step()
            
            batch = []
            crops = []
            masked_images = []
            totalLoss+=(loss.item() / batch_size)
            

            if batch_idx % 10 == 0:
                print('Train Epoch: {} [Batch {}] \tLoss: {:.6f}'.format(
                    epoch, batch_idx, loss.item() / batch_size))

    print('====> Epoch: {} Total loss: {:.4f}'.format(epoch, totalLoss))



In [None]:
torch.save({
    'diffusion': diffusion.state_dict(),
}, '/kaggle/working/weights_4.pth')

In [None]:
checkpoint = torch.load("/kaggle/working/weights_2.pth")
diffusion.load_state_dict(checkpoint['diffusion'])