In [None]:
"""Test code for torchfont module."""

from fontTools.ttLib import TTFont

from truetype_vs_postscript_transformer.torchfont.transforms import (
    Compose,
    DecomposeSegment,
    NormalizeSegment,
    PostScriptSegmentToTensor,
    QuadToCubic,
)

font_path = "../fonts/ofl/abeezee/ABeeZee-Regular.ttf"

font = TTFont(file=font_path)

transforms = Compose(
    [
        DecomposeSegment(),
        NormalizeSegment(),
        QuadToCubic(),
        PostScriptSegmentToTensor("trajectory"),
    ],
)

transforms


In [2]:
from truetype_vs_postscript_transformer.torchfont.datasets.font_pair import (
    FontPairDataset,
)

dataset = FontPairDataset(
    src_font=font,
    target_font=font,
    transform=transforms,
)


In [3]:
from torch.utils.data import DataLoader

from truetype_vs_postscript_transformer.modules.collate_fn import (
    FontPairPostScriptCollate,
)

dataloader = DataLoader(
    dataset,
    batch_size=16,
    collate_fn=FontPairPostScriptCollate(),
)


In [None]:
import torch

from truetype_vs_postscript_transformer.modules.loss import ReconstructionLoss
from truetype_vs_postscript_transformer.torchfont.io.font import (
    POSTSCRIPT_COMMAND_TYPE_TO_NUM,
)

for batch in dataloader:
    src, _ = batch
    src_cmd, src_coords = src

    batch_size, seq_len = src_cmd.shape
    src_cmd_logits = torch.zeros(
        (batch_size, seq_len, len(POSTSCRIPT_COMMAND_TYPE_TO_NUM)),
        device=src_cmd.device,
    )
    src_cmd_logits.scatter_(
        -1,
        src_cmd.unsqueeze(-1),
        100.0,
    )

    pred = (
        src_cmd_logits,
        src_coords,
    )

    loss_fn = ReconstructionLoss(ce_weight=0, mse_weight=0, chamfer_weight=1)

    loss_value = loss_fn(pred, src)

    print("Loss:", loss_value.item())
