In [None]:
from datasets import load_dataset
import numpy as np
import torchvision
from torch import nn
import torch
from torch.utils.data import DataLoader, IterableDataset


In [None]:
train_dataset = load_dataset("mingyy/chinese_landscape_paintings", split='train', streaming=True)
class StreamingDataLoader(IterableDataset):
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size

    def __iter__(self):
        batch = []
        for item in self.dataset:
            batch.append(item)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch:
            yield batch

In [None]:
def streaming_map(dataset, transform_fn):
    for sample in dataset:
        yield transform_fn(sample)

In [None]:

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(512),
    #torchvision.transforms.RandomCrop(512),
])

In [None]:
batch_size=64


transformed_dataset = streaming_map(train_dataset, transforms)
dataloader = DataLoader(StreamingDataLoader(transformed_dataset, batch_size=batch_size), 
                        num_workers=2, pin_memory=True)

In [None]:

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # Additional layers can be added here if needed
        )
        self.conv6_mu = nn.Conv2d(512, latent_dim, 4, 2, 1)
        self.conv6_logvar = nn.Conv2d(512, latent_dim, 4, 2, 1)

    def forward(self, x):
        x = self.conv_layers(x)
        mu = self.conv6_mu(x)
        logvar = self.conv6_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()  # Use nn.Sigmoid() if your data is normalized to [0,1]
        )

    def forward(self, z):
        z = self.deconv_layers(z)
        return z

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        logvar = torch.clamp(logvar, -30, 20)
        variance = logvar.exp()
        stdev = variance.sqrt()
        eps = torch.randn_like(stdev)
        z = mu + eps * stdev
        return z

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

In [None]:
latent_dim = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = VAE(latent_dim).to(device)

In [None]:
def add_random_white_box_and_get_mask_batch(images):
    """
    Adds a random white box to each image in the batch and returns the images and masks.
    Images is a PyTorch tensor of shape (B, C, H, W).
    The mask is a binary tensor of the same batch size, height, and width as the images.
    """
    B, C, H, W = images.shape
    masks = torch.zeros((B, H, W), dtype=torch.float32, device=images.device)

    for i in range(B):
        if torch.rand(1).item() < 0.5:
            images[i], masks[i] = add_random_white_box_and_get_mask(images[i])
        else:
            masks[i] = torch.ones((B, H, W), dtype=torch.float32, device=images.device)

    return images, masks

def add_random_white_box_and_get_mask(image):
    C, H, W = image.shape
    box_width = int(torch.rand(1).item() * (0.5 * W) + 0.3 * W)
    box_height = int(torch.rand(1).item() * (0.5 * H) + 0.3 * H)
    x_start = int(torch.rand(1).item() * (W - box_width))
    y_start = int(torch.rand(1).item() * (H - box_height))

    mask = torch.zeros((H, W), dtype=torch.float32, device=image.device)
    mask[y_start:y_start + box_height, x_start:x_start + box_width] = 1
    image[:, y_start:y_start + box_height, x_start:x_start + box_width] = 1.0

    return image, mask

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import io
from PIL import Image
import copy

# Assuming 'vae' is your model and 'train_dataset' is your dataset
# Make sure to define 'vae', 'train_dataset', 'transforms', 'device', and 'batch_size' correctly

optimizer = optim.Adam(vae.parameters(), lr=0.00002)

# Initialize EMA model
ema_vae = copy.deepcopy(vae)
alpha = 0.99  # Smoothing factor for EMA

def update_ema(ema_model, model, alpha):
    with torch.no_grad():
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

def loss_function(recon_x, x, mu, logvar):
    MSE = F.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return MSE + KLD

# Training loop
num_epochs = 5
clip_value = 1.0  # Gradient clipping threshold

for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    batch_idx = 0
    batch = []

    for item in train_dataset:
        # Extract the image from the 'target' attribute
        image_bytes = io.BytesIO(item['target']['bytes'])
        image = Image.open(image_bytes).convert('RGB')

        # Apply transformations
        image = transforms(image)  # Make sure this is a PIL image or convert it to one

        # Accumulate batch
        batch.append(image.to(device))

        if len(batch) == batch_size:
            # Process full batch
            batch_idx += 1
            images = torch.stack(batch)
            batch = []  # Reset for next batch

            optimizer.zero_grad()
            recon_batch, mu, logvar = vae(images)
            loss = loss_function(recon_batch, images, mu, logvar)

            loss.backward()

            # Implementing Gradient Clipping
            torch.nn.utils.clip_grad_norm_(vae.parameters(), clip_value)

            train_loss += loss.item()
            optimizer.step()

            # Update EMA parameters
            update_ema(ema_vae, vae, alpha)

            if batch_idx % 30 == 0:
                print('Train Epoch: {} [Batch {}] \tLoss: {:.6f}'.format(
                    epoch, batch_idx, loss.item() / batch_size))

    print('====> Epoch: {} Total loss: {:.4f}'.format(epoch, train_loss))


In [None]:
torch.save({
    'vae': vae.state_dict(),
}, '/kaggle/working/weights_10.pth')

In [None]:
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import requests
from io import BytesIO

def test_image(image_path, model, device):
    model.eval()
    url = 'https://th.bing.com/th/id/OIP.L4nUSvQ7ZaefejVVEkLG5QHaEp?rs=1&pid=ImgDetMain'
    # Fetch the image from the URL
    response = requests.get(url)
    
    imgTransform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Resize(512),
        #torchvision.transforms.RandomCrop(512),
    ])

    # Ensure the request was successful
    if response.status_code == 200:
        # Open the image from the bytes in memory
        image = Image.open(BytesIO(response.content))
    else:
        print("Failed to retrieve the image")
    # Apply transformations and add batch dimension
    image_tensor = imgTransform(image).unsqueeze(0).to(device)

    # Generate the fake image
    with torch.no_grad():
        recon_batch, mu, logvar = vae(image_tensor)

    # Convert tensors back to images for visualization
    recon_image = transforms.ToPILImage()(recon_batch.squeeze().cpu())  # Denormalize if the model uses [-1, 1] range
    real_image = transforms.ToPILImage()(image_tensor.squeeze().cpu())     # Denormalize if the model uses [-1, 1] range

    # Plotting
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))

    ax[0].imshow(real_image)
    ax[0].set_title("Input Image")
    ax[0].axis("off")

    ax[1].imshow(recon_image)
    ax[1].set_title("Generated Image")
    ax[1].axis("off")

    plt.tight_layout()
    plt.show()
    
test_image("", vae, device)

In [None]:
checkpoint = torch.load("/vae.pth")
vae.load_state_dict(checkpoint['vae'])