In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt


# custom dataset class
from dataset_loader import CustomDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ConvVAE(nn.Module):
    def __init__(self, in_channels, out_channels, latent_size):
        super(ConvVAE, self).__init__()

        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1), # output 150x150 -> 150x150
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True)
        )
        
        self.encoder2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=2), # output 150x150 -> 152x152
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 152 -> 76
        )
        
        self.encoder3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=2), # output 76x76 -> 78x78
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 78 -> 39
        )
        
        self.encoder4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=2), # output 39x39 -> 41x41
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2) # 41 -> 20
        )
        
        self.encoder5 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=2), # output 20x20 -> 22x22
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2), # 22 -> 11
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), # same: 11 -> 11
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), # same: 11 -> 11
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.decoder1 = nn.Sequential(
            nn.Upsample(size=(20,20),mode='nearest'), # output 11x11 -> 20x20
            nn.Conv2d(512,256, kernel_size=3, stride=1, padding=1), # same 20 -> 20
            nn.BatchNorm2d(num_features=256),
            nn.LeakyReLU(negative_slope=0.01,inplace=True)
        )
        
        self.decoder2 = nn.Sequential(
            nn.Upsample(size=(39,39),mode='nearest'), # output 20x20 -> 39x39
            nn.Conv2d(512,128, kernel_size=3, stride=1, padding=1), # same 39 -> 39
            nn.BatchNorm2d(num_features=128),
            nn.LeakyReLU(negative_slope=0.01,inplace=True)
        )
        
        self.decoder3 = nn.Sequential(
            nn.Upsample(size=(76,76),mode='nearest'), # output 39x39 -> 76x76
            nn.Conv2d(256,64, kernel_size=3, stride=1, padding=1), # same 76 -> 76
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(negative_slope=0.01,inplace=True)
        )
        self.decoder4 = nn.Sequential(
            nn.Upsample(size=(150,150),mode='nearest'), # output 76x76 -> 150x150
            nn.Conv2d(128,32, kernel_size=3, stride=1, padding=1), # same 150 -> 150
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(negative_slope=0.01,inplace=True)
        )
        
        self.decoder5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), # same: 150 -> 150
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), # same: 150 -> 150
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), # same: 150 -> 150
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, out_channels, kernel_size=3, stride=1, padding=1), # same: 150 -> 150
        )
        
        ### Fully connected layers for mean and logvar ###
        self.mean = nn.Sequential(
            nn.Linear(512*11*11, latent_size),
            nn.BatchNorm1d(latent_size),  # Batch Normalization
            nn.ReLU(inplace=True),
            nn.Linear(latent_size, 512 * 11 * 11),
            nn.BatchNorm1d(512 * 11 * 11),  # Batch Normalization
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (512, 11, 11))  # Reshape to match the shape after encoder
        )
        self.logvar = nn.Sequential(
            nn.Linear(512*11*11, latent_size),
            nn.BatchNorm1d(latent_size),  # Batch Normalization
            nn.ReLU(inplace=True),
            nn.Linear(latent_size, 512*11*11),
            nn.BatchNorm1d(512*11*11),  # Batch Normalization
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (512, 11, 11))  # Reshape to match the shape after encoder
        )
        
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def encode(self, x):
        x1 = self.encoder1(x)
        x2 = self.encoder2(x1)
        x3 = self.encoder3(x2)
        x4 = self.encoder4(x3)
        x5 = self.encoder5(x4)
        return x1, x2, x3, x4, x5

    def decode(self, x1, x2, x3, x4, x5):
        z = self.decoder1(x5)
        z = self.decoder2(torch.cat((z,x4), dim=1))
        z = self.decoder3(torch.cat((z,x3), dim=1))
        z = self.decoder4(torch.cat((z,x2), dim=1))
        z = self.decoder5(torch.cat((z,x1), dim=1))
        return z

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.encode(x)
        
        mean = self.mean(x5.reshape(x5.shape[0],-1))
        logvar = self.logvar(x5.reshape(x5.shape[0],-1))
        x5 = self.reparameterize(mean,logvar)
        
        x_recon = self.decode(x1, x2, x3, x4, x5)
        return x_recon, mean, logvar

# Define the loss function for Convolutional VAE
def conv_vae_loss(recon_x, x, mu, logvar):
    #BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    MSE = F.mse_loss(recon_x, x, reduction='mean')
    
    # KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return MSE + KLD

# Set random seed for reproducibility
torch.manual_seed(42)


ModuleNotFoundError: No module named 'dataset_loader'

In [None]:

def train_vae(model, train_loader, optimizer, num_epochs=5):
    # Training loop
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

    # Lists to store loss values for plotting
    all_losses = []

    for epoch in tqdm(range(num_epochs)):
        model.train()
        epoch_losses = []

        for batch_idx, data in enumerate(train_loader):
            inputs, targets = data
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            recon_batch, mu, logvar = conv_vae(inputs)
            loss = conv_vae_loss(recon_batch, targets, mu, logvar)
            loss.backward()
            optimizer.step()

            epoch_losses.append(loss.item())

            if batch_idx % 100 == 99:
                print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(inputs), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item() / len(inputs)))

        # Save average loss for the epoch
        avg_epoch_loss = sum(epoch_losses) / len(epoch_losses)
        all_losses.append(avg_epoch_loss)

        # decrease the learning rate
        scheduler.step()

    # Plot the loss over epochs
    plt.plot(range(1, num_epochs + 1), all_losses, marker='o')
    plt.xlabel('Epochs')
    plt.ylabel('Average Loss')
    plt.title('Training Loss Over Epochs')
    plt.show()



In [None]:

# Set your dataset path and other parameters
input_folder = "landscape_images\gray"
output_folder = "landscape_images\color"
batch_size = 4
epochs = 200
learning_rate = 1e-3
latent_size = 20


# Transformations for output images (no grayscale conversion)
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load your dataset with separate transforms for input and output
dataset = CustomDataset(input_folder, output_folder, transform=transform)

# Split it into train and test data
train_size = 32 #int(0.001 * len(dataset))  # Adjust the split ratio as needed
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create the dataloader using only the transform for input images
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Instantiate the Convolutional VAE model and optimizer
conv_vae = ConvVAE(in_channels=1, out_channels=3,latent_size=latent_size).to(device)
optimizer = optim.Adam(conv_vae.parameters(), lr=learning_rate)


In [None]:
%time

train=0

if train:
    train_vae(conv_vae, train_loader, optimizer, epochs)



In [None]:
# save the model weights
save_params= 0
if save_params:
    torch.save(conv_vae.state_dict(), "generative_models/004_params/004_vae_weights_g2c.pth")
    torch.save(conv_vae, "generative_models/004_params/004_vae_g2c.pth")

In [None]:
# Load the trained model
conv_vae = ConvVAE(in_channels=1, out_channels=3,latent_size=latent_size).to(device)
conv_vae.load_state_dict(torch.load("generative_models/005_params/005_vae_weights_g2c.pth"))
conv_vae.eval()  # Set model to evaluation mode

# Create a test data loader
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define a loss function (e.g., mean squared error for regression)
criterion = nn.MSELoss()

# Evaluate on the test set
test = 0
if test:
    test_loss = 0.0
    with torch.no_grad():
        for inputs, targets in test_loader:
            outputs,_,_ = conv_vae(inputs.to(device))
            outputs = outputs.detach().cpu()
            loss = criterion(outputs, targets)
            test_loss += loss.item() * targets.size(0)

    test_loss /= len(test_loader.dataset)
    print("Test Loss:", test_loss)


In [None]:

# Generate a sample from the Convolutional VAE, this part can be used for both image generation and noise testing
with torch.no_grad():
    sample = torch.randn([32, 3, 150, 150]).to(device)
    #sample = conv_vae.forward(sample)
    
    # Get a test image and target
    input_images, target_images = next(iter(train_loader)) # .to(device) #+sample*0.5
    input_images, target_images = input_images.to(device), target_images.to(device)
    output_images,_,_ = conv_vae(input_images)


# You can then visualize the generated samples using a library like matplotlib
fig, axes = plt.subplots(batch_size, 2, figsize=(20,10))

for i in range(batch_size):

    # Get the current target image
    output_image = output_images[i,:,:,:].squeeze().detach().cpu()

    # target_image is already a numpy array, just squeeze it:
    target_image = target_images[i,:,:,:].squeeze().detach().cpu()

    # Normalize predictions and targets to [0, 1]
    target_image = (target_image - target_image.min()) / (target_image.max() - target_image.min())
    output_image = (output_image - output_image.min()) / (output_image.max() - output_image.min())


    # torch tensor has permute
    axes[i,0].imshow(target_image.permute(1,2,0))
    axes[i,0].set_title("Target Image")
    # torch tensor has permute
    axes[i,1].imshow(output_image.permute(1,2,0))
    axes[i,1].set_title("Output Image")

plt.show()

In [None]:
print(torch.Tensor(sample).shape)

In [None]:
test_loss /= len(train_loader.dataset)
print("Test Loss:", test_loss)
