In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import GPUtil

import src.shared.util as U
import src.shared.augmentation as AUG
#from src.decoder import Decoder
from src.dataset.decode_dataset import DecodeDataset
from src.shared.subdivide import Subdivide
from decoder import Decoder

torch.set_printoptions(sci_mode=False, precision=5)

In [2]:
device = torch.device('cuda')
subdivide = Subdivide().to(device)

patch_root = './data/fitted/512x512/'
emb_file = ['./data/face_emb/', './data/face_emb_hq/',]
n_embs = 3

transform = lambda x: AUG.random_nth(x, 2)
dataset = DecodeDataset(patch_root, emb_file, n_embs=n_embs, suffix='.pth', transform=transform)

patch_mean = dataset.patch_data.mean(dim=0, keepdim=True)

for k, v  in dataset[0].items():
    print(k, v.shape if torch.is_tensor(v) else v)

emb_sizes = (512,) * 3
e_sz = 512

mean = dataset.patch_data.mean(dim=0, keepdim=True)
mean = mean.div(mean.abs().amax(dim=(1, 2, 3), keepdim=True))
mean =  F.interpolate(mean, 32).cuda()
print(mean.shape)

emb, patch = [dataset[1][k][None].cuda() for  k in ('embedding', 'patch')]

p_sz = 8
e_sz = 512
mid_sz = 64
blocks = [
    (1, 1, 2, 2,),
    (3, 3, 4, 4,),
    (5, 6, 7, 8,),
    (10, 12, 14, 16,),
]

decoder = Decoder(16, 512, 192, blocks).cuda()

decoder.load_state_dict(torch.load(f'./data/checkpoint/decoder512-{256}w.pth'))


emb.shape, patch.shape, mean.shape

patch torch.Size([3, 256, 256])
embedding torch.Size([3, 512])
idx 0
torch.Size([1, 3, 32, 32])


(torch.Size([1, 3, 512]),
 torch.Size([1, 3, 256, 256]),
 torch.Size([1, 3, 32, 32]))

In [5]:
emb.shape, mean.shape, [t.shape for t in decoder(emb, mean)]

(torch.Size([1, 3, 512]),
 torch.Size([1, 3, 32, 32]),
 [torch.Size([1, 3, 32, 32]),
  torch.Size([1, 3, 64, 64]),
  torch.Size([1, 3, 128, 128]),
  torch.Size([1, 3, 256, 256])])

In [15]:
"""
PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
"""

import torch.nn as nn


class Discriminator(nn.Module):
    def __init__(self, image_channels, num_filters_last=64, n_layers=3):
        super(Discriminator, self).__init__()

        layers = [nn.Conv2d(image_channels, num_filters_last, 7, 3), nn.LeakyReLU(0.2)]
        num_filters_mult = 1

        for i in range(1, n_layers + 1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2 ** i, 8)
            layers += [
                nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                          2 if i < n_layers else 1, 1, bias=False),
                nn.BatchNorm2d(num_filters_last * num_filters_mult),
                nn.LeakyReLU(0.2, True)
            ]

        layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

discriminator =  Discriminator(3, 32).to(device)
optimD = torch.optim.AdamW(discriminator.parameters(), lr=0.0002)
discriminator(torch.randn(1, 3, 256,  256).to(device)).shape

torch.Size([1, 1, 19, 19])

In [9]:
p_sz = 8
e_sz = 512
mid_sz = 64
blocks = [    
    #(5,  6,  7,  8,),
    (9, 11, 13, 15, 16,),
]

generator =  Decoder(16, 512, 192, blocks).to(device)
optimG =  torch.optim.AdamW(generator.parameters(), lr=0.0002)

In [12]:
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
# for g in optim.param_groups:
#     #g['lr'] = 0.00005# 075
#     print(g['lr'])

In [16]:
epoch_no = 1_001

keys = ('embedding', 'patch')

loss_fn = F.binary_cross_entropy_with_logits
errD_lst, errG_lst = [], []
for epoch in range(epoch_no):
    for step, batch in enumerate(iter(loader)):
        patch_trg = batch['patch'].to(device)
        emb = batch['embedding'].to(device)        
        with torch.no_grad():
            coarse = decoder(emb, mean)[-1]
            
        fake = generator(emb, coarse)[0]
        
        ratios  = torch.rand(fake.size(0), 1, 1, 1, device=device)
        blend = ratios * patch_trg +  (1-ratios) * fake.detach()
        ##### Update D network: #####
        output = discriminator(blend)
        labels = ratios.expand(-1, -1, output.size(-2), output.size(-2))
        error = loss_fn(output, labels)
        errD_lst.append(error.item())
        
        optimD.zero_grad()
        error.backward()        
        optimD.step()
                                
        ##### Update G network: #####
        
        
        output = discriminator(fake)        
        error = loss_fn(output, torch.zeros_like(output))        
        errG_lst.append(error.item())
        
        optimG.zero_grad()
        error.backward()
        optimG.step()
        
        if step % 50 == 0:
            temperature = GPUtil.getGPUs()[0].temperature
            print(epoch, step, temperature,
                  torch.tensor(errD_lst).mean(), 
                  torch.tensor(errG_lst).mean(),)
            errD_lst, errG_lst = [], []
            if  temperature > 92:
                while temperature > 70:
                    print(f'GPU:{temperature}')
                    time.sleep(10)
                    temperature = GPUtil.getGPUs()[0].temperature

0 0 63.0 tensor(0.70843) tensor(0.75907)
0 50 73.0 tensor(0.52876) tensor(0.55006)
0 100 76.0 tensor(0.51805) tensor(0.48044)
0 150 79.0 tensor(0.50872) tensor(0.41785)
1 0 81.0 tensor(0.51169) tensor(0.35214)
1 50 83.0 tensor(0.50401) tensor(0.30887)
1 100 86.0 tensor(0.50247) tensor(0.27693)
1 150 86.0 tensor(0.50488) tensor(0.24632)
2 0 87.0 tensor(0.50386) tensor(0.23003)
2 50 89.0 tensor(0.49239) tensor(0.22668)
2 100 91.0 tensor(0.50873) tensor(0.23133)
2 150 91.0 tensor(0.50663) tensor(0.21868)
3 0 91.0 tensor(0.50801) tensor(0.20390)
3 50 92.0 tensor(0.50484) tensor(0.19461)
3 100 91.0 tensor(0.50811) tensor(0.18791)
3 150 91.0 tensor(0.49591) tensor(0.18533)
4 0 91.0 tensor(0.50133) tensor(0.17895)
4 50 91.0 tensor(0.50392) tensor(0.17696)
4 100 91.0 tensor(0.50672) tensor(0.17745)
4 150 91.0 tensor(0.49844) tensor(0.18332)
5 0 91.0 tensor(0.50484) tensor(0.19530)
5 50 91.0 tensor(0.50488) tensor(0.20274)
5 100 91.0 tensor(0.50932) tensor(0.23851)
5 150 91.0 tensor(0.50745) te

KeyboardInterrupt: 

In [22]:
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

In [21]:
for i, (pc, ps, pt) in enumerate(zip(coarse, fake, patch_trg)):
    if i < 3:        
        U.export_stl(pc, f'{i}crc')
        U.export_stl(ps, f'{i}src')
        U.export_stl(pt, f'{i}trg')

In [20]:
optimD = torch.optim.AdamW(discriminator.parameters(), lr=0.0001)

In [26]:
patch_trg[:, :, 8:-8, 8:-8] = 0

In [28]:
patch_trg[:, :, 8:-8, 8:-8] .shape

torch.Size([32, 3, 16, 16])