In [1]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from preprocess import extract_images
from dataset import get_full_list, ChineseCharacterDataset
from models import CVAE
from utils import visualize_images, show_images, plot_generated_images, pad_to_target_size
from loss import vae_loss

  warn(


In [2]:
%load_ext autoreload
%autoreload 2

### Loading and Preprocessing:

In [3]:
# Set the directory and load the dataset
image_dir = './chinese_chars/pngs'
full_data_list = get_full_list(image_dir)

In [4]:
train_size = int(len(full_data_list))
print("Training Set Size:", train_size)

Training Set Size: 9574


In [5]:
train_data = ChineseCharacterDataset(full_data_list[:train_size], cond_type='Row', rows=[20,43])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=False)

### Model:

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Conv_CVAE(nn.Module):
    def __init__(self, z_dim, condition_dim):
        super(Conv_CVAE, self).__init__()
        self.z_dim = z_dim
        self.condition_dim = condition_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(2, 32, 4, stride=2, padding=1),  # Change input channels to 2 for concatenated input
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc1 = nn.Linear(128 * 8 * 8, z_dim)
        self.fc2 = nn.Linear(128 * 8 * 8, z_dim)

        # Condition encoder
        self.condition_encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, condition_dim) 
        )

        # Decoder
        self.fc3 = nn.Linear(z_dim + condition_dim, 128 * 8 * 8)
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, 8, 8)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode(self, x, condition):
        # Concatenate the condition with the input image
        x_cond = torch.cat([x, condition], dim=1)
        h1 = self.encoder(x_cond)
        return self.fc1(h1), self.fc2(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, condition):
        # Encode the condition
        cond_encoded = self.condition_encoder(condition)
        # Concatenate z with the encoded condition
        z_cond = torch.cat([z, cond_encoded], dim=1)
        h3 = F.relu(self.fc3(z_cond))
        return self.decoder(h3)

    def forward(self, x, condition):
        mu, logvar = self.encode(x, condition)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, condition), mu, logvar

### Training:

In [10]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [11]:
# Train
def train(model, train_data_set, input_dim, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, condition) in enumerate(train_data_set):
            data = data.to(device)
            condition = condition.to(device)
            padded_condition = pad_to_target_size(condition, (64, 64)).to(device)
            optimizer.zero_grad()
            reconstructed_batch, mu, logvar = model(data, padded_condition)
            loss = vae_loss(reconstructed_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_data_set.dataset)}')
    print(f'Final Loss: {train_loss/len(train_data_set.dataset)}')

In [None]:
input_dim = 64 * 64
z_dim = 50
condition_dim = 50
learning_rate = 1e-4

num_epochs = 25

model = Conv_CVAE(z_dim, condition_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train(model, train_loader, input_dim, optimizer, scheduler, num_epochs) 

In [None]:
torch.save(model.state_dict(), './saves/unet_cvae_r20_43.pth')

### Generating Images:

In [None]:
import matplotlib.pyplot as plt
import torch

def plot_generated_images(model, data_loader, num_images, batch_size, device='cpu'):
    model.eval()
    left_images = num_images
    with torch.no_grad():
        for batch_idx, (data, condition) in enumerate(data_loader):
            condition = condition.to(device)
            data = data.to(device)
            if left_images == 0:
                break
            
            z = torch.randn(batch_size, model.z_dim).to(device)
            sample = model.decode(z, condition).cpu()
            sample = sample.view(batch_size, 1, 64, 64)
            
            if left_images > batch_size:
                print_images = batch_size
                left_images = left_images-batch_size
            else:
                print_images = left_images
                left_images = 0
                

            for i in range(print_images):
                ref = data[i].cpu().detach().numpy().reshape(64, 64)
                img = sample[i].cpu().detach().numpy().reshape(64, 64)

                plt.figure(figsize=(8, 4))

                # Plot reference image
                plt.subplot(1, 2, 1)
                plt.title('Reference Image')
                plt.imshow(ref, cmap='gray')
                plt.axis('off')

                # Plot generated image
                plt.subplot(1, 2, 2)
                plt.title('Generated Image')
                plt.imshow(img, cmap='gray')
                plt.axis('off')

                plt.show()

# Example usage:
# Assuming `model` is an instance of CVAE and `train_loader` is your DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
plot_generated_images(model, train_loader, num_images=5, batch_size=32, device=device)