In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.models import efficientnet_b7, EfficientNet_B7_Weights

import pandas as pd
import copy
import numpy as np

import shutil
import os
from PIL import Image
from tqdm import tqdm
from google.cloud import storage

storage_client = storage.Client("leo_font")
bucket = storage_client.bucket("leo_font")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
class CustomDataset(Dataset):
    def __init__(self, files, transform=None):
        self.transform = transform
        self.files = files
        
    def __len__(self):
        # return int(len(self.files)/100)
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        image = Image.open(path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image

In [None]:
class StyleMeasurer(nn.Module):
    def __init__(self):
        super().__init__()
        self.effnet = efficientnet_b7(weights=EfficientNet_B7_Weights.IMAGENET1K_V1)

    def forward(self, x):
        x = self.effnet(x)
        return torch.tanh(x)

In [None]:
style_measurer = StyleMeasurer()
style_measurer = style_measurer.to(device)

In [None]:
style_measurer.load_state_dict(torch.load(f"{localfd}/{model1}"))
style_measurer.eval()

In [None]:
reportfd = "/home/jupyter/ai_font/data/test_ttf/report"
contentfd = "/home/jupyter/ai_font/data/zipfiles/raw/size96/seen"
stylefd = "/home/jupyter/ai_font/data/test_ttf/pngs"
filterfd = "/home/jupyter/ai_font/data/test_ttf/filter"
os.makedirs(reportfd, exist_ok=True)

In [None]:
prototypes = {}
for font in tqdm(fonts):
    batchsize = 16
    protofiles = [f"{stylefd}/{f}" for f in os.listdir(stylefd) if (font in f)&(".png" in f)]
    ds = CustomDataset(protofiles, transform=transform)
    c = 0
    embs = []
    while c < len(ds):
        imgs = []
        for _ in range(batchsize):
            if c < len(ds):
                imgs.append(ds[c])
                c += 1
        embs.append(style_measurer(torch.stack(imgs).cuda()).detach().cpu())
    proto = torch.concat(embs).mean(0,keepdim=True)
    prototypes[font] = proto

In [None]:
import pickle
with open("/home/jupyter/ai_font/data/test_ttf/prototypes.pickle", "wb") as f:
    pickle.dump(prototypes,f)