In [102]:
import torch.nn as nn
import torch
import torchvision.transforms
import os
from PIL import Image
from itertools import chain
import generate 

In [8]:
IMAGE_CHANNELS = 3
IMAGE_DIMENSION = 100

In [75]:
class FCCharacterNet(nn.Module):

    def __init__(self, h=100, w=100, image_channels=3, num_labels=10):
        super(FCCharacterNet, self).__init__()
        
        self.h = h
        self.w = w
        self.image_channels = image_channels
        self.num_labels = num_labels
        
        self.fc_inputs = image_channels * h * w
        
        self.fc1 = nn.Linear(in_features=self.fc_inputs, out_features=num_labels)
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        
        # X: B 3 W H
        image_channels = x.size(1)
        H = x.size(2)
        W = x.size(3)

        x = x.view(-1, image_channels * H * W)
        
        x = self.fc1(x)
        
        x = self.softmax(x)
        
        return x

In [80]:
def load_image(path):
    """This function loads an image into memory when you give it
       the path of the image
    """
    img = Image.open(path)
    img.load()

    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    return transform(img)


In [81]:
path = '/home/chris/code/font-learning/images/Great_Vibes/52.png'
input1 = load_image(path)
print(input1.size())
batch = input1.unsqueeze(0)
print(batch.size())

torch.Size([3, 100, 100])
torch.Size([1, 3, 100, 100])


In [103]:
in_list = []
out_list = []

for font_idx, font_name in enumerate(generate.fonts):
        
    out_row = [1 if x == font_idx else 0 for x in range(len(generate.fonts))]
    
    for character in chain(range(ord("a"), ord("z")), range(ord("A"), ord("Z")), range(ord("0"), ord("9"))):
        image = load_image(f"images/{font_name}/{character}.png")
        in_list.append(image)
        
        out_list.append(torch.tensor(out_row))
        
inputs = torch.stack(in_list)
outputs = torch.stack(out_list)

In [108]:
net = FCCharacterNet()