In [13]:
import torch
import os
import numpy as np

from tqdm import tqdm
from options import TrainOptions
from x_dataset import CGGANDataset
from model import CHARACTERModel
from utils import save_model

from google.cloud import storage

In [14]:
opt = TrainOptions()   # get training options

In [15]:
storage_client = storage.Client(opt.bucket_name)
bucket = storage_client.bucket(opt.bucket_name)

In [16]:
dataset = CGGANDataset(args=opt)
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=opt.batch_size,
    shuffle=True, sampler=None, drop_last=True, num_workers=int(opt.num_threads))

In [17]:
model = CHARACTERModel(opt=opt)

Param count for Ds initialized parameters: 19541696
Param count for Ds initialized parameters: 20591296
Param count for Ds initialized parameters: 27289027


In [18]:
data = next(iter(train_loader))

In [19]:
model.set_input(data)

In [21]:
model.img_write.shape

torch.Size([32, 3, 128, 128])

In [22]:
model.img_print.shape

torch.Size([32, 3, 128, 128])

In [20]:
style_emd, style_fc, residual_features_style = model.netStyleEncoder(model.img_write)
cont, residual_features = model.netContentEncoder(model.img_print) 
h = model.netdecoder(cont, residual_features, style_emd, style_fc, residual_features_style)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 8 for tensor number 1 in the list.

In [12]:
residual_features[4].shape

torch.Size([32, 512, 8, 8])

In [16]:
cont.shape

torch.Size([32, 1024, 4, 4])

In [21]:
style_emd.shape

torch.Size([32, 1024, 4, 4])

In [19]:
model.netdecoder.blocks[0][0]

GBlock(
  (activation): ReLU()
  (conv1): SNConv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): SNConv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv_sc): SNConv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
  (bn1): AdaIN2d(num_features=2048)
  (bn2): AdaIN2d(num_features=512)
)

In [None]:
total_iters = 0                # the total number of training iterations
total_steps = 440000*5
pbar = tqdm(total=total_steps)
while total_iters < total_steps:    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    model.train() 

    for i, data in enumerate(train_loader):  # inner loop within one epoch         
        model.set_input(data)         # unpack data from dataset and apply preprocessing
        model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
        
        pbar.update(1)
        total_iters += 1
        
        if total_iters % 10000 == 0:
            state_dicts = model.get_state_dicts()
            for k,v in state_dicts.items():
                save_model(opt, bucket, v, f"{k}__{total_iters}.pth")
            newlr = model.update_learning_rate()     
            pbar.set_postfix(lr=f'{newlr:.{12}f}')# update learning rates at the end of every epoch.

  0%|          | 1783/2200000 [38:06<817:59:41,  1.34s/it]