In [11]:
import torch
import torch.nn as nn

from src.data.adversarial import AdversarialDataProvider
from src.data.reconstruction import ReconstructionDataProvider
from src.models.stylist import Stylist
from src.models.generator import Generator
from src.models.discriminator import Discriminator
from src.render.mesh_renderer import MeshPointsRenderer

class Trainer:
    
    def __init__(self, config):
        self.G_noise_amp =  config.G_noise_amp
        self.z_size = config.latent_size
        
        self.device = torch.device("cuda")
        
        # Data providers
        self.RDP = ReconstructionDataProvider(config)
        self.ADP = AdversarialDataProvider(config)
        
        self.S = Stylist(config).to(self.device)
        self.G = Generator(config).to(self.device)
        self.D = Discriminator(config).to(self.device)
        self.R = MeshPointsRenderer(config).to(self.device)
        self.R.setup(self.device)
        
        # Optimizers
        self.optim_G = torch.optim.Adam(self.G.parameters())
        self.optim_S = torch.optim.Adam(self.S.parameters())
        self.optim_D = torch.optim.Adam(self.D.parameters())
        # Loss functions
        self.adversarial_loss = nn.BCELoss()
        self.reconstruction_loss = nn.L1Loss()
        
    def step(self):
        pass
        
    
    def adversarial_step(self):
        batch, mean_std = self.ADP.next_batch(labels=True, device=self.device)
        images, points = batch['large'], batch['points'] # ?????
        label_real = batch['label_real']
        label_fake = batch['label_fake']
        
        # (1) Update Discriminator
        ## Train with all-real batch
        self.optim_D.zero_grad()
        output = self.D(images)        
        errD_real =  self.adversarial_loss(output, label_real)
        errD_real.backward()
        D_x = output.mean().item() 
        
        ## Train with all-fake batch
        style = self.S(images)
        pts =  self.G(points, style)
        g_images = self.R(pts, mean_std=mean_std)                
        output = self.D(g_images.detach())  
        errD_fake = self.adversarial_loss(output, label_fake)    
        errD_fake.backward()
        
        self.optim_D.step()
        
        # (2) Update Generator and Stylist
        self.optim_S.zero_grad()
        self.optim_G.zero_grad()
        
        output = self.D(g_images)    
        errG = self.adversarial_loss(output, label_real)
        errG.backward()
        
        self.optim_S.step()
        self.optim_G.step()
        
    def reconstruction_step(self):
        batch = self.RDP.next_batch(self.device)        
        pts_fine = batch['points']
        bs = pts_fine.size(0)
        
        pts_noise = pts_fine + torch.randn_like(pts_fine) * self.G_noise_amp
        style = torch.zeros(bs, self.z_size, device=self.device)
        
        self.optim_G.zero_grad()
        
        vertices = self.G(pts_noise, style)
        errG_rec = self.reconstruction_loss(vertices, pts_fine)
        
        errG_rec.backward()
        
        self.optim_G.step()


from src.config import get_parser

config = get_parser().parse_args(args=[])    
trainer = Trainer(config)
trainer

torch.Size([16, 3, 512, 512])


<__main__.Trainer at 0x7f9ecc4e60a0>

In [7]:
trainer.adversarial_step()

torch.Size([8, 128, 128, 4])


In [10]:
trainer.reconstruction_step()

NameError: name 'bs' is not defined

In [4]:
self = trainer

batch, mean_std = self.ADP.next_batch(labels=True, device=self.device)
images, points = batch['large'], batch['points'] # ?????
label_real = batch['label_real']
label_fake = batch['label_fake']

# (1) Update Discriminator
## Train with all-real batch
self.optim_D.zero_grad()
output = self.D(images)        
errD_real =  self.adversarial_loss(output, label_real)
errD_real.backward()
D_x = output.mean().item() 

## Train with all-fake batch
style = self.S(images)
pts =  self.G(points, style)
g_images = self.R(pts, mean_std=mean_std)
#g_img_clone =  g_images.contiguous().detach()
output = self.D(g_images.detach())        
errD_fake = self.adversarial_loss(output, label_fake)    
errD_fake.backward()

self.optim_D.step()

# (2) Update Generator and Stylist
self.optim_S.zero_grad()
self.optim_G.zero_grad()

output = self.D(g_images)    
errG = self.adversarial_loss(output, label_real)
errG.backward()

self.optim_S.step()
self.optim_G.step()

torch.Size([8, 128, 128, 4])


In [8]:
pts.shape

torch.Size([8, 3, 128, 128])

In [17]:
t =  torch.rand(11)
t

tensor([0.6146, 0.0308, 0.7004, 0.5932, 0.1845, 0.9701, 0.3037, 0.7716, 0.3100,
        0.9828, 0.2576])

In [18]:
t.clone()

tensor([0.6146, 0.0308, 0.7004, 0.5932, 0.1845, 0.9701, 0.3037, 0.7716, 0.3100,
        0.9828, 0.2576])

In [3]:
g_img_clone.shape

torch.Size([8, 4, 128, 1])