In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import GPUtil

import src.shared.util as U
import src.shared.augmentation as AUG
#from src.decoder import Decoder
from src.dataset.decode_dataset import DecodeDataset
from src.shared.subdivide import Subdivide
from decoder import Decoder

torch.set_printoptions(sci_mode=False, precision=5)

In [2]:
device = torch.device('cuda')
subdivide = Subdivide().to(device)

patch_root = './data/fitted/512x512/'
emb_file = ['./data/face_emb/', './data/face_emb_hq/',]
n_embs = 3

transform = lambda x: AUG.random_nth(x, 2)
dataset = DecodeDataset(patch_root, emb_file, n_embs=n_embs, suffix='.pth', transform=transform)

patch_mean = dataset.patch_data.mean(dim=0, keepdim=True)

for k, v  in dataset[0].items():
    print(k, v.shape if torch.is_tensor(v) else v)

emb_sizes = (512,) * 3
e_sz = 512

mean = dataset.patch_data.mean(dim=0, keepdim=True)
mean = mean.div(mean.abs().amax(dim=(1, 2, 3), keepdim=True))
mean =  F.interpolate(mean, 32).cuda()
print(mean.shape)

emb, patch = [dataset[1][k][None].cuda() for  k in ('embedding', 'patch')]

p_sz = 8
e_sz = 512
mid_sz = 64
blocks = [
    (1, 1, 2, 2,),
    (3, 3, 4, 4,),
    (5, 6, 7, 8,),
    (10, 12, 14, 16,),
]

decoder = Decoder(16, 512, 192, blocks).cuda()

decoder.load_state_dict(torch.load(f'./data/checkpoint/decoder512-{256}w.pth'))


emb.shape, patch.shape, mean.shape

patch torch.Size([3, 256, 256])
embedding torch.Size([3, 512])
idx 0
torch.Size([1, 3, 32, 32])


(torch.Size([1, 3, 512]),
 torch.Size([1, 3, 256, 256]),
 torch.Size([1, 3, 32, 32]))

In [5]:
emb.shape, mean.shape, [t.shape for t in decoder(emb, mean)]

(torch.Size([1, 3, 512]),
 torch.Size([1, 3, 32, 32]),
 [torch.Size([1, 3, 32, 32]),
  torch.Size([1, 3, 64, 64]),
  torch.Size([1, 3, 128, 128]),
  torch.Size([1, 3, 256, 256])])

In [3]:
"""
PatchGAN Discriminator (https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py#L538)
"""

import torch.nn as nn


class Discriminator(nn.Module):
    def __init__(self, image_channels, num_filters_last=64, n_layers=3):
        super(Discriminator, self).__init__()

        layers = [nn.Conv2d(image_channels, num_filters_last, 7, 3), nn.LeakyReLU(0.2)]
        num_filters_mult = 1

        for i in range(1, n_layers + 1):
            num_filters_mult_last = num_filters_mult
            num_filters_mult = min(2 ** i, 8)
            layers += [
                nn.Conv2d(num_filters_last * num_filters_mult_last, num_filters_last * num_filters_mult, 4,
                          2 if i < n_layers else 1, 1, bias=False),
                nn.BatchNorm2d(num_filters_last * num_filters_mult),
                nn.LeakyReLU(0.2, True)
            ]

        layers.append(nn.Conv2d(num_filters_last * num_filters_mult, 1, 4, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

discriminator =  Discriminator(3, 32).to(device)
optimD = torch.optim.AdamW(discriminator.parameters(), lr=0.0002)
discriminator(torch.randn(1, 3, 256,  256).to(device)).shape

torch.Size([1, 1, 19, 19])

In [4]:
p_sz = 8
e_sz = 512
mid_sz = 64
blocks = [    
    #(5,  6,  7,  8,),
    (9, 11, 13, 15, 16,),
]

generator =  Decoder(16, 512, 192, blocks).to(device)
optimG =  torch.optim.AdamW(generator.parameters(), lr=0.0002)

In [5]:
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
# for g in optim.param_groups:
#     #g['lr'] = 0.00005# 075
#     print(g['lr'])

In [19]:
epoch_no = 1_001

keys = ('embedding', 'patch')

lbl_real, lbl_fake = 1, 0
loss_fn = F.binary_cross_entropy_with_logits
errDR_lst, errDF_lst, errG_lst = [], [], []
for epoch in range(epoch_no):
    for step, batch in enumerate(iter(loader)):
        patch_trg = batch['patch'].to(device)
        emb = batch['embedding'].to(device)        
        with torch.no_grad():
            coarse = decoder(emb, mean)[-1]
            
        fake = generator(emb, coarse)[0]
        
        
        ##### Update D network: #####
        optimD.zero_grad()
        
        output = discriminator(patch_trg)
        labels = torch.zeros_like(output) + lbl_real
        labels = labels - torch.rand_like(labels).div(10)
        error = loss_fn(output, labels)
        error.backward()
        errDR_lst.append(error.item())        
        
        output = discriminator(fake.detach())
        labels = torch.zeros_like(output) + lbl_fake
        labels = labels + torch.rand_like(labels).div(10)
        error = loss_fn(output, labels)
        error = error
        error.backward()
        errDF_lst.append(error.item())
        
        optimD.step()
                            
        ##### Update G network: #####
        
        output = discriminator(fake)  
        labels = torch.zeros_like(output) + lbl_real
        labels = labels - torch.rand_like(labels).div(10)
        error = loss_fn(output, labels) + F.mse_loss(fake, patch_trg)       
        errG_lst.append(error.item())
        
        optimG.zero_grad()
        error.backward()
        optimG.step()
        
        if step % 50 == 0:
            temperature = GPUtil.getGPUs()[0].temperature
            print(epoch, step, temperature,
                  torch.tensor(errDR_lst).mean(),
                  torch.tensor(errDF_lst).mean(),
                  torch.tensor(errG_lst).mean(),)
            errDR_lst, errDF_lst, errG_lst = [], [], []
            if  temperature > 92:
                while temperature > 70:
                    print(f'GPU:{temperature}')
                    time.sleep(10)
                    temperature = GPUtil.getGPUs()[0].temperature

0 0 56.0 tensor(0.23056) tensor(0.22090) tensor(3.50133)
0 50 69.0 tensor(0.20951) tensor(0.20948) tensor(2.91743)
0 100 73.0 tensor(0.20966) tensor(0.21376) tensor(2.99928)
0 150 77.0 tensor(0.20868) tensor(0.21102) tensor(2.91898)
1 0 80.0 tensor(0.20326) tensor(0.21030) tensor(2.99533)
1 50 84.0 tensor(0.20445) tensor(0.20516) tensor(2.86844)
1 100 85.0 tensor(0.20901) tensor(0.21445) tensor(2.98051)
1 150 86.0 tensor(0.20525) tensor(0.20537) tensor(2.85232)
2 0 87.0 tensor(0.20835) tensor(0.21481) tensor(3.01564)
2 50 90.0 tensor(0.21212) tensor(0.20971) tensor(2.80933)
2 100 90.0 tensor(0.20590) tensor(0.21069) tensor(2.95996)
2 150 91.0 tensor(0.20736) tensor(0.20964) tensor(2.95514)
3 0 90.0 tensor(0.20561) tensor(0.20876) tensor(2.92113)
3 50 91.0 tensor(0.21411) tensor(0.21871) tensor(2.96688)
3 100 91.0 tensor(0.20871) tensor(0.21012) tensor(2.91594)
3 150 91.0 tensor(0.23215) tensor(0.23464) tensor(2.86980)
4 0 92.0 tensor(0.25701) tensor(0.25432) tensor(2.72472)
4 50 91.0 t

KeyboardInterrupt: 

In [22]:
loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)

In [20]:
for i, (pc, ps, pt) in enumerate(zip(coarse, fake, patch_trg)):
    if i < 3:        
        U.export_stl(pc, f'{i}crc')
        U.export_stl(ps, f'{i}src')
        U.export_stl(pt, f'{i}trg')

In [20]:
optimD = torch.optim.AdamW(discriminator.parameters(), lr=0.0001)

In [26]:
patch_trg[:, :, 8:-8, 8:-8] = 0

In [28]:
patch_trg[:, :, 8:-8, 8:-8] .shape

torch.Size([32, 3, 16, 16])

In [14]:
torch.rand(10) / 10

tensor([0.08695, 0.01874, 0.09883, 0.02212, 0.06796, 0.06031, 0.08695, 0.03643,
        0.02735, 0.01541])