In [None]:
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim

from preprocess import extract_images
from dataset import get_full_list, ChineseCharacterDataset
from models import CVAE
from utils import visualize_images, show_images, pad_to_target_size
from loss import vae_loss

In [None]:
%load_ext autoreload
%autoreload 2

### Loading and Preprocessing:

In [None]:
# Set the directory and load the dataset
image_dir = './chinese_chars/pngs'
full_data_list = get_full_list(image_dir)

In [None]:
train_size = int(len(full_data_list))
print("Training Set Size:", train_size)

In [None]:
train_data = ChineseCharacterDataset(full_data_list[:train_size], cond_type='Row', rows=[20,43])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=False)

### Model: (Version 1)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v):
        scores = torch.matmul(q.transpose(-2, -1), k) 
        attn_weights = self.softmax(scores)
        output = torch.matmul(v, attn_weights)
        return output, attn_weights
    
class AttentionEncoder(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(AttentionEncoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(256 * 8 * 8, 512)
        self.fc2 = nn.Linear(condition_dim, 512)
        self.attention = ScaledDotProductAttention()
        self.fc_mu = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

    def forward(self, x, condition):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        condition = condition.view(condition.size(0),-1)
        condition = F.relu(self.fc2(condition))

        attn_output, _ = self.attention(x.unsqueeze(1), condition.unsqueeze(1), condition.unsqueeze(1))
        attn_output = attn_output.squeeze(1)

        mu = self.fc_mu(attn_output)
        logvar = self.fc_logvar(attn_output)
        return mu, logvar

class AttentionDecoder(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(AttentionDecoder, self).__init__()
        self.fc1 = nn.Linear(512 + 512, 512)
        self.fc2 = nn.Linear(512, 256 * 8 * 8)
        self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1)
        self.attention = ScaledDotProductAttention()
        self.fc3 = nn.Linear(latent_dim, 512)
        self.fc4 = nn.Linear(condition_dim, 512)

    def forward(self, z, condition):
        condition = condition.view(condition.size(0),-1)
        z = F.relu(self.fc3(z))
        condition = F.relu(self.fc4(condition))

        attn_output, _ = self.attention(z.unsqueeze(1), condition.unsqueeze(1), condition.unsqueeze(1))
        attn_output = attn_output.squeeze(1)

        h = F.relu(self.fc1(torch.cat([z, attn_output], dim=1)))
        h = F.relu(self.fc2(h))
        h = h.view(h.size(0), 256, 8, 8)
        h = F.relu(self.deconv1(h))
        h = F.relu(self.deconv2(h))
        x_recon = torch.sigmoid(self.deconv3(h))
        return x_recon

class AttentionCVAE(nn.Module):
    def __init__(self, latent_dim, condition_dim):
        super(AttentionCVAE, self).__init__()
        self.encoder = AttentionEncoder(latent_dim, condition_dim)
        self.decoder = AttentionDecoder(latent_dim, condition_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, condition):
        mu, logvar = self.encoder(x, condition)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z, condition)
        return x_recon, mu, logvar

In [None]:
input_dim = 64 * 64
z_dim = 50
condition_dim = 64*23
learning_rate = 1e-4

num_epochs = 100

model_cols = AttentionCVAE(z_dim, condition_dim).to(device)
optimizer = optim.Adam(model_cols.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train(model_cols, train_loader_cols, input_dim, optimizer, scheduler, num_epochs) 

### Model: (Version 2)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention mechanism
class Attention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv_query = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.conv_key = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.conv_value = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        N, C, H, W = x.shape
        q = self.conv_query(x).reshape(N, C, H*W)
        k = self.conv_key(x).reshape(N, C, H*W)
        v = self.conv_value(x).reshape(N, C, H*W)

        attention = q.transpose(-2, -1) @ k * C**-0.5
        attention = F.softmax(attention, dim=-1)
        attention = v @ attention
        attention = attention.reshape(N, C, H, W)
        return x + attention

# Conditional Variational Autoencoder (CVAE)
class CVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim, condition_dim):
        super(CVAE, self).__init__()
        self.z_dim = z_dim
        self.condition_dim = condition_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim + condition_dim, hidden_dim),
            nn.ReLU(),
        )
        
        self.fc_mu = nn.Linear(hidden_dim, z_dim)
        self.fc_logvar = nn.Linear(hidden_dim, z_dim)
        
        # Condition encoder
        self.condition_encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
        )
        
        self.attention = Attention(128)
        
        # Fully connected layer for condition encoding
        self.fc_condition = None
                                      
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(z_dim + condition_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
    def initialize_fc_condition(self, condition):
        with torch.no_grad():
            sample_output = self.condition_encoder(condition)
            output_size = (min(sample_output.size(2), 4), min(sample_output.size(3), 4))
            fc_input_dim = 128 * output_size[0] * output_size[1]
            self.fc_condition = nn.Linear(fc_input_dim, self.condition_dim).to(condition.device)
            nn.init.xavier_normal_(self.fc_condition.weight)
            nn.init.zeros_(self.fc_condition.bias)
        
    def forward_condition_encoder(self, condition):
        if self.fc_condition is None:
            self.initialize_fc_condition(condition)
        
        _, _, h, w = condition.shape
        if h < 8 or w < 8:
            pad_h = max(0, 8 - h)
            pad_w = max(0, 8 - w)
            condition = F.pad(condition, (0, pad_w, 0, pad_h))            
            
        # Forward pass through the condition encoder
        x = self.condition_encoder(condition)
        x = self.attention(x)
        batch_size = x.size(0)
        
        # Dynamically determine the size of the adaptive pooling
        output_size = (min(x.size(2), 4), min(x.size(3), 4))
        
        x = nn.AdaptiveAvgPool2d(output_size)(x)  # Adaptive pooling to a fixed size or smaller
        x = x.view(batch_size, -1)  # Flatten
        x = self.fc_condition(x)  # Fully connected layer to transform to condition_dim

        return x
        
    def encode(self, x, cond_encoded):
        x_flat = x.view(x.size(0), -1)
        # Concatenate the condition with the input image
        x_cond = torch.cat([x_flat, cond_encoded], dim=1)
        h1 = self.encoder(x_cond)
        return self.fc_mu(h1), self.fc_logvar(h1)   

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, cond_encoded):
        # Concatenate z with the encoded condition
        z_cond = torch.cat([z, cond_encoded], dim=1)
        return self.decoder(z_cond)

    def forward(self, x, condition):
        cond_encoded = self.forward_condition_encoder(condition)
        mu, logvar = self.encode(x, cond_encoded)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, cond_encoded), mu, logvar


### Training:

In [None]:
# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Train
def train(model, train_data_set, input_dim, optimizer, scheduler, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, condition) in enumerate(train_data_set):
            data = data.to(device)
            condition = condition.to(device)
            padded_condition = pad_to_target_size(condition, (64, 64)).to(device)
            optimizer.zero_grad()
            reconstructed_batch, mu, logvar = model(data, padded_condition)
            loss = vae_loss(reconstructed_batch.view(data.shape[0],1,64,64), data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        scheduler.step()
        print(f'Epoch {epoch+1}, Loss: {train_loss/len(train_data_set.dataset)}')
    print(f'Final Loss: {train_loss/len(train_data_set.dataset)}')

In [None]:
input_dim = 64 * 64
z_dim = 50
condition_dim = 50
hidden_dim = 6400*8
learning_rate = 1e-3

num_epochs = 25

model = CVAE(input_dim, hidden_dim, z_dim, condition_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

train(model, train_loader, input_dim, optimizer, scheduler, num_epochs) 

In [None]:
torch.save(model.state_dict(), './saves/attn_cvae_r20_43.pth')

### Generating Images:

In [None]:
import matplotlib.pyplot as plt
import torch

def plot_generated_images(model, data_loader, num_images, batch_size, device='cpu'):
    model.eval()
    left_images = num_images
    with torch.no_grad():
        for batch_idx, (data, condition) in enumerate(data_loader):
            if batch_idx < 18:
                continue
            condition = condition.to(device)
            data = data.to(device)
            if left_images == 0:
                break
            
            z = torch.randn(batch_size, model.z_dim).to(device)
            cond_encoded = model.forward_condition_encoder(condition)
            sample = model.decode(z, cond_encoded).cpu()
            sample = sample.view(batch_size, 1, 64, 64)
            
            if left_images > batch_size:
                print_images = batch_size
                left_images = left_images-batch_size
            else:
                print_images = left_images
                left_images = 0

            for i in range(print_images):
                ref = data[i].cpu().detach().numpy().reshape(64, 64)
                img = sample[i].cpu().detach().numpy().reshape(64, 64)
                
                # Handle condition image with random shape
                cond = condition[i].cpu().detach().numpy()
                if len(cond.shape) > 2:
                    cond = cond[0]  # Select the first channel if condition is multi-channel
                cond_shape = cond.shape
                cond_resized = cond.reshape(cond_shape)

                plt.figure(figsize=(12, 4))

                # Plot condition image
                plt.subplot(1, 3, 1)
                plt.title('Condition Image')
                plt.imshow(cond_resized, cmap='gray')
                plt.axis('off')

                # Plot reference image
                plt.subplot(1, 3, 2)
                plt.title('Reference Image')
                plt.imshow(ref, cmap='gray')
                plt.axis('off')

                # Plot generated image
                plt.subplot(1, 3, 3)
                plt.title('Generated Image')
                plt.imshow(img, cmap='gray')
                plt.axis('off')

                plt.show()


plot_generated_images(model, train_loader, num_images=100, batch_size=32, device=device)