In [1]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from preprocess import extract_images
from dataset import get_full_list, ChineseCharacterDataset
from models import CVAE
from utils import visualize_images, show_images
from loss import vae_loss

In [2]:
%load_ext autoreload
%autoreload 2

### Loading and Preprocessing:

In [3]:
# Set the directory and load the dataset
image_dir = './chinese_chars/pngs'
full_data_list = get_full_list(image_dir)

In [4]:
train_size = int(1*len(full_data_list))
print("Training Set Size:", train_size)

Training Set Size: 9574


In [5]:
train_data_cols = ChineseCharacterDataset(full_data_list[:train_size], cond_type='Half')

# Create data loaders
train_loader_cols = DataLoader(train_data_cols, batch_size=32, shuffle=True)

### Training:

In [6]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [7]:
# Train
def train(model, train_data_set, input_dim, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, condition) in enumerate(train_data_set):
            print(f'Batch {batch_idx}/{len(train_data_set)} for epoch {epoch+1}')
            data = data.to(device)
            condition = condition.to(device)
            optimizer.zero_grad()
            reconstructed_batch, mu, logvar = model(data, condition)
            loss = vae_loss(reconstructed_batch.view(data.shape[0],1,64,64), data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_data_set.dataset)}')
    print(f'Final Loss: {train_loss/len(train_data_set.dataset)}')

In [None]:
input_dim = 64 * 64
z_dim = 50
condition_dim = 50
hidden_dim = 6400*10
learning_rate = 1e-3

num_epochs = 40

model_cols = CVAE(input_dim, hidden_dim, z_dim, condition_dim).to(device)
optimizer = optim.Adam(model_cols.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

print("training now")
train(model_cols, train_loader_cols, input_dim, optimizer, scheduler, num_epochs) 

training now
Batch 0/300 for epoch 1
Batch 1/300 for epoch 1
Batch 2/300 for epoch 1
Batch 3/300 for epoch 1
Batch 4/300 for epoch 1
Batch 5/300 for epoch 1
Batch 6/300 for epoch 1
Batch 7/300 for epoch 1
Batch 8/300 for epoch 1
Batch 9/300 for epoch 1
Batch 10/300 for epoch 1
Batch 11/300 for epoch 1
Batch 12/300 for epoch 1
Batch 13/300 for epoch 1
Batch 14/300 for epoch 1
Batch 15/300 for epoch 1
Batch 16/300 for epoch 1
Batch 17/300 for epoch 1
Batch 18/300 for epoch 1
Batch 19/300 for epoch 1
Batch 20/300 for epoch 1
Batch 21/300 for epoch 1
Batch 22/300 for epoch 1


### Generating Images:

In [None]:
import matplotlib.pyplot as plt
import torch

def plot_generated_images(model, data_loader, num_images, batch_size, device='cpu'):
    model.eval()
    left_images = num_images
    with torch.no_grad():
        for batch_idx, (data, condition) in enumerate(data_loader):
            condition = condition.to(device)
            data = data.to(device)
            if left_images == 0:
                break
            
            z = torch.randn(batch_size, model.z_dim).to(device)
            cond_encoded = model.forward_condition_encoder(condition)
            sample = model.decode(z, cond_encoded).cpu()
            sample = sample.view(batch_size, 1, 64, 64)
            
            if left_images > batch_size:
                print_images = batch_size
                left_images = left_images-batch_size
            else:
                print_images = left_images
                left_images = 0

            for i in range(print_images):
                ref = data[i].cpu().detach().numpy().reshape(64, 64)
                img = sample[i].cpu().detach().numpy().reshape(64, 64)
                
                # Handle condition image with random shape
                cond = condition[i].cpu().detach().numpy()
                if len(cond.shape) > 2:
                    cond = cond[0]  # Select the first channel if condition is multi-channel
                cond_shape = cond.shape
                cond_resized = cond.reshape(cond_shape)

                plt.figure(figsize=(12, 4))

                # Plot condition image
                plt.subplot(1, 3, 1)
                plt.title('Condition Image')
                plt.imshow(cond_resized, cmap='gray')
                plt.axis('off')

                # Plot reference image
                plt.subplot(1, 3, 2)
                plt.title('Reference Image')
                plt.imshow(ref, cmap='gray')
                plt.axis('off')

                # Plot generated image
                plt.subplot(1, 3, 3)
                plt.title('Generated Image')
                plt.imshow(img, cmap='gray')
                plt.axis('off')

                plt.show()


plot_generated_images(model_cols, train_loader_cols, num_images=5, batch_size=32, device=device)