In [1]:
import os
import random
from PIL import Image, ImageDraw, ImageFont
from torch.utils.data import Dataset, DataLoader

class FontsDataset(Dataset):
    def __init__(self, fonts_root, characters_set_path, img_size):
        self.fonts_path = find_fonts(fonts_root)
        self.characters_set = load_characters_set(characters_set_path)
        self.img_size = img_size
        self.size = len(self.fonts_path) * len(self.characters_set)
        self.sample_num = len(self.characters_set)
        self.fonts_random_step = len(self.fonts_path) -1
        self.char_random_step = self.sample_num -1


    def draw_char(self, font_path, char):
        img = Image.new("L", self.img_size, color="white")
        draw = ImageDraw.Draw(img)
        font = ImageFont.truetype(font_path, self.img_size[1]/1.2)
        # center
        _,_,w,h = font.getbbox(char)
        draw.text(((self.img_size[0]-w)//2,(self.img_size[1]-h)//2), char, font=font, fill="black")
        return img

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        font_id, char_id = divmod(idx, self.sample_num)
        char = self.characters_set[char_id]
        pos_char = self.characters_set[random.randint(font_id+1, self.char_random_step+font_id)%self.sample_num]
        neg_font_id = random.randint(1, self.fonts_random_step)
        neg_char = self.characters_set[random.randint(0,self.char_random_step)]
        anchor = self.draw_char(self.fonts_path[font_id], char)
        pos = self.draw_char(self.fonts_path[font_id], pos_char)
        neg = self.draw_char(self.fonts_path[neg_font_id], neg_char)
        return anchor, pos, neg

def find_fonts(directory, extensions=('.ttf', '.otf', '.ttc')):
    font_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(extensions):
                font_files.append(os.path.join(root, file))
    return font_files

def load_characters_set(path):
    with open(path, 'r') as f:
        characters = f.read().replace("\n","")
        characters = list(characters)
    return characters

In [2]:
from torchvision.utils import make_grid
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
fonts_root = "/home/qba/Data/Project/DeepLearning/FontDream/data/font/中文"
char_set_path = "/home/qba/Data/Project/DeepLearning/FontDream/data/common-char-level-1.txt"

In [3]:
transformer = transforms.Compose([transforms.ToTensor()])

def collate_fn(batch): 
    anchor, pos, neg = zip(*batch)
    anchor = torch.stack([transformer(img) for img in anchor])
    pos = torch.stack([transformer(img) for img in pos])
    neg = torch.stack([transformer(img) for img in neg])
    return anchor, pos, neg

data = FontsDataset(fonts_root, char_set_path, (64, 64))
datalader = DataLoader(data, batch_size=16, shuffle=True, collate_fn=collate_fn, pin_memory=False)


In [12]:
anchor, pos, neg = next(iter(datalader))
img = torch.cat([anchor, pos, neg], dim=-1)
print(img.shape)
plt.figure(figsize=(5, 30),dpi=100)
plt.imshow(make_grid(img, nrow=1).permute(1, 2, 0))

OSError: stack overflow