In [1]:
import torch.nn as nn
import torch
import torchvision.transforms
import os
from PIL import Image
from itertools import chain
import generate 
import torch.optim as optim

In [2]:
IMAGE_CHANNELS = 3
IMAGE_DIMENSION = 100

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [331]:
class FCCharacterNet(nn.Module):

    def __init__(self, num_labels, h=100, w=100, image_channels=3):
        super(FCCharacterNet, self).__init__()
        
        self.h = h
        self.w = w
        self.image_channels = image_channels
        self.num_labels = num_labels
        
        self.fc_inputs = image_channels * h * w
        
        self.fc1 = nn.Linear(in_features=self.fc_inputs, out_features=num_labels)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        
        # X: B 3 W H
        image_channels = x.size(1)
        H = x.size(2)
        W = x.size(3)

        x = x.view(-1, image_channels * H * W)
        
        x = self.fc1(x)
        
        x = self.softmax(x)
        
        return x

In [197]:
def load_image(path):
    """This function loads an image into memory when you give it
       the path of the image
    """
    img = Image.open(path)
    img.load()

    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    return transform(img)


In [6]:
path = '/home/chris/code/font-learning/images/Great_Vibes/52.png'
input1 = load_image(path)
print(input1.size())
batch = input1.unsqueeze(0)
print(batch.size())

torch.Size([3, 100, 100])
torch.Size([1, 3, 100, 100])


In [254]:
in_list = []
out_list = []

for font_idx, font_name in enumerate(generate.fonts):
        
    out_row = [1 if x == font_idx else 0 for x in range(len(generate.fonts))]
    
    for character in chain(range(ord("a"), ord("z")), range(ord("A"), ord("Z")), range(ord("0"), ord("9"))):
        image = load_image(f"images/{font_name}/{character}.png")
        in_list.append(image)
        
        out_list.append(torch.tensor(out_row, dtype=torch.float32))

#in_list = [in_list[0], in_list[-1]]
#out_list = [out_list[0], out_list[-1]]

inputs = torch.stack(in_list).to(device)
true_classes = torch.stack(out_list).to(device)

In [338]:
net = FCCharacterNet(num_labels=len(generate.fonts)).to(device)
criterion = nn.MSELoss()
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [339]:
outputs = net(inputs)
loss = criterion(outputs, true_classes)
loss

tensor(0.2080, device='cuda:0', grad_fn=<MseLossBackward>)

In [340]:
n_epochs = 1000

for epoch_idx in range(n_epochs):
    # zero the parameter gradients
    optimizer.zero_grad()
    
    # forward + backward + optimize
    outputs = net(inputs)

    loss = criterion(outputs, true_classes)
    loss.backward()
    optimizer.step()
    
    # print statistics
    if epoch_idx % 100 == 0:
        print(f'Epoch {epoch_idx} loss: {loss:0.5f}')


Epoch 0 loss: 0.20804
Epoch 100 loss: 0.32000
Epoch 200 loss: 0.32000
Epoch 300 loss: 0.32000
Epoch 400 loss: 0.32000
Epoch 500 loss: 0.32000
Epoch 600 loss: 0.32000
Epoch 700 loss: 0.32000
Epoch 800 loss: 0.32000
Epoch 900 loss: 0.32000


In [327]:
outputs = net(inputs)
for i in range(len(outputs)):
    guesses = [f'{c:0.3f}' for c in outputs[i].tolist()]
    trues = [f'{c:0.3f}' for c in true_classes[i].tolist()]
    print(f'{guesses}: {trues}')



['0.293', '0.283', '0.110', '0.151', '0.053']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.581', '-0.229', '0.196', '0.372', '0.037']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.354', '0.374', '0.330', '-0.108', '-0.052']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.527', '0.089', '0.343', '0.281', '-0.041']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.175', '0.161', '0.252', '0.212', '-0.181']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.433', '0.072', '-0.121', '0.175', '0.247']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.314', '0.329', '0.077', '-0.088', '-0.061']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.348', '-0.247', '0.209', '0.411', '0.208']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.414', '0.029', '0.045', '0.116', '0.121']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.375', '-0.120', '0.056', '0.187', '-0.112']: ['1.000', '0.000', '0.000', '0.000', '0.000']
['0.508', '-0.340', '0.240', '0.170', '0.390']: ['1.000', '

In [194]:
list(outputs.tolist())

[[2.1860256043467146e-43, 0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0]]