In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from torchvision import transforms
from PIL import Image
import os
import torch.nn.functional as F
from tqdm import tqdm
import torchvision
import random
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Step 1: Prepare your dataset
class SketchDataset(Dataset):
    def __init__(self, csv_file, image_root_dir, sketch_root_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.image_root_dir = image_root_dir
        self.sketch_root_dir = sketch_root_dir
        self.transform = transform
        self.num_sketches = 3594  # Update with the actual number of sketches
        self.num_samples = len(self.data_frame)

        # Compute variance of the data
        self.variance = self.compute_variance().to(device)

    def compute_variance(self):
        # Load all images and compute variance
        data = []
        for idx in tqdm(range(len(self))):
            img_name = os.path.join(self.image_root_dir, self.data_frame.iloc[idx, 0] + '.jpg')
            image = Image.open(img_name).convert('RGB')
            if self.transform:
                image = self.transform(image)
            data.append(image.numpy())

        data = np.array(data)
        data = torch.from_numpy(data).to(device)
        return torch.var(data / 255.0)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_root_dir, self.data_frame.iloc[idx, 0] + '.jpg')
        sketch_idx = idx % self.num_sketches  # Cyclic indexing for sketches
        sketch_name = os.path.join(self.sketch_root_dir, f"sketch_{sketch_idx + 1}.png")
        
        image = Image.open(img_name).convert('RGB')
        sketch = Image.open(sketch_name)

        label = torch.tensor(self.data_frame.iloc[idx, 1:], dtype=torch.float32)

        rand_idx = random.randint(0, self.num_samples - 1)
        rand_label = torch.tensor(self.data_frame.iloc[rand_idx, 1:], dtype=torch.float32)

        if self.transform:
            image = self.transform(image)
            sketch = self.transform(sketch)
        
        return label, sketch, image, img_name, rand_label, self.variance

transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(256, scale=(0.8, 1.0), ratio=(0.75, 1.333)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = SketchDataset(csv_file='/home/cvlab/Karan/A_3/Dataset_A4/Train_labels.csv', 
                              image_root_dir='/home/cvlab/Karan/A_3/Dataset_A4/Train_data',
                              sketch_root_dir='/home/cvlab/Karan/A_3/Dataset_A4/Unpaired_sketch',
                              transform=transform)
dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)


In [None]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, num_layers=[16, 32]):
        super(Encoder, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i, filters in enumerate(num_layers):
            if i == 0:
                in_channels = 1  # Assuming grayscale input
            else:
                in_channels = num_layers[i-1]
            self.conv_layers.append(
                nn.Conv2d(in_channels, filters, kernel_size=3, padding=1, stride=2)
            )
        self.z_e = nn.Conv2d(num_layers[-1], d, kernel_size=3, padding=1)

    def forward(self, x):
        for conv_layer in self.conv_layers:
            x = nn.functional.relu(conv_layer(x))
        z_e = self.z_e(x)
        return z_e

class Decoder(nn.Module):
    def __init__(self, num_layers=[32, 16]):
        super(Decoder, self).__init__()
        self.convT_layers = nn.ModuleList()
        for i, filters in enumerate(num_layers):
            if i == 0:
                in_channels = d  # Assuming d is defined
            else:
                in_channels = num_layers[i-1]
            self.convT_layers.append(
                nn.ConvTranspose2d(in_channels, filters, kernel_size=4, padding=1, stride=2)
            )
        self.output = nn.ConvTranspose2d(num_layers[-1], 1, kernel_size=3, padding=1)

    def forward(self, y):
        for convT_layer in self.convT_layers:
            y = nn.functional.relu(convT_layer(y))
        decoded = torch.sigmoid(self.output(y))
        return decoded

# Usage
d = 64  # Define the value of d
encoder = Encoder()
decoder = Decoder()
inputs = torch.randn(1, 1, 28, 28)  # Assuming input shape (batch_size, channels, height, width)
z_e = encoder(inputs)
decoded = decoder(z_e)


In [None]:
import torch
import torch.nn as nn

class VectorQuantizer(nn.Module):
    def __init__(self, k):
        super(VectorQuantizer, self).__init__()
        self.k = k
    
    def forward(self, inputs):
        # inputs shape: (batch_size, w, h, d)
        batch_size, w, h, d = inputs.size()
        # Flatten inputs to shape (batch_size * w * h, d)
        inputs_flat = inputs.view(-1, d)
        
        # Calculate distances between inputs and codebook
        distances = torch.norm(inputs_flat.unsqueeze(1) - self.codebook.unsqueeze(0), dim=-1)
        # Find indices of nearest codebook entries
        indices = torch.argmin(distances, dim=-1)
        
        # Reshape indices to match inputs shape
        indices = indices.view(batch_size, w, h)
        
        # Lookup nearest codebook entries
        quantized = torch.gather(self.codebook, 0, indices.unsqueeze(-1).expand(-1, -1, -1, d))
        
        return quantized, indices

    def init_codebook(self, d):
        # Initialize codebook with random values
        self.codebook = nn.Parameter(torch.randn(self.k, d))

# Usage
k = 10  # Set the number of codebook entries
vq = VectorQuantizer(k)
inputs = torch.randn(1, 28, 28, 64)  # Assuming input shape (batch_size, width, height, channels)
vq.init_codebook(inputs.size(-1))
quantized, indices = vq(inputs)


In [None]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, num_layers=[16, 32]):
        super(Encoder, self).__init__()
        self.conv_layers = nn.ModuleList()
        for i, filters in enumerate(num_layers):
            if i == 0:
                in_channels = 1  # Assuming grayscale input
            else:
                in_channels = num_layers[i-1]
            self.conv_layers.append(
                nn.Conv2d(in_channels, filters, kernel_size=3, padding=1, stride=2)
            )

    def forward(self, x):
        for conv_layer in self.conv_layers:
            x = nn.functional.relu(conv_layer(x))
        return x

class Decoder(nn.Module):
    def __init__(self, num_layers=[32, 16]):
        super(Decoder, self).__init__()
        self.convT_layers = nn.ModuleList()
        for i, filters in enumerate(num_layers):
            if i == 0:
                in_channels = d  # Assuming d is defined
            else:
                in_channels = num_layers[i-1]
            self.convT_layers.append(
                nn.ConvTranspose2d(in_channels, filters, kernel_size=4, padding=1, stride=2)
            )
        self.output = nn.ConvTranspose2d(num_layers[-1], 1, kernel_size=3, padding=1)

    def forward(self, y):
        for convT_layer in self.convT_layers:
            y = nn.functional.relu(convT_layer(y))
        decoded = torch.sigmoid(self.output(y))
        return decoded

class VectorQuantizer(nn.Module):
    def __init__(self, k, d):
        super(VectorQuantizer, self).__init__()
        self.k = k
        self.d = d
        self.codebook = nn.Parameter(torch.randn(k, d))

    def forward(self, inputs):
        batch_size, w, h, d = inputs.size()
        inputs_flat = inputs.view(-1, d)
        distances = torch.norm(inputs_flat.unsqueeze(1) - self.codebook.unsqueeze(0), dim=-1)
        indices = torch.argmin(distances, dim=-1)
        indices = indices.view(batch_size, w, h)
        quantized = torch.gather(self.codebook, 0, indices.unsqueeze(-1).expand(-1, -1, -1, d))
        return quantized, indices

    def sample(self, indices):
        quantized = torch.gather(self.codebook, 0, indices.unsqueeze(-1).expand(-1, -1, -1, self.d))
        return quantized

# Building VQ-VAE
def build_vqvae(k, d, input_shape=(1, 28, 28), num_layers=[16, 32]):
    global SIZE

    # Encoder
    encoder = Encoder(num_layers=num_layers)

    # Vector Quantization
    vector_quantizer = VectorQuantizer(k, d)

    # Decoder
    decoder = Decoder(num_layers=num_layers[::-1])

    SIZE = input_shape[1] // (2 ** (len(num_layers)))

    # VQ-VAE Model (training)
    class SamplingLayer(nn.Module):
        def __init__(self):
            super(SamplingLayer, self).__init__()

        def forward(self, indices):
            z_q = vector_quantizer.sample(indices)
            return z_q

    sampling_layer = SamplingLayer()

    class StraightThroughEstimator(nn.Module):
        def __init__(self):
            super(StraightThroughEstimator, self).__init__()

        def forward(self, z_q, z_e):
            return z_q + (z_e - z_q).detach()

    straight_through = StraightThroughEstimator()

    vq_vae = nn.Sequential(
        encoder,
        vector_quantizer,
        sampling_layer,
        straight_through,
        decoder
    )

    # VQ-VAE Model (inference)
    class VQVAESampler(nn.Module):
        def __init__(self):
            super(VQVAESampler, self).__init__()

        def forward(self, indices):
            z_q = vector_quantizer.sample(indices)
            generated = decoder(z_q)
            return generated

    vq_vae_sampler = VQVAESampler()

    # Getter to easily access the codebook for visualization
    def get_vq_vae_codebook():
        return vector_quantizer.codebook.detach().cpu().numpy()

    return vq_vae, vq_vae_sampler, encoder, decoder, get_vq_vae_codebook

# Hyperparameters
NUM_LATENT_K = 10                 # Number of codebook entries
NUM_LATENT_D = 64                 # Dimension of each codebook entries
BETA = 1.0                        # Weight for the commitment loss

# INPUT_SHAPE = x_train.shape[1:]
SIZE = None                       # Spatial size of latent embedding
                                  # will be set dynamically in `build_vqvae

VQVAE_BATCH_SIZE = 128            # Batch size for training the VQVAE
VQVAE_NUM_EPOCHS = 20             # Number of epochs
VQVAE_LEARNING_RATE = 3e-4        # Learning rate
VQVAE_LAYERS = [16, 32]           # Number of filters for each layer in the encoder

PIXELCNN_BATCH_SIZE = 128         # Batch size for training the PixelCNN prior
PIXELCNN_NUM_EPOCHS = 10          # Number of epochs
PIXELCNN_LEARNING_RATE = 3e-4     # Learning rate
PIXELCNN_NUM_BLOCKS = 12          # Number of Gated PixelCNN blocks in the architecture
PIXELCNN_NUM_FEATURE_MAPS = 32    # Width of each PixelCNN block
vq_vae, vq_vae_sampler, encoder, decoder, get_vq_vae_codebook = build_vqvae(k, d, input_shape, num_layers)


In [None]:
import torch

def mse_loss(ground_truth, predictions):
    return torch.mean((ground_truth - predictions)**2)

def latent_loss(dummy_ground_truth, outputs):
    global BETA
    del dummy_ground_truth
    z_e, z_q = torch.split(outputs, outputs.size(-1) // 2, dim=-1)
    vq_loss = torch.mean((z_e.detach() - z_q)**2)
    commit_loss = torch.mean((z_e - z_q.detach())**2)
    latent_loss = vq_loss + BETA * commit_loss
    return latent_loss


In [None]:
import torch

def zq_norm(y_true, y_pred):
    _, z_q = torch.split(y_pred, y_pred.size(-1) // 2, dim=-1)
    return torch.mean(torch.norm(z_q, dim=-1))

def ze_norm(y_true, y_pred):
    z_e, _ = torch.split(y_pred, y_pred.size(-1) // 2, dim=-1)
    return torch.mean(torch.norm(z_e, dim=-1))


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Assuming x_train is your training data, convert it to a PyTorch tensor
# x_train_tensor = torch.tensor(x_train, dtype=torch.float32)

# Create DataLoader for batch training
train_loader = DataLoader(dataloader, batch_size=VQVAE_BATCH_SIZE, shuffle=True)

# Define optimizer
optimizer = optim.Adam(vq_vae.parameters(), lr=VQVAE_LEARNING_RATE)

# Training loop
for epoch in range(VQVAE_NUM_EPOCHS):
    vq_vae.train()
    total_loss = 0.0
    total_mse_loss = 0.0
    total_latent_loss = 0.0
    total_zq_norm = 0.0
    total_ze_norm = 0.0
    
    for batch_idx, (inputs,) in enumerate(train_loader):
        optimizer.zero_grad()
        reconstructed, codes = vq_vae(inputs)
        
        # Compute MSE loss
        mse_loss_val = mse_loss(inputs, reconstructed)
        
        # Compute latent loss
        latent_loss_val = latent_loss(None, codes)
        
        # Compute total loss
        loss = mse_loss_val + latent_loss_val
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        
        # Compute metrics
        with torch.no_grad():
            zq_norm_val = zq_norm(None, codes)
            ze_norm_val = ze_norm(None, codes)
        
        total_loss += loss.item()
        total_mse_loss += mse_loss_val.item()
        total_latent_loss += latent_loss_val.item()
        total_zq_norm += zq_norm_val.item()
        total_ze_norm += ze_norm_val.item()
    
    # Compute average losses and metrics
    avg_loss = total_loss / len(train_loader)
    avg_mse_loss = total_mse_loss / len(train_loader)
    avg_latent_loss = total_latent_loss / len(train_loader)
    avg_zq_norm = total_zq_norm / len(train_loader)
    avg_ze_norm = total_ze_norm / len(train_loader)
    
    # Print epoch results
    print(f"Epoch {epoch + 1}/{VQVAE_NUM_EPOCHS}:")
    print(f"  Loss: {avg_loss:.4f}, MSE Loss: {avg_mse_loss:.4f}, Latent Loss: {avg_latent_loss:.4f}")
    print(f"  z_q Norm: {avg_zq_norm:.4f}, z_e Norm: {avg_ze_norm:.4f}")
