In [1]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import GPUtil

import src.shared.util as U
#from src.decoder import Decoder
from src.dataset.decode_dataset import DecodeDataset
from src.shared.subdivide import Subdivide

torch.set_printoptions(sci_mode=False, precision=5)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class Head(nn.Module):    

    def __init__(self, p_sz, emb_sz, mid_sz, dropout=0.05):
        super().__init__()
        self.p_sz = p_sz
        self.proj_emb = nn.Linear(emb_sz, mid_sz, bias=False)
        self.proj_key = nn.Linear(3*p_sz**2, mid_sz, bias=False)
        
        self.net = nn.Sequential(            
            nn.Dropout(dropout),
            nn.ReLU(True),
            nn.Linear(2*mid_sz, 3*p_sz**2, bias=False),
        ) 

    def forward(self, x, emb):                
        emb = self.proj_emb(emb)
        u, v = x.size(-2) // self.p_sz, x.size(-1) // self.p_sz
        y = rearrange(x, 'b c (u m) (v n) -> b (u v) (c m n)',
                      m=self.p_sz, n=self.p_sz)        
        y = self.proj_key(y)        
        emb = emb.expand(-1, y.size(1), -1)
        y = y.expand(emb.size(0), -1, -1)        
        y = torch.cat([emb, y], dim=-1)        
        y = self.net(y)
        return rearrange(y, 'b (u v) (c m n) -> b c (u m) (v n)',
                         c=3, u=u, v=v, m=self.p_sz, n=self.p_sz)
        

class MultiHead(nn.Module):
    def __init__(self, p_sz, e_sz, mid_sz, sizes):
        super().__init__()
        self.p_sz = p_sz        
        self.sizes = sizes
        self.heads = nn.ModuleList([
            Head(p_sz, e_sz, mid_sz) 
            for _ in sizes])

    def forward(self, emb, mean):        
        layers = [mean]        
        for i, (head, sz) in enumerate(zip(self.heads, self.sizes)):                        
            o = F.interpolate(mean, self.p_sz * sz, mode='bilinear')
            y = head(o, emb)
            layers.append(y)
        return U.join(layers)

class Decoder(nn.Module):
    def __init__(self, p_sz, e_sz, mid_sz, blocks):
        super().__init__()
        self.blocks = nn.ModuleList([
            MultiHead(p_sz, e_sz, mid_sz, sizes)
            for sizes  in blocks
        ])
    
    def forward(self, emb, mean):
        emb = emb.mean(dim=1, keepdim=True)        
        res = []
        for block in self.blocks:
            mean = block(emb, mean)
            res.append(mean)
        return res
        

p_sz = 8
e_sz = 512
mid_sz = 64
blocks = [
    (1, 1, 2, 2,),
    (3, 3, 4, 4,),
    (5, 6, 7, 8,),
    (10, 12, 14, 16,),
]
mean = torch.randn(1, 3, 256, 256).cuda()

decoder = Decoder(16, 512, 192, blocks).cuda()

decoder.load_state_dict(torch.load(f'./data/checkpoint/decoder512-{256}w.pth'))

optim = torch.optim.AdamW(decoder.parameters(), lr=0.0002)
optim.load_state_dict(torch.load(f'./data/checkpoint/optim/optim512-{256}w.pth'))

[e.shape for e in  decoder(torch.randn(7, 3, 512).cuda(), mean)]

[torch.Size([7, 3, 32, 32]),
 torch.Size([7, 3, 64, 64]),
 torch.Size([7, 3, 128, 128]),
 torch.Size([7, 3, 256, 256])]

In [3]:
device = torch.device('cuda')
subdivide = Subdivide().to(device)

patch_root = './data/fitted/512x512/'
emb_file = ['./data/face_emb/', './data/face_emb_hq/',]
n_embs = 3

transform = lambda x: F.interpolate(x[None], 256)[0]
dataset = DecodeDataset(patch_root, emb_file, n_embs=n_embs, suffix='.pth', transform=transform)

patch_mean = dataset.patch_data.mean(dim=0, keepdim=True)

for k, v  in dataset[0].items():
    print(k, v.shape if torch.is_tensor(v) else v)

emb_sizes = (512,) * 3
e_sz = 512

mean = dataset.patch_data.mean(dim=0, keepdim=True)
print(mean.shape)

emb, patch = [dataset[1][k][None].cuda() for  k in ('embedding', 'patch')]
emb.shape, patch.shape

patch torch.Size([3, 256, 256])
embedding torch.Size([3, 512])
idx 0
torch.Size([1, 3, 512, 512])


(torch.Size([1, 3, 512]), torch.Size([1, 3, 256, 256]))

In [20]:
loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=0)
for g in optim.param_groups:
    g['lr'] = 0.00005# 075
    print(g['lr'])

5e-05


In [None]:
keys = ('embedding', 'patch')# 'dino1', 'dino2')
#patch_mean =  patch_mean.to(device)
from src.shared.sharpen import Sharpen
sharpen = Sharpen(mask=('sharpen', '3x3_3'), padd=False).to(device)
#offsets = [F.interpolate(mean, sz).cuda() for sz in  (8, 16, 32, 24, 40, 48, 56, 64)]
#offsets = [F.interpolate(mean, sz).cuda() for sz in  [112, 256]] #[80, 112, 176, 256]]
mean = dataset.patch_data.mean(dim=0, keepdim=True)
mean = mean.div(mean.abs().amax(dim=(1, 2, 3), keepdim=True))
mean =  F.interpolate(mean, 32).cuda()

no_epochs = 1_001
err_lst0, err_lst1, err_lst2, err_lst3  = [], [], [], []
for epoch in range(no_epochs):
    for step, batch in enumerate(loader):
        #emb, patch_trg, mean, dino1, dino2 = (batch[k].cuda() for k in keys)
        emb, patch_trg = (batch[k].cuda() for k in keys)
        patch_trg = subdivide.smooth(patch_trg, interpolate=True)
        patch_trg = patch_trg.div(patch_trg.abs().amax(dim=(1, 2, 3), keepdim=True))
        
        patch_src_lst = decoder(emb, mean) #+ patch_mean
        errors =  [
            F.mse_loss(F.interpolate(patch_src, patch_trg.size(-1), mode='bilinear'), patch_trg)            
            for patch_src in patch_src_lst]
        for err_lst, err in zip([err_lst0, err_lst1, err_lst2, err_lst3], errors):
            err_lst.append(err.item())
        #diff = (patch_src - patch_trg).abs()#.sum(dim=1, keepdim=True)
        #mse_err = diff.mean() +  F.max_pool2d(diff, 16).mean()
        #mse_err = F.mse_loss(patch_src, patch_trg)
        
        error = sum(errors) #+ 0.01 * F.l1_loss(patch_src, patch_trg)
        
        optim.zero_grad()
        error.backward()
        optim.step()
        #err_lst.append(error.item())
        if step % 100 == 0:
            temperature = GPUtil.getGPUs()[0].temperature
            if epoch % 1 == 0:
                print(str(epoch).zfill(4), str(step).zfill(4), 
                      f'{torch.tensor(err_lst0).mean().item():.8f}',
                      f'{torch.tensor(err_lst1).mean().item():.8f}',
                      f'{torch.tensor(err_lst2).mean().item():.8f}',
                      f'{torch.tensor(err_lst3).mean().item():.8f}',
                      #f'{torch.tensor(sharpen_lst).mean().item():.8f}',
                      temperature)
                err_lst0, err_lst1, err_lst2, err_lst3  = [], [], [], []                             
                
            if  temperature > 92:
                while temperature > 68:
                    print(f'GPU:{temperature}')
                    time.sleep(30)
                    temperature = GPUtil.getGPUs()[0].temperature
    if epoch % 3 ==0:      
        size = patch_src_lst[-1].size(-1)
        torch.save(decoder.state_dict(), f'./data/checkpoint/decoder512-{size}w.pth')
        torch.save(optim.state_dict(), f'./data/checkpoint/optim/optim512-{size}w.pth')
        for i, (pc, ps, pt) in enumerate(zip(patch_src_lst[0], patch_src_lst[-1], patch_trg)):
            if i < 3:
                U.export_stl(pc, f'{i}crc')
                U.export_stl(ps, f'{i}src')
                U.export_stl(pt, f'{i}trg')

size = patch_src_lst[-1].size(-1)
torch.save(decoder.state_dict(), f'./data/checkpoint/decoder512-{size}w.pth')
torch.save(optim.state_dict(), f'./data/checkpoint/optim/optim512-{size}w.pth')
# 0059 0000 0.00003770 0.00003576 0.00004650 0.00004472 93.0

0000 0000 0.00008116 0.00005639 0.00005167 0.00005292 84.0
0001 0000 0.00008058 0.00005618 0.00005156 0.00005274 89.0
0002 0000 0.00007998 0.00005547 0.00005076 0.00005188 93.0
GPU:93.0
GPU:81.0


In [22]:
errors

[tensor(0.00013, device='cuda:0', grad_fn=<MseLossBackward0>),
 tensor(0.00010, device='cuda:0', grad_fn=<MseLossBackward0>),
 tensor(0.00010, device='cuda:0', grad_fn=<MseLossBackward0>),
 tensor(0.00011, device='cuda:0', grad_fn=<MseLossBackward0>)]

In [19]:
size = patch_trg.size(-1)
torch.save(decoder.state_dict(), f'./data/checkpoint/decoder512-{size}w.pth')
torch.save(optim.state_dict(), f'./data/checkpoint/optim/optim512-{size}w.pth')

In [8]:
for i, (pc, ps, pt) in enumerate(zip(patch_src_lst[0], patch_src_lst[-1], patch_trg)):
    if i < 3:
        U.export_stl(pc, f'{i}crc')
        U.export_stl(ps, f'{i}src')
        U.export_stl(pt, f'{i}trg')

In [None]:
print(str(epoch).zfill(4), str(step).zfill(4), 
                      f'{torch.tensor(err_lst).mean().item():.8f}', temperature)

In [None]:
loader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=0)
decoded512_32 = torch.zeros(len(dataset), 3, 32, 32)
orig_512_32 = torch.zeros(len(dataset), 3, 32, 32)
device = torch.device('cuda')

bsz = 8
keys = ('embedding', 'patch')
for i, batch in enumerate(loader):
    emb, orig = (batch[k].to(device) for k in keys)    
    with torch.no_grad():
        coarse = decoder(emb)        
    decoded512_32[i*bsz: (i+1) * bsz] = coarse
    orig_512_32[i*bsz: (i+1) * bsz] = orig

torch.save(decoded512_32, './data/decoded512_32j.pt')

In [None]:
for i in [100, 700, 2300]:
    U.export_stl(orig_512_32[i], f'{i}org')
    U.export_stl(decoded512_32[i], f'{i}src')

In [None]:
dataset.patch_data.shape