In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import optim

import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import os
from utils.svg import SVG
from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
import pandas as pd
from torchsummary import summary

import utils.dataloader as dl
from svglib.svglib import svg2rlg
from reportlab.graphics import renderPM
import shutil
import optuna
from optuna.trial import TrialState

from IPython import display
%matplotlib inline

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

SVG.ENCODE_HEIGHT = 80
# fonts_number = 100
fonts_number = None

print(f'Device: {device}')

Device: cuda


## Загрузка данных

In [3]:
print('Loading data')
dl.load_data(fonts_number)

Loading data
<################################################################################>: 100.% [15147 / 15147]


In [4]:
print('Encoding data')
stored_path = Path(f'data/data_{SVG.ENCODE_HEIGHT}_{fonts_number}.json')
if stored_path.exists():
    data = pd.read_json(str(stored_path))
    # data['data'] = data['data'].apply(lambda x: np.array(x))
else:
    data = dl.get_data(fonts_number)
    data.to_json(str(stored_path))
data

Encoding data


Unnamed: 0,font,letter,data
0,!crass_roots_ofl,a,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
1,!crass_roots_ofl,b,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
2,!crass_roots_ofl,c,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
3,!crass_roots_ofl,d,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
4,!crass_roots_ofl,e,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
...,...,...,...
337538,çarsi,v,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
337539,çarsi,w,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
337540,çarsi,x,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."
337541,çarsi,y,"[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,..."


## Определение даталоадеров
`Dataloader` для букв

`DataloaderRows` для линий

In [5]:
class Dataloader:
    def __init__(self, df: pd.DataFrame, test_size=0.1, shuffle=False, batch_size=24):
        if shuffle:
            df = df.sample(frac=1).reset_index(drop=True)
        xs = np.array(df['data'].to_list(), dtype=np.float32)
        ys = df['letter'].to_numpy()
        self.x_train, self.x_test, self.y_train, self.y_test = train_test_split(xs, ys, test_size=test_size, shuffle=shuffle)
        self.batch_size = batch_size

    def iterate(self):
        bs = self.batch_size
        for i in range(len(self.x_train) // bs):
            yield self.x_train[i * bs: (i + 1) * bs], self.y_train[i * bs: (i + 1) * bs]

    def iterate_test(self):
        yield self.x_test, self.y_test

    def __len__(self):
        return len(self.x_train) // self.batch_size + int(len(self.x_train) % self.batch_size > 0)
    
class DataloaderRows:
    def __init__(self, df: pd.DataFrame, test_size=0.1, shuffle=False, batch_size=24):
        if shuffle:
            df = df.sample(frac=1).reset_index(drop=True)
        xs = np.array(df['data'].to_list(), dtype=np.float32)
        xs = xs.reshape((-1, SVG.ENCODE_WIDTH))
        self.x_train, self.x_test = train_test_split(xs, test_size=test_size, shuffle=shuffle)
        self.batch_size = batch_size

    def iterate(self):
        bs = self.batch_size
        for i in range(len(self.x_train) // bs):
            yield self.x_train[i * bs: (i + 1) * bs]

    def iterate_test(self):
        yield self.x_test

    def __len__(self):
        return len(self.x_train) // self.batch_size + int(len(self.x_train) % self.batch_size > 0)
    

In [6]:
dataloader = Dataloader(data, test_size=0.15, shuffle=True)
dataloader_rows = DataloaderRows(data, test_size=0.15, shuffle=True)

In [7]:
def save_sampled(x: np.ndarray, name, close=True):
    file = Path('imgs') / name
    file.parent.mkdir(parents=True, exist_ok=True)
    svg = SVG.decode(x, path=file)
    svg.dump_to_file()

In [8]:
save_sampled(dataloader.x_test[3], 'test_print.svg', close=False)

## Определение энкодеров
`AE` - автоэнкодер

`CAE` - автоэнкодер с условием

In [9]:
class BnAndDropout(nn.Module):
    def __init__(self, features, p=0.15):
        super().__init__()
        self.bn = nn.BatchNorm1d(num_features=features)
        self.do = nn.Dropout(p)
    
    def forward(self, x):
        return self.do(self.bn(x))


class Block(nn.Module):
    def __init__(self, f_in, f_out):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(f_in, f_out),
            nn.Tanh(),
            # BnAndDropout(f_out),
        )

    def forward(self, x):
        return self.layer(x)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        # TODO: Double check the ordering here
        return embeddings

In [10]:
class AE(nn.Module):
    def __init__(self, params):
        super().__init__()
        
        self.encode = True
        self.decode = True
        
        self.encoder = []
        for i in range(len(params) - 1):
            self.encoder.append(Block(params[i], params[i + 1]))
        
        self.decoder = []
        for i in range(len(params) - 1, 0, -1):
            self.decoder.append(Block(params[i], params[i - 1]))
        
        self.encoder = nn.Sequential(*self.encoder)
        self.decoder = nn.Sequential(*self.decoder)
        
    def forward(self, x):
        if self.encode:
            x = self.encoder(x)
        if self.decode:
            x = self.decoder(x)
        return x

    def loss(self):
        def _inner(y_hat, y):
            return ((y - y_hat)**2).mean(axis=0).sum()

        return _inner


class CAE(nn.Module):
    def __init__(self, params, in_labels):
        super().__init__()
        
        encoder_params = params[:]
        encoder_params[0] += in_labels
        
        decoder_params = params[::-1]
        decoder_params[0] += in_labels

        self.encoder = []
        for i in range(len(encoder_params) - 1):
            self.encoder.append(Block(encoder_params[i], encoder_params[i + 1]))
        
        self.decoder = []
        self.decoder.append(Block(decoder_params[0], decoder_params[1]))
        for i in range(1, len(decoder_params) - 1):
            self.decoder.append(Block(decoder_params[i] * 2, decoder_params[i + 1]))

        self.encoder = nn.ParameterList(self.encoder)
        self.decoder = nn.ParameterList(self.decoder)
        
    def forward(self, x, labels):
        shape = x.shape
        
        x = x.view(shape[0], -1)
        x = torch.cat((x, labels), 1)
        accumulate = []
        for layer in self.encoder:
            x = layer(x)
            accumulate.append(x)
        accumulate.pop()
        
        x = torch.cat((x, labels), 1)
        x = self.decoder[0](x)
        
        for layer in self.decoder[1:]:
            connection = accumulate.pop()
            x = torch.cat((x, connection), 1)
            x = layer(x)
        x = x.view(shape[0], shape[1], -1)
        
        return x
    
    def loss(self):
        def _inner(y_hat, y):
            return ((y - y_hat)**2).mean(axis=0).sum()

        return _inner

## Вспомогательные функции

In [11]:
ONE_HOT_LEN = len(dl.GLYPH_FILTER)
one_hot_rules = {
    glyph: one_hot
    for glyph, one_hot in zip(
        dl.GLYPH_FILTER, 
        range(0, ONE_HOT_LEN),
    )
}

def labels2num(labels):
    return torch.Tensor([one_hot_rules[i] for i in labels]).long()

def labels2one_hot(labels):
    return F.one_hot(labels2num(labels), num_classes=ONE_HOT_LEN)

In [12]:
interval = len(dataloader) / 6

train_ts, train_loss = [], []
test_ts, test_loss = [], []


def show_progress(t, epochs, save_to=None, info: dict | None = None):
    display.clear_output(wait=True)
    fig, (ax1, ax2) = plt.subplots(2, 1, constrained_layout=True, figsize=(12, 10))
    fig.suptitle(f'Epoch {t:3.3f} / {epochs}', fontsize=16)
    
    last_size = 0.5
    
    for ax, msg in zip((ax1, ax2), ('', f'last {int(last_size*100)}%')):
        title = f'loss {msg}'
        if info is not None:
            title += ' | ' + ' | '.join(f'{key}: {value}' for key, value in info.items())
        ax.set_title(title)
        ax.set_xlabel('time (epochs)')
        ax.set_ylabel('loss')
    last_train = str(train_loss[-1]) if len(train_loss) > 0 else ''
    last_test = str(test_loss[-1]) if len(test_loss) > 0 else ''
    
    ax1.plot(train_ts, train_loss, c='darkblue', lw=3, label=f'train: {last_train}')
    ax1.plot(test_ts, test_loss, c='green', marker='o', lw=5, label=f'test: {last_test}')
    
    ax2.plot(train_ts[-int(len(train_ts)*last_size):], train_loss[-int(len(train_loss)*last_size):], c='darkblue', lw=3, label=f'train: {last_train}')
    ax2.plot(test_ts[-int(len(test_ts)*last_size):], test_loss[-int(len(test_loss)*last_size):], c='green', marker='o', lw=5, label=f'test: {last_test}')
    
    ax1.legend()
    ax2.legend()
    if save_to is None:
        plt.show() 
    else:
        plt.savefig(save_to)
        plt.close()
    
def train_cae(epoch, epochs, dataloader, model, loss_fn, optimizer, scheduler, pbar=None, show=True):
    model.train()
    num_batches = len(dataloader)
    for batch, (inp_data, labels) in enumerate(dataloader.iterate()):
        inp_data = torch.Tensor(inp_data).to(device)        
        labels = labels2one_hot(labels).to(device)

        output = model(inp_data, labels)
        loss = loss_fn(output, inp_data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % interval == 0:
            t = epoch + (batch + 1) / num_batches
            train_ts.append(t)
            train_loss.append(loss.item())
            if show:                
                show_progress(t, epochs, info={'lr': scheduler.get_last_lr()[0]})
            if pbar is not None:
                pbar.refresh()
    scheduler.step()        
    
def test_cae(epoch, epochs, dataloader, model, loss_fn, show=True):
    model.eval()
    num_batches = len(dataloader)
    tmp_test_loss = []
    with torch.no_grad():
        for inp_data, labels in dataloader.iterate_test():
            inp_data = torch.Tensor(inp_data).to(device)
            labels = labels2one_hot(labels).to(device)

            result = model(inp_data, labels)
            loss = loss_fn(result, inp_data)

            tmp_test_loss.append(loss.item())
            
    test_ts.append(epoch)
    test_loss.append(np.mean(tmp_test_loss))
    if show:
        show_progress(epoch, epochs)


In [13]:
def train_ae(epoch, epochs, dataloader, model, loss_fn, optimizer, scheduler, pbar=None, show=True):
    model.train()
    num_batches = len(dataloader)
    for batch, inp_data in enumerate(dataloader.iterate()):
        inp_data = torch.Tensor(inp_data).to(device)

        output = model(inp_data)
        loss = loss_fn(output, inp_data)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % interval == 0:
            t = epoch + (batch + 1) / num_batches
            train_ts.append(t)
            train_loss.append(loss.item())
            if show:                
                show_progress(t, epochs, info={'lr': scheduler.get_last_lr()[0]})
            if pbar is not None:
                pbar.refresh()
    scheduler.step()        
    
def test_ae(epoch, epochs, dataloader, model, loss_fn, show=True):
    model.eval()
    num_batches = len(dataloader)
    tmp_test_loss = []
    with torch.no_grad():
        for images in dataloader.iterate_test():
            images = torch.Tensor(images).to(device)

            decoded = model(images)
            loss = loss_fn(decoded, images)

            tmp_test_loss.append(loss.item())
            
    test_ts.append(epoch)
    test_loss.append(np.mean(tmp_test_loss))
    if show:
        show_progress(epoch, epochs)

In [14]:
epoch = 0

def run_maker(model_type):
    is_cae =  model_type == 'cae'
    model_char = 'c' if is_cae else ''
    def _inner(model, dataloader, optimizer, scheduler, epochs, params, batch_size, _epoch=0, run_name=None, trial=None):
        global epoch

        if run_name is None:
            run_name = f'run_size{SVG.ENCODE_HEIGHT}_{model_char}ae_{",".join(map(str, params))}'

        save_folder = Path(f'models_{model_char}ae') / run_name
        loss_img_path = str(save_folder / '_loss.png')
        save_folder.mkdir(parents=True, exist_ok=True)
        plt.clf()
        file_format = 'svg'
        max_epoch = epoch + epochs
        loss_fn = model.loss()
        dataloader.batch_size = batch_size

        if trial is None:
            pbar = trange(epoch, max_epoch)
            rng = pbar
        else:
            pbar = None
            rng = range(epoch, max_epoch)
        for _epoch in rng:
            
            if is_cae:
                train_fn = train_cae
                test_fn = test_cae
            else:
                train_fn = train_ae
                test_fn = test_ae
                
            train_fn(_epoch, max_epoch, dataloader, model, loss_fn, optimizer, scheduler, pbar, show=trial is None)
            test_fn(_epoch + 1, max_epoch, dataloader, model, loss_fn, show=trial is None)
                
            torch.save(model.state_dict(), save_folder / 'ckpt.pt')

            model.eval()

            if trial is not None:
                if trial.should_prune():
                    raise optuna.exceptions.TrialPruned()
            else:
                # if is_cae:
#                     with torch.no_grad():
#                         num = np.random.randint(0, len(dataloader.x_test))

#                         image = torch.tensor(dataloader.x_test[num].reshape((1, SVG.ENCODE_HEIGHT, LINE_WIDTH))).to(device)
#                         label = labels2one_hot([dataloader.y_test[num]]).to(device)
#                         sample = model(image, label)[0]
#                         save_sampled(dataloader.x_test[num], f'{run_name}/test/{_epoch}_{dataloader.y_test[num]}_orig.{file_format}')
#                         save_sampled(sample.cpu().detach().numpy(), f'{run_name}/test/{_epoch}_{dataloader.y_test[num]}_gen.{file_format}')

#                         image = torch.tensor(dataloader.x_train[num].reshape((1, SVG.ENCODE_HEIGHT, SVG.ENCODE_WIDTH))).to(device)
#                         label = labels2one_hot([dataloader.y_train[num]]).to(device)
#                         sample = model(image, label)[0]
#                         save_sampled(dataloader.x_train[num], f'{run_name}/train/{_epoch}_{dataloader.y_test[num]}_orig.{file_format}')
#                         save_sampled(sample.cpu().detach().numpy(), f'{run_name}/train/{_epoch}_{dataloader.y_test[num]}_gen.{file_format}')
                epoch = _epoch + 1
            show_progress(_epoch + 1, max_epoch, loss_img_path, info={'params':params, 'lr': scheduler.get_last_lr()[0]})

    return _inner


run_cae = run_maker('cae')
run_ae = run_maker('ae')

In [15]:
def setup_maker(model_type):
    is_cae =  model_type == 'cae'
    model_char = 'c' if is_cae else ''

    def _inner(params, lr, weight_decay=2e-5, step_size=1):
        global train_ts, train_loss, test_ts, test_loss, epoch

        if is_cae:
            model = CAE(params, ONE_HOT_LEN).to(device)
        else:
            model = AE(params).to(device)
            
        loss_fn = model.loss()
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.95)

        train_ts, train_loss = [], []
        test_ts, test_loss = [], []
        epoch = 0

        return model, optimizer, scheduler, loss_fn

    return _inner
    
setup_cae = setup_maker('cae')
setup_ae = setup_maker('ae')

## Поиск конфигурации модели для кодирования линий

In [95]:
def run_for_search(trial):
    lr = trial.suggest_float('lr', 1e-6, 1e-3, log=True)
    params = [
        SVG.ENCODE_WIDTH,
        trial.suggest_int('p1', 10, 25),
        trial.suggest_int('p2', 8, 20),
        trial.suggest_int('p3', 4, 10),
    ]

    for i in range(3):
        params.append(trial.suggest_int(f'p{i}', 4, 20))
    
    model, optimizer, scheduler, loss_fn = setup_ae(
        params=params,
        lr=lr,
        weight_decay=5e-5,
    )
    run_ae(
        model=model, 
        dataloader=dataloader_rows, 
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=4,
        params=params,
        batch_size=2048,
        run_name='temp',
        trial=trial,
    )
    return test_loss[-1] + max(test_loss[-1] - train_loss[-1], 0) ** 2 


study = optuna.create_study(
    direction="minimize",
    storage="sqlite:///db.sqlite3",
    study_name="ae_lr_3_layers_4_epoch"
)

study.optimize(run_for_search, n_trials=50, show_progress_bar=True)

[32m[I 2023-02-18 19:23:50,745][0m Trial 49 finished with value: 9.119536116486415e-05 and parameters: {'lr': 0.00013387223829575858, 'p0': 20, 'p1': 17, 'p2': 14}. Best is trial 44 with value: 6.124428604674469e-05.[0m


<Figure size 640x480 with 0 Axes>

## Обучаем кодировщик линий

In [16]:
LINE_WIDTH = 10
params = [SVG.ENCODE_WIDTH, 19, 15, LINE_WIDTH]

model_rows, optimizer_rows, scheduler_rows, loss_fn_rows = setup_ae(
    params=params,
    lr=6e-4,
    weight_decay=3e-5,
)

In [17]:
run_ae(
    model=model_rows, 
    dataloader=dataloader_rows,
    optimizer=optimizer_rows,
    scheduler=scheduler_rows,
    epochs=3,
    params=params,
    batch_size=1024,
)

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [03:32<00:00, 70.83s/it]


## Кодируем все линии в датасете

In [18]:
model_rows.decode = False
model_rows.eval()
encoded = []
with torch.no_grad():
    to_encode = np.array(data['data'].to_list(), dtype=np.float32)
    for batch in tqdm(to_encode):
        batch = torch.Tensor(batch).to(device)
        encoded.append(model_rows(batch).cpu().detach().numpy())
        
enc_data = data.copy()
enc_data['data'] = encoded

enc_dataloader = Dataloader(enc_data, test_size=0.15, shuffle=True)

100%|████████████████████████████████████████████████████████████████████████| 337543/337543 [02:46<00:00, 2025.09it/s]


In [274]:
params = [SVG.ENCODE_HEIGHT * LINE_WIDTH, 2780, 1820, 1260, 1104, 865, 350]

model, optimizer, scheduler, loss_fn = setup_cae(
    params=params,
    lr=2e-4,
    weight_decay=6e-6,
    step_size=3,
)

In [276]:
run_cae(
    model=model, 
    dataloader=enc_dataloader, 
    optimizer=optimizer, 
    scheduler=scheduler, 
    epochs=1,
    params=params,
    batch_size=128,
    # run_name=f'run_relative_01_tanh_size{SVG.ENCODE_HEIGHT}_{"c" if model.vae else ""}vae_{",".join(map(str, params))}',
    run_name=f'run_encoded_tanh_size{SVG.ENCODE_HEIGHT}_ucae_{",".join(map(str, params))}',
)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:40<00:00, 40.54s/it]
