In [1]:
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

from src.data.data_util import (
    pyramid_transform, 
    calculate_mean_std,
    loader_generator,
)
from src.data.masked_dataset import MaskedDataset
from src.config import get_parser

from src.models.stylist import Stylist
from src.models.generator import Generator
from src.models.discriminator import Discriminator
from src.render.mesh_renderer import MeshPointsRenderer
from src.utilities.util import grid_to_list
# itr = iter(loader)
# for i in range(len(loader)):
#     batch = next(itr)
#     print(i)

mean = 0.1834
std = 0.2670

In [3]:
config = get_parser().parse_args(args=[])

config.batch_size = 8

root = '/home/bobi/Desktop/db/ffhq-dataset/thumbnails'
ds = ImageFolder(root, transform=pyramid_transform(mean, std))
loader = DataLoader(ds, batch_size=config.batch_size, shuffle=True)
res = next(iter(loader))
res[0]['large'].shape, res[0]['medium'].shape, res[0]['small'].shape

(torch.Size([8, 1, 128, 128]),
 torch.Size([8, 1, 64, 64]),
 torch.Size([8, 1, 32, 32]))

In [4]:
def loader_generator(loader):
    current = 0
    iterator = iter(loader)
    while True:
        if current >= len(loader):
            current = 0
            iterator = iter(loader)
        yield next(iterator)
        current += 1
        
img_gen =  loader_generator(loader)        
img_gen

<generator object image_generator at 0x7fab83cbcc10>

In [5]:
len(loader)

8750

In [6]:
for i in range(9000):
    batch = next(img_gen)
    if i % 100 == 0:
        print(i)

0
100
200
300
400
500
600
700
800
900
1000
1100
1200
1300
1400
1500
1600
1700
1800
1900
2000
2100
2200
2300
2400
2500
2600
2700
2800
2900
3000
3100
3200
3300
3400
3500
3600
3700
3800
3900
4000
4100
4200
4300
4400
4500
4600
4700
4800
4900
5000
5100
5200
5300
5400
5500
5600
5700
5800
5900
6000
6100
6200
6300
6400
6500
6600
6700
6800
6900
7000
7100
7200
7300
7400
7500
7600
7700
7800
7900
8000
8100
8200
8300
8400
8500
8600
8700
8800
8900


In [7]:
batch = next(img_gen)
batch

[{'large': tensor([[[[-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            ...,
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869]]],
  
  
          [[[-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            ...,
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869],
            [-0.6869, -0.6869, -0.6869,  ..., -0.6869, -0.6869, -0.6869]]],
  
  
          [[[-0.6869, -0.6869, -0.6869,  ..., -0.6869

In [8]:
from enum import Enum

class LossType(Enum):
    ADVERSARIAL = 0
    RECONSTRUCTION = 1
    CONTRASTIVE = 2

LossType(2) == LossType.ADVERSARIAL, LossType(2) == LossType.CONTRASTIVE

(False, True)

In [9]:
def get_loss_type(step, epoch):
    return LossType.ADVERSARIAL

In [None]:
epoch_no = 100
step_no = 1000

for epoch in range(epoch_no):
    for step in range(step_no):
        loss_type =  get_loss_type(step, epoch)
        if loss_type == LossType.ADVERSARIAL:
            adversarial_step()
        elif loss_type == LossType.RECONSTRUCTION:
            reconstruction_step()
        elif loss_type == LossType.CONTRASTIVE:
            contrastive_step()    

In [None]:
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.


def adversarial_step():    
    # Generator step        
    batch = provider.adversarial_batch()
    images = batch['large'].to(device)
    points = batch['points'].to(device)
    
    optimizer.zero_grad()
    
    style = stylist(images)
    pts = generator(pts, style)
    # Render
    g_images = mesh_renderer(grid_to_list(pts))
    g_images = g_images[..., :3].mean(-1, keepdim=True)
    g_images = g_images.permute(0, 3, 1, 2)    
    g_images = (g_images - mean) / std
    
    d_output = discriminator(g_images)
    
    ones = torch.ones(d_output.shape[0], device=device)
    g_loss = adversarial_loss(d_output, ones)    
    g_loss.backward()
    optimizer.step()
    
    # Discriminator step
    batch = provider.adversarial_batch()
    images = batch['large'].to(device)
    points = batch['points'].to(device)
    
    
    


def reconstruction_step():
    pass

def contrastive_step():
    pass