In [14]:
import os
import random
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from utils import FontSampler
import time


In [None]:
# --------------------------
# 配置字体相关路径
fonts_dir = "./font_ds/fonts"            # 字体文件夹路径
text_file = "./font_ds/cleaned_text.txt" # 文本文件路径
chars_file = "./font_ds/chars.txt"                  # 常用字文件路径

In [4]:
random.seed(42)

# 初始化 FontSampler，同时会将字体分为 train/test 两类
sampler = FontSampler(fonts_dir, text_file, chars_file, font_size=76)

Loading fonts and rendering characters: 100%|██████████| 28/28 [00:49<00:00,  1.75s/it]
Loading fonts and rendering characters: 100%|██████████| 28/28 [00:51<00:00,  1.83s/it]


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --------------------------
# 定义模型并迁移到设备上，输出嵌入向量
model = models.resnet34(weights=models.ResNet34_Weights.DEFAULT).to(device)
embedding_dim = 128  # 输出嵌入向量的维度
hidden_dim = 256     # 隐藏层维度

# 修改最后全连接层，直接输出嵌入向量
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, hidden_dim),
    nn.ReLU(inplace=True),
    nn.Linear(hidden_dim, hidden_dim),
    nn.ReLU(inplace=True),
    nn.Linear(hidden_dim, embedding_dim)
).to(device)

In [16]:
import torch.nn.functional as F

def compute_loss_and_acc(style_vecs, group_size):
    """
    计算交叉熵损失和准确率。
    对于每一行，重新生成一个 tensor，直接去除对角线（即自身的分数）。

    :param style_vecs: 向量序列，由模型生成，形状为 [N, embedding_dim]
    :param group_size: 每组的大小
    :return: loss 和 acc
    """

    # 对 style_vecs 进行 L2 标准化
    style_vecs = F.normalize(style_vecs, p=2, dim=1)

    # print(f"Style Vecs: {style_vecs}")
    
    # 计算点积，然后对 0 取 max，再提升到 alpha 次方
    alpha = 4.0
    dot_prod = torch.matmul(style_vecs, style_vecs.T)

    # print(f"Similarity Matrix: {dot_prod}")


    similarity_matrix = torch.clamp(dot_prod, min=1e-8) ** alpha
    # print(f"Similarity Matrix: {similarity_matrix}")

    # time.sleep(1)

    N = similarity_matrix.size(0)
    losses = []
    correct = 0

    for i in range(N):
        row = similarity_matrix[i]  # shape: [N]
        # 重新构造一个 tensor，去除自身的分数（第 i 个元素）
        new_row = torch.cat((row[:i], row[i+1:]))  # shape: [N-1]
        # 对 new_row 进行归一化
        new_row = F.normalize(new_row, p=1, dim=0)
        
        # 构造目标分布：对于当前行所属的组（group_start 到 group_end-1），除去自身，每个目标均为 1/(group_size-1)
        target = torch.zeros_like(new_row)
        group_start = (i // group_size) * group_size
        group_end = group_start + group_size -1
        target[group_start:group_end] = 1.0 / (group_size - 1)
        
        # 计算 KL 散度损失
        row_loss = F.kl_div(new_row.log(), target, reduction='sum')
        losses.append(row_loss)

        # 计算准确率：
        # 从 new_row 选取 top-(group_size-1)，如果这些位置对应的原始索引均落在同一组中，则算作正确
        topk_indices = new_row.topk(group_size - 1).indices
        correct_in_row = ((topk_indices >= group_start) & (topk_indices < group_end)).sum().item()
        correct += correct_in_row

    loss = torch.stack(losses).mean()
    acc = correct / ((group_size - 1) * N)

    # 以 1e-3 的概率展示相似度矩阵 softmax
    if random.random() < 1e-3:
        print(f"Sim Matrix: {similarity_matrix}")

    return loss, acc

In [17]:
# --------------------------
# 训练和验证步骤（loss 基于 word2vec 风格的 compute_loss）
def train_step(model, epoch, data_loader, optimizer, batch_size, font_size, group_size):
    model.train()
    total_loss = 0
    total_acc = 0
    progress_bar = tqdm(data_loader, desc=f'Epoch {epoch + 1} - Training', leave=True)
    for batch in progress_bar:
        # Flatten the batch into a single tensor
        flattened_batch = [img.to(device) for sample in batch for img in sample]  # Flatten the nested list
        batch_tensor = torch.stack(flattened_batch).squeeze(1)  # Shape: [total_images_in_batch, C, H, W]

        # Pass the entire batch through the model
        style_vecs = model(batch_tensor)  # Shape: [total_images_in_batch, embedding_dim]

        # Reshape the output to match the expected input shape for compute_loss_and_acc
        style_vecs = style_vecs.view(batch_size, font_size * group_size, -1)  # Shape: [batch_size, group_size, embedding_dim]
        
        # Compute the loss and accuracy
        loss, acc = 0, 0
        for i in range(batch_size):
            sample_loss, sample_acc = compute_loss_and_acc(style_vecs[i], group_size)
            loss += sample_loss
            acc += sample_acc

        loss /= batch_size
        acc /= batch_size

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_acc += acc
        progress_bar.set_postfix(loss=loss.item(), acc=acc)
    progress_bar.close()

    return total_loss / len(data_loader), total_acc / len(data_loader)

def validate(model, data_loader, batch_size, font_size, group_size):
    model.eval()
    total_loss = 0
    total_acc = 0
    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc="Validating", leave=True)
        for batch in progress_bar:
            # Flatten the batch into a single tensor
            flattened_batch = [img.to(device) for sample in batch for img in sample]  # Flatten the nested list
            batch_tensor = torch.stack(flattened_batch).squeeze(1)  # Shape: [total_images_in_batch, C, H, W]

            # Pass the entire batch through the model
            style_vecs = model(batch_tensor)  # Shape: [total_images_in_batch, embedding_dim]

            # Reshape the output to match the expected input shape for compute_loss_and_acc
            style_vecs = style_vecs.view(batch_size, font_size * group_size, -1) # Shape: [batch_size, group_size, embedding_dim]
            
            # Compute the loss and accuracy
            loss, acc = 0, 0
            for i in range(batch_size):
                sample_loss, sample_acc = compute_loss_and_acc(style_vecs[i], group_size)
                loss += sample_loss
                acc += sample_acc

            loss /= batch_size
            acc /= batch_size

            total_loss += loss.item()
            total_acc += acc
            progress_bar.set_postfix(loss=loss.item(), acc=acc)
        progress_bar.close()

    return total_loss / len(data_loader), total_acc / len(data_loader)

In [18]:
# Transformations for the image 数据
data_transforms = transforms.Compose([
    transforms.ToTensor(),  # 转为张量
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 将单通道图像复制为3通道
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet 标准化
])

# 定义一个简单的 Dataset 类来处理样本
class FontDataset(Dataset):
    def __init__(self, batchs, transform=None):
        self.batchs = batchs
        self.transform = transform

    def __len__(self):
        return len(self.batchs)

    def __getitem__(self, idx):
        batch = self.batchs[idx]
        if self.transform:
            batch = [[self.transform(img) for img in inner_list] for inner_list in batch]
        return batch

In [19]:
# 定义优化器（只包含 model 参数）
optimizer = torch.optim.Adam(model.parameters(), lr=4e-4, betas=(0.9, 0.999), eps=1e-08)

# 迭代次数，可根据需求调整
num_epochs = 16
epoch_length = 16  # 每个 epoch 中的 batch 个数

# 假设每次采样返回的样本中，同一字体的样本数等于 sample_cnt，此处作为 group_size
font_cnt = 2
sample_cnt = 2
batch_size = 8  # 每个批次的样本数

In [None]:
def sample(sampler, font_cnt, sample_cnt, sample_source):
    sample = sampler.sample(font_cnt=font_cnt, sample_cnt=sample_cnt, sample_source=sample_source)
    return sample

import concurrent.futures

for epoch in range(num_epochs):
    # 收集一个 epoch 所需的所有训练样本
    train_samples = []
    val_samples = []

    # 使用多线程采样所有数据
    total_samples = epoch_length * batch_size
    val_samples_count = total_samples // 8  # 1/8 的数据用于验证
    train_samples_count = total_samples

    with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
        val_futures = [executor.submit(sample, sampler, font_cnt, sample_cnt, "test") for _ in range(val_samples_count)]
        val_samples = [future.result() for future in tqdm(val_futures, desc=f"Epoch {epoch + 1} - Collecting val samples")]

        train_futures = [executor.submit(sample, sampler, font_cnt, sample_cnt, "train") for _ in range(train_samples_count)]
        train_samples = [future.result() for future in tqdm(train_futures, desc=f"Epoch {epoch + 1} - Collecting train samples")]

    # 将采样结果重新排布为 [epoch_length, batch_size] 的格式
    train_batches = []
    for i in range(epoch_length):
        batch_samples = train_samples[i * batch_size:(i + 1) * batch_size]
        train_batches.append(batch_samples)

    val_batches = []
    val_length = len(val_samples) // batch_size
    for i in range(val_length):
        batch_samples = val_samples[i * batch_size:(i + 1) * batch_size]
        val_batches.append(batch_samples)

    # 创建训练集 Dataset 和 DataLoader
    train_dataset = FontDataset(train_batches, transform=data_transforms)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

    # 创建验证集 Dataset 和 DataLoader
    val_dataset = FontDataset(val_batches, transform=data_transforms)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

    train_loss, train_acc = train_step(model, epoch, train_loader, optimizer, batch_size, font_cnt, sample_cnt)
    val_loss, val_acc = validate(model, val_loader, batch_size, font_cnt, sample_cnt)

    print(f"Epoch {epoch + 1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # 保存模型
    model_save_path = f'font_identifier_model(dot)_epoch_{epoch + 1}.pth'
    torch.save(model, model_save_path)
    print(f"Model saved to {model_save_path}")

Epoch 1 - Collecting val samples:   0%|          | 0/16 [00:00<?, ?it/s]

Epoch 1 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 76.14it/s]
Epoch 1 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 71.94it/s]
Epoch 1 - Training:  44%|████▍     | 7/16 [00:00<00:01,  7.03it/s, acc=0.781, loss=7.88] 

Sim Matrix: tensor([[1.0000e+00, 6.4500e-01, 1.0000e-32, 1.0000e-32],
        [6.4500e-01, 1.0000e+00, 1.0000e-32, 1.0000e-32],
        [1.0000e-32, 1.0000e-32, 1.0000e+00, 8.1512e-01],
        [1.0000e-32, 1.0000e-32, 8.1512e-01, 1.0000e+00]], device='cuda:0',
       grad_fn=<PowBackward0>)


Epoch 1 - Training: 100%|██████████| 16/16 [00:02<00:00,  6.84it/s, acc=0.75, loss=0.488] 
Validating: 100%|██████████| 2/2 [00:00<00:00,  7.05it/s, acc=0.562, loss=0.979]


Epoch 1/16 - Train Loss: 1.1565, Train Acc: 0.8457, Val Loss: 0.8873, Val Acc: 0.6250
Model saved to font_identifier_model(dot)_epoch_1.pth


Epoch 2 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 75.00it/s]
Epoch 2 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 71.20it/s]
Epoch 2 - Training: 100%|██████████| 16/16 [00:02<00:00,  6.67it/s, acc=0.969, loss=0.368]
Validating: 100%|██████████| 2/2 [00:00<00:00,  7.26it/s, acc=0.781, loss=0.693]


Epoch 2/16 - Train Loss: 0.4110, Train Acc: 0.8965, Val Loss: 0.6148, Val Acc: 0.8750
Model saved to font_identifier_model(dot)_epoch_2.pth


Epoch 3 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 74.73it/s]
Epoch 3 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 67.88it/s]
Epoch 3 - Training: 100%|██████████| 16/16 [00:02<00:00,  7.06it/s, acc=0.906, loss=0.15] 
Validating: 100%|██████████| 2/2 [00:00<00:00,  5.42it/s, acc=0.969, loss=0.296]


Epoch 3/16 - Train Loss: 0.2227, Train Acc: 0.9258, Val Loss: 0.4722, Val Acc: 0.7656
Model saved to font_identifier_model(dot)_epoch_3.pth


Epoch 4 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 74.43it/s]
Epoch 4 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 70.93it/s]
Epoch 4 - Training: 100%|██████████| 16/16 [00:02<00:00,  7.46it/s, acc=1, loss=0.177]    
Validating: 100%|██████████| 2/2 [00:00<00:00,  5.34it/s, acc=1, loss=0.00345]


Epoch 4/16 - Train Loss: 0.2498, Train Acc: 0.9395, Val Loss: 0.0610, Val Acc: 1.0000
Model saved to font_identifier_model(dot)_epoch_4.pth


Epoch 5 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 76.82it/s]
Epoch 5 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 72.72it/s]
Epoch 5 - Training: 100%|██████████| 16/16 [00:02<00:00,  7.44it/s, acc=0.844, loss=0.447]
Validating: 100%|██████████| 2/2 [00:00<00:00,  7.25it/s, acc=0.906, loss=0.752]


Epoch 5/16 - Train Loss: 0.2626, Train Acc: 0.9375, Val Loss: 2.5437, Val Acc: 0.8594
Model saved to font_identifier_model(dot)_epoch_5.pth


Epoch 6 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 81.20it/s]
Epoch 6 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 72.42it/s]
Epoch 6 - Training: 100%|██████████| 16/16 [00:02<00:00,  6.62it/s, acc=0.906, loss=0.327]
Validating: 100%|██████████| 2/2 [00:00<00:00,  6.88it/s, acc=0.812, loss=0.442]


Epoch 6/16 - Train Loss: 0.3404, Train Acc: 0.9199, Val Loss: 0.4238, Val Acc: 0.7969
Model saved to font_identifier_model(dot)_epoch_6.pth


Epoch 7 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 77.49it/s]
Epoch 7 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 73.67it/s]
Epoch 7 - Training: 100%|██████████| 16/16 [00:02<00:00,  6.79it/s, acc=0.969, loss=0.107]
Validating: 100%|██████████| 2/2 [00:00<00:00,  6.98it/s, acc=0.844, loss=0.382]


Epoch 7/16 - Train Loss: 0.1756, Train Acc: 0.9375, Val Loss: 0.2832, Val Acc: 0.8594
Model saved to font_identifier_model(dot)_epoch_7.pth


Epoch 8 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 76.56it/s]
Epoch 8 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 70.52it/s]
Epoch 8 - Training: 100%|██████████| 16/16 [00:02<00:00,  6.68it/s, acc=0.75, loss=4.24]  
Validating: 100%|██████████| 2/2 [00:00<00:00,  7.18it/s, acc=0.781, loss=1.02]


Epoch 8/16 - Train Loss: 2.0353, Train Acc: 0.8770, Val Loss: 0.9899, Val Acc: 0.7656
Model saved to font_identifier_model(dot)_epoch_8.pth


Epoch 9 - Collecting val samples: 100%|██████████| 16/16 [00:00<00:00, 76.53it/s]
Epoch 9 - Collecting train samples: 100%|██████████| 128/128 [00:01<00:00, 72.94it/s]
Epoch 9 - Training:  38%|███▊      | 6/16 [00:00<00:01,  6.58it/s, acc=0.938, loss=0.557]