In [2]:
import argparse
import os
import random

import cv2
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from IPython.display import HTML
from matplotlib import pyplot as plt
from PIL import Image, ImageOps
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import os
import cv2
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [3]:
batch_size = 64


# Define custom dataset class
class HandwritingDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):

        self.alphabets = u"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz-' "
        self.max_str_len = 24  # max length of input labels
        self.num_of_characters = len(
            self.alphabets) + 1  # +1 for ctc pseudo blank
        self.num_of_timestamps = 64  # max length of predicted labels

        self.labels = pd.read_csv(csv_file)
        self.labels.dropna(axis=0, inplace=True)
        self.labels = self.labels[self.labels['IDENTITY'] != 'UNREADABLE']
        # self.labels['IDENTITY'] = self.labels['IDENTITY'].str.upper()
        # self.labels = self.labels.iloc[:(len(self.labels)//batch_size)*batch_size]
        self.labels = self.labels.iloc[:64]
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def label_to_num(self, label):
        label_num = []
        for ch in label:
            label_num.append(self.alphabets.find(ch))

        return np.array(label_num)

    def num_to_label(self, num):
        ret = ""
        for ch in num:
            if ch == -1:  # CTC Blank
                break
            else:
                ret += self.alphabets[ch]
        return ret

    def preprocess(self, img):
        
        (h, w) = img.shape

        final_img = np.ones([256, 256])*255  # blank white image

        # crop
        if w > 256:
            img = img[:, :256]

        if h > 256:
            img = img[:256, :]

        final_img[:h, :w] = img
        return cv2.rotate(final_img, cv2.ROTATE_90_CLOCKWISE)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        image = Image.open(img_path).convert('L')
        image = ImageOps.grayscale(image)
        # image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if self.transform:
            image = self.transform(image)

        image = image[-1, :, :]
        # label = self.labels.iloc[idx, 1]
        label = np.ones([self.max_str_len]) * -1
        label[0:len(self.labels.iloc[idx, 1])] = self.label_to_num(
            self.labels.iloc[idx, 1])
        return self.preprocess(image), label

In [4]:
# Define generator network

class Generator(nn.Module):
    def __init__(self, noise_dim, text_dim, image_size):
        super(Generator, self).__init__()
        self.noise_dim = noise_dim
        self.text_dim = text_dim
        self.image_size = image_size
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, image_size * image_size),
            nn.Tanh()
        )

    def forward(self, noise, text):
        x = torch.cat([noise, text], dim=1)
        x = self.fc(x)
        x = x.view(-1, 1, self.image_size, self.image_size)
        return x

# Define discriminator network


class Discriminator(nn.Module):
    def __init__(self, text_dim, image_size):
        super(Discriminator, self).__init__()
        self.text_dim = text_dim
        self.image_size = image_size
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Sequential(
            nn.Linear(512 * (image_size // 16) *
                      (image_size // 16) + text_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, image, text):
        # image = image.double()
        # if (image.shape == torch.Size([64, 256, 256])):
        #     image = image.unsqueeze(1)
        if (len(image.shape) == 3):
            image = image.unsqueeze(1)
        # image = image.unsqueeze(1)
        x = self.conv(image)
        x = x.view(-1, 512 * (self.image_size // 16) * (self.image_size // 16))
        x = torch.cat([x, text], dim=1)
        x = self.fc(x)
        return x.squeeze()

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [6]:
learning_rate = 0.0002
num_epochs = 2
batch_size = 64
noise_dim = 128
text_dim = 24
image_size = 256
dataset_path = 'handwriting-recognition'

# # Decide which device we want to run on
ngpu = 0
device = torch.device("cuda:0" if (
    torch.cuda.is_available() and ngpu > 0) else "cpu")

# %%

generator = Generator(noise_dim, text_dim, image_size).to(device)
discriminator = Discriminator(text_dim, image_size).to(device)
generator.double()
discriminator.double()
# generator.load_state_dict(torch.load('generator.pth', map_location=device))
# discriminator.load_state_dict(torch.load('discriminator.pth', map_location=device))

criterion = nn.BCELoss()
criterion.double()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))


# %%
train_dataset = HandwritingDataset(os.path.join(dataset_path, 'written_name_train_v2.csv'), os.path.join(dataset_path, 'train_v2', 'train'), transform=None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

val_dataset = HandwritingDataset(os.path.join(dataset_path, 'written_name_validation_v2.csv'), os.path.join(dataset_path, 'validation_v2', 'validation'), transform=None)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=0)


# %%
for epoch in range(num_epochs):
    for i, (real_images, text_labels) in enumerate(train_loader):
        real_images = real_images.to(device)
        text_labels = text_labels.to(device)
        # print(real_images, text_labels)
        # break
        # Train discriminator
        optimizer_d.zero_grad()

        # Real images
        real_labels = torch.ones(batch_size).to(device)
        real_output = discriminator(real_images, text_labels)
        real_labels = real_labels.double()
        d_loss_real = criterion(real_output, real_labels)

        # Fake images
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_labels = torch.zeros(batch_size).to(device)
        fake_images = generator(noise, text_labels)
        fake_output = discriminator(fake_images.detach(), text_labels)
        fake_labels = fake_labels.double()
        d_loss_fake = criterion(fake_output, fake_labels)

        # Total discriminator loss
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_d.step()

        # Train generator
        optimizer_g.zero_grad()

        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_labels = torch.ones(batch_size).to(device)
        fake_images = generator(noise, text_labels)
        fake_output = discriminator(fake_images, text_labels)
        fake_labels = fake_labels.double()
        g_loss = criterion(fake_output, fake_labels)

        g_loss.backward()
        optimizer_g.step()

        # Print losses every 100 iterations
        if i % 100 == 0:
            print(f"Epoch {epoch}, Batch {i}: D_loss={d_loss.item():.4f}, G_loss={g_loss.item():.4f}")




Epoch 0, Batch 0: D_loss=1.4983, G_loss=54.5758


In [None]:
generator.eval()
with torch.no_grad():
    val_losses = []
    for i, (real_images, text_labels) in enumerate(val_loader):
        real_images = real_images.to(device)
        text_labels = text_labels.to(device)

        # Compute loss for discriminator
        real_labels = torch.ones(batch_size).to(device)
        real_output = discriminator(real_images, text_labels)
        real_labels = real_labels.double()
        d_loss_real = criterion(real_output, real_labels)

        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_labels = torch.zeros(batch_size).to(device)
        fake_images = generator(noise, text_labels)
        fake_output = discriminator(fake_images, text_labels)
        fake_labels = fake_labels.double()
        d_loss_fake = criterion(fake_output, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        val_losses.append(d_loss.item())

    print(f"Validation loss: {sum(val_losses)/len(val_losses):.4f}")

In [None]:
# Save the models to a file
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

test_image = generator(torch.randn(1, noise_dim).to(device), "test")
plt.imshow(test_image[0].cpu().detach().numpy(), cmap='gray')