In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from torch import optim
from torch.utils.data import DataLoader

from data.datasets.ffhq_dataset import FFHQDataset
from data.datasources.ffhq_datasource import FFHQDatasource
from data.datasources.golden_age_face_datasource import GoldenAgeFaceDatasource
from functional.losses.bi_discriminator_loss import BidirectionalDiscriminatorLoss, BidirectionalDiscriminatorLossType
from networks.bigan import BiGAN
from training.bigan_trainer import BiGANTrainer
from utils.config_utils import read_config, Config
from utils.logging_utils import *
from utils.plot_utils import *

from data.datasets import facedataset
from data.datasources import facedatasource
from data.datasources.datasource_mode import DataSourceMode
from networks.siamese_network import SiameseNetwork
from functional.losses.contrastive_loss import ContrastiveLoss
from functional.metrics.dissimilarity import *
from training.face_recognition_trainer import train_epochs
from configs.base_config import *

In [3]:
def save_best_loss_model(model_name, model, best_loss):
    # print('current best loss: ' + str(best_loss))
    logging.info('current best loss: ' + str(best_loss))
    torch.save(model, base_dir + 'playground/bigan/results/' + model_name + ".pth")

In [None]:
golden_age_config = read_config(Config.GOLDEN_AGE_FACE)
datasource = GoldenAgeFaceDatasource(golden_age_config, mode=DataSourceMode.TRAIN)

reading image: 0
reading image: 512
reading image: 1024
reading image: 1536
reading image: 2048
reading image: 2560
reading image: 3072
reading image: 3584
reading image: 4096
reading image: 4608
reading image: 5120
reading image: 5632
reading image: 6144
reading image: 6656
reading image: 7168
reading image: 7680
reading image: 8192
reading image: 8704
reading image: 9216
reading image: 9728
reading image: 10240
reading image: 10752
reading image: 11264
reading image: 11776
reading image: 12288
reading image: 12800
reading image: 13312
reading image: 13824
reading image: 14336
reading image: 14848
reading image: 15360
reading image: 15872
reading image: 16384
reading image: 16896
reading image: 17408
reading image: 17920
reading image: 18432
reading image: 18944
reading image: 19456
reading image: 19968
reading image: 20480
reading image: 20992
reading image: 21504
reading image: 22016
reading image: 22528
reading image: 23040
reading image: 23552
reading image: 24064
reading image: 2

In [5]:
annot_path = golden_age_config.annotations_folder_path
additional_data = datasource.get_additional_data(annot_path=annot_path,
                               from_image_count=70000,
                               additional_image_count=500)

reading image: 0


In [10]:
datasource.data = np.concatenate((datasource.data, additional_data), axis=0)

In [1]:
train_config = read_config(Config.BiGAN)
train_dataset = FFHQDataset(datasource=datasource)
train_dataloader = DataLoader(train_dataset, batch_size=train_config.batch_size, shuffle=True)

NameError: name 'read_config' is not defined

In [None]:
def train_golden_age_face_bigan(model_name='test_model'):
    logging.info("initiate training")
    net = BiGAN(image_dim=golden_age_config.image_dim).to(ptu.device)
    criterion = BidirectionalDiscriminatorLoss(loss_type=BidirectionalDiscriminatorLossType.VANILLA_LOG_MEAN)

    d_optimizer = torch.optim.Adam(net.discriminator.parameters(),
                                   lr=train_config.discriminator_lr,
                                   betas=(train_config.discriminator_beta_1, train_config.discriminator_beta_2),
                                   weight_decay=train_config.discriminator_weight_decay)

    g_optimizer = torch.optim.Adam(list(net.encoder.parameters()) + list(net.generator.parameters()),
                                   lr=train_config.generator_lr,
                                   betas=(train_config.generator_beta_1, train_config.generator_beta_2),
                                   weight_decay=train_config.generator_weight_decay)
    g_scheduler = torch.optim.lr_scheduler.LambdaLR(g_optimizer,
                                                    lambda epoch: (
                                                                          train_config.train_epochs - epoch) / train_config.train_epochs,
                                                    last_epoch=-1)
    d_scheduler = torch.optim.lr_scheduler.LambdaLR(d_optimizer,
                                                    lambda epoch: (
                                                                          train_config.train_epochs - epoch) / train_config.train_epochs,
                                                    last_epoch=-1)
    trainer = BiGANTrainer(model=net,
                           criterion=criterion,
                           train_loader=train_dataloader,
                           test_loader=None,
                           epochs=train_config.train_epochs,
                           optimizer_generator=g_optimizer,
                           optimizer_discriminator=d_optimizer,
                           scheduler_gen=g_scheduler,
                           scheduler_disc=d_scheduler,
                           best_loss_action=lambda m, l: save_best_loss_model(model_name, m, l))
    losses = trainer.train_bigan()

    logging.info("completed training")
    save_training_plot(losses['discriminator_loss'],
                       losses['generator_loss'],
                       "Golden Age Face BiGAN Losses",
                       base_dir + 'playground/bigan/' + f'results/bigan_plot.png')
    return net

In [None]:
ptu.set_gpu_mode(True)
# visualize_data()
# visualize_golden_age_face_data()
# model = train_bigan(get_dt_string() + "_model")
model = train_golden_age_face_bigan(get_dt_string() + "_model")
# torch.save(model, base_dir + 'playground/bigan/results/' + "test_model.pth")
# model = torch.load("test_model.pth")
