In [1]:
import torch
import os
import numpy as np

from tqdm import tqdm
from options import TrainOptions
from x_dataset import CGGANDataset
from model import CHARACTERModel
from utils import save_model

from google.cloud import storage

In [2]:
opt = TrainOptions()   # get training options

In [3]:
storage_client = storage.Client(opt.bucket_name)
bucket = storage_client.bucket(opt.bucket_name)

In [4]:
dataset = CGGANDataset(args=opt)
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=opt.batch_size,
    shuffle=True, sampler=None, drop_last=True, num_workers=int(opt.num_threads))

In [5]:
model = CHARACTERModel(opt=opt)

Param count for Ds initialized parameters: 19541696
Param count for Ds initialized parameters: 20591296
Param count for Ds initialized parameters: 27289027


In [None]:
total_iters = 0                # the total number of training iterations
total_steps = 440000*5
pbar = tqdm(total=total_steps)
while total_iters < total_steps:    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
    model.train() 

    for i, data in enumerate(train_loader):  # inner loop within one epoch         
        model.set_input(data)         # unpack data from dataset and apply preprocessing
        model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
        
        if total_iters % 10000 == 0:
            state_dicts = model.get_state_dicts()
            for k,v in state_dicts.items():
                save_model(opt, bucket, v, f"{k}__{total_iters}")
            newlr = model.update_learning_rate()     
            pbar.set_postfix(lr=f'{newlr:.{12}f}')# update learning rates at the end of every epoch.
        
        pbar.update(1)
        total_iters += 1
        

 14%|█▍        | 303276/2200000 [108:34:17<651:54:22,  1.24s/it, lr=0.000069306931] 