In [1]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from preprocess import extract_images
from dataset import get_full_list, ChineseCharacterDataset
from models import CVAE
from utils import visualize_images, show_images, plot_generated_images, pad_to_target_size
from loss import vae_loss

  warn(


In [2]:
%load_ext autoreload
%autoreload 2

### Loading and Preprocessing:

In [3]:
# Set the directory and load the dataset
image_dir = './chinese_chars/pngs'
full_data_list = get_full_list(image_dir)

In [4]:
train_size = int(len(full_data_list))
print("Training Set Size:", train_size)

Training Set Size: 9574


In [5]:
train_data = ChineseCharacterDataset(full_data_list[:train_size], cond_type='Row', rows=[20,43])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=False)

### Training:

In [6]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [7]:
# Train
def train(model, train_data_set, input_dim, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, condition) in enumerate(train_data_set):
            data = data.to(device)
            condition = condition.to(device)
            padded_condition = pad_to_target_size(condition, (64, 64)).to(device)
            optimizer.zero_grad()
            reconstructed_batch, mu, logvar = model(data, padded_condition)
            loss = vae_loss(reconstructed_batch.view(data.shape[0],1,64,64), data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_data_set.dataset)}')
    print(f'Final Loss: {train_loss/len(train_data_set.dataset)}')

In [None]:
input_dim = 64 * 64
z_dim = 50
condition_dim = 50
hidden_dim = 6400*8
learning_rate = 1e-3

num_epochs = 25

model = CVAE(input_dim, hidden_dim, z_dim, condition_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train(model, train_loader, input_dim, optimizer, scheduler, num_epochs) 

Epoch 1, Loss: 1734.7816492339277
Epoch 2, Loss: 1408.894734182013
Epoch 3, Loss: 1394.9308922908888
Epoch 4, Loss: 1378.8421255238798
Epoch 5, Loss: 1367.3558959833435
Epoch 6, Loss: 1351.2693015425893
Epoch 7, Loss: 1339.6957619696736
Epoch 8, Loss: 1328.135917426102
Epoch 9, Loss: 1317.3303180835628
Epoch 10, Loss: 1307.6427123079923
Epoch 11, Loss: 1291.252030646151
Epoch 12, Loss: 1271.71491581407
Epoch 13, Loss: 1255.7590383542015
Epoch 14, Loss: 1235.8538233125685
Epoch 15, Loss: 1211.0373370423804
Epoch 16, Loss: 1181.8956076719828
Epoch 17, Loss: 1154.9805920188303
Epoch 18, Loss: 1130.7025966831557
Epoch 19, Loss: 1107.3552968264473
Epoch 20, Loss: 1084.8433331198103
Epoch 21, Loss: 1062.1705125147698
Epoch 22, Loss: 1042.746285767851
Epoch 23, Loss: 1026.9685948046956
Epoch 24, Loss: 1013.7911331817455


In [None]:
torch.save(model.state_dict(), './saves/cvae_r20_43_extra.pth')

### Generating Images:

In [None]:
import matplotlib.pyplot as plt

def plot_generated_images(model, data_loader, num_images, batch_size, device='cpu'):
    model.eval()
    left_images = num_images
    with torch.no_grad():
        for batch_idx, (data, condition) in enumerate(data_loader):
            if batch_idx < 18:
                continue
            condition = condition.to(device)
            data = data.to(device)
            if left_images == 0:
                break
                
            
            z = torch.randn(batch_size, model.z_dim).to(device)
            padded_condition = pad_to_target_size(condition, (64, 64)).to(device)
            cond_encoded = model.forward_condition_encoder(padded_condition)
            sample = model.decode(z, cond_encoded).cpu()
            sample = sample.view(batch_size, 1, 64, 64)
            
            if left_images > batch_size:
                print_images = batch_size
                left_images = left_images-batch_size
            else:
                print_images = left_images
                left_images = 0

            for i in range(print_images):
                ref = data[i].cpu().detach().numpy().reshape(64, 64)
                img = sample[i].cpu().detach().numpy().reshape(64, 64)
                
                # Handle condition image with random shape
                cond = condition[i].cpu().detach().numpy()
                if len(cond.shape) > 2:
                    cond = cond[0]  # Select the first channel if condition is multi-channel
                cond_shape = cond.shape
                cond_resized = cond.reshape(cond_shape)

                plt.figure(figsize=(12, 4))

                # Plot condition image
                plt.subplot(1, 3, 1)
                plt.title('Condition Image')
                plt.imshow(cond_resized, cmap='gray')
                plt.axis('off')

                # Plot reference image
                plt.subplot(1, 3, 2)
                plt.title('Reference Image')
                plt.imshow(ref, cmap='gray')
                plt.axis('off')

                # Plot generated image
                plt.subplot(1, 3, 3)
                plt.title('Generated Image')
                plt.imshow(img, cmap='gray')
                plt.axis('off')

                plt.show()
                

In [None]:
plot_generated_images(model, train_loader, num_images=5
                      , batch_size=32, device=device)