In [1]:
import os
import random
import copy
from PIL import Image

import pickle
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from google.cloud import storage

storage_client = storage.Client("leo_font")
bucket = storage_client.bucket("leo_font")

from scr import SCR

In [2]:
def save_model(state_dict, savefd, model_name):
    blob = bucket.blob(f"{savefd}/{model_name}.pth")
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(state_dict, f)

In [3]:
# Sample, Positive, Negative. By Style
class SCRDataset(Dataset):

    def __init__(self, path, num_neg=16):
        super().__init__()
        self.path = path
        self.fonts = sorted([f for f in os.listdir(self.path) if ".ipy" not in f])
        self.fontdict = {}
        for font in self.fonts:
            self.fontdict[font] = sorted([f.replace(".png","").split("__")[-1] for f in os.listdir(f"{self.path}/{font}/") if f.endswith(".png")])
        self.resolution = 128 # default
        self.num_neg = num_neg
        
        self.transform = transforms.Compose([
            transforms.Resize((self.resolution, self.resolution)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
    def __len__(self):
        return len(self.fonts)
        
    def __getitem__(self, index):
        font = self.fonts[index]
        content = random.choice(self.fontdict[font])
        sample_img_path = f"{self.path}/{font}/{font}__{content}.png"
        
        pos_content = random.choice(self.fontdict[font])
        pos_content = pos_content if content != pos_content else random.choice(self.fontdict[font])
        pos_img_path = f"{self.path}/{font}/{font}__{pos_content}.png"
        
        sample_img = self.transform(Image.open(sample_img_path).convert("RGB"))
        pos_img = self.transform(Image.open(pos_img_path).convert("RGB"))
        
        neg_imgs = []
        neg_fonts = [f for f in self.fonts if f != font]
        while len(neg_imgs) < self.num_neg:
            neg_font = random.choice(neg_fonts)
            neg_content = random.choice(self.fontdict[neg_font])
            neg_img_path = f"{self.path}/{neg_font}/{neg_font}__{neg_content}.png"
            neg_imgs.append(self.transform(Image.open(neg_img_path).convert("RGB")))
        
        return sample_img, pos_img, torch.stack(neg_imgs)

In [4]:
epoch = 660000
path = "/home/jupyter/ai_font/data/exp0717/train0730_whole"
savefd = "exp0717/scr"

In [5]:
scr_ds = SCRDataset(path=path)
scr_dl = DataLoader(scr_ds, shuffle=True, batch_size=16, num_workers=4)

In [6]:
scr_model = SCR(image_size=128)
scr_model = scr_model.cuda()

In [7]:
optimizer = torch.optim.AdamW(scr_model.parameters(), lr=1e-5)

In [None]:
lossdicts = []
pbar = tqdm(total=epoch)
epoch_count = 0
while epoch_count < epoch:
    for x in scr_dl:
        optimizer.zero_grad()

        sample_img, pos_img, neg_imgs = x
        sample_emb, pos_emb, neg_emb = scr_model(sample_img.cuda(), pos_img.cuda(), neg_imgs.cuda())
        loss = scr_model.calculate_nce_loss(sample_emb, pos_emb, neg_emb)

        loss.backward()
        optimizer.step()

        pbar.update(1)
        epoch_count += 1
        
        if epoch_count % 10000 == 0:
            save_model(scr_model.state_dict(), savefd, f"scr__{epoch_count}")

  0%|          | 112/660000 [02:55<289:10:53,  1.58s/it]