In [56]:
import os
import random
import copy
from PIL import Image

import pickle
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from google.cloud import storage

storage_client = storage.Client("leo_font")
bucket = storage_client.bucket("leo_font")

from scr import SCR

In [67]:
def save_model(state_dict, savefd, model_name):
    blob = bucket.blob(f"{savefd}/{model_name}.pth")
    with blob.open("wb", ignore_flush=True) as f:
        torch.save(state_dict, f)

In [68]:
# Sample, Positive, Negative. By Style
class SCRDataset(Dataset):

    def __init__(self, path, num_neg=4):
        super().__init__()
        self.path = path
        self.resolution = 96 # default
        self.num_neg = num_neg
        self.letter_mapper_a = pd.read_pickle(f"{path}/pickle/letter_mapper_a.pickle")
        self.letter_mapper_b = pd.read_pickle(f"{path}/pickle/letter_mapper_b.pickle")
        self.font_mapper = pd.read_pickle(f"{path}/pickle/font_mapper.pickle")
        self.letter_mapper_ab = self.letter_mapper_a.similar + self.letter_mapper_b.similar
        self.fonts = self.font_mapper.index
        
        self.transform = transforms.Compose([
            transforms.Resize((96, 96)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
        
    def __len__(self):
        return len(self.fonts)
        
    def __getitem__(self, index):
        font = self.fonts[index]
        contents = copy.deepcopy(self.font_mapper.loc[font])
        
        content = contents.pop(random.randint(0, len(contents)-1))
        sample_img_path = f"{self.path}/train/pngs/{font}__{content}.png"
        
        pos_content = contents.pop(random.randint(0, len(contents)-1))
        pos_img_path = f"{self.path}/train/pngs/{font}__{content}.png"
        
        sample_img = self.transform(Image.open(sample_img_path).convert("RGB"))
        pos_img = self.transform(Image.open(pos_img_path).convert("RGB"))
        
        neg_imgs = []
        while len(neg_imgs) < self.num_neg:
            neg_font = random.choice(self.fonts)
            neg_img_path = f"{self.path}/train/pngs/{neg_font}__{content}.png"
            if os.path.exists(neg_img_path) & (font != neg_font):
                neg_imgs.append(self.transform(Image.open(neg_img_path).convert("RGB")))
        
        return sample_img, pos_img, torch.stack(neg_imgs)

In [69]:
scr_ds = SCRDataset(path=path)
scr_dl = DataLoader(scr_ds, shuffle=True, batch_size=64, num_workers=4)

In [70]:
scr_model = SCR()
scr_model = scr_model.cuda()

In [71]:
optimizer = torch.optim.AdamW(scr_model.parameters(), lr=1e-5)

In [72]:
epoch = 1000
savefd = "exp0514/scr"

In [74]:
lossdicts = []
for epoch_i in tqdm(range(epoch)):
    losses = []
    for x in scr_dl:
        optimizer.zero_grad()

        sample_img, pos_img, neg_imgs = x
        sample_emb, pos_emb, neg_emb = scr_model(sample_img.cuda(), pos_img.cuda(), neg_imgs.cuda())
        loss = scr_model.calculate_nce_loss(sample_emb, pos_emb, neg_emb)

        loss.backward()
        optimizer.step()
    save_model(scr_model.state_dict(), savefd, f"scr__{epoch_i}.pth")

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  0%|          | 2/1000 [00:58<8:08:37, 29.38s/it]


KeyboardInterrupt: 