In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
def onehot_encode(label, device, n_class=10):
    
    eye = torch.eye(n_class, device=device)
    
    return eye[label].view(-1, n_class ,1, 1)

In [None]:
onehot_encode()

In [4]:
def concat_image_label(image, label, device, n_class=10):
    
    B, C, H, W = image.shape
    
    oh_label = onehot_encode(label, device)
    oh_label = oh_label.expand(B, n_class, H, W)
    
    return torch.cat((image, oh_label), dim=1)

In [5]:
def concat_noise_label(noise, label, device):
    
    oh_label = onehot_encode(label, device)
    
    return torch.cat((noise, oh_label), dim=1)

In [None]:
def train_model(G, D, dataloader, num_epochs):
    
    #==== save
    d_loss_list = []
    g_loss_list = []
    
    #==== device ====
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("device : ", device)
    
    #==== optimizer setting ====
    g_lr, d_lr = 0.0001, 0.0004
    beta1, beta2 = 0.0, 0.9
    
    g_optimizer = optim.Adam(G.parameters(), g_lr, [beta1, beta2])
    d_optimizer = optim.Adam(D.parameters(), d_lr, [beta1, beta2])
    
    #Loss Function
    criterion = nn.MSELoss()\
    
    #parameter
    z_dim = 20
    mini_batch_size = 64
    
    #Network
    G.to(device)
    D.to(device)
    
    G.train()
    D.train()
    
    torch.backends.cudnn.benchmark = True
    
    num_train_imgs = len(dataloader.dataset)
    batch_size = dataloader.batch_size
    
    #iteration
    iteration = 1
    logs = []
    
    for epoch in range(num_epochs):
        
        for images in dataloader:
            
            #==== Discriminator ====
            if images.size()[0] == 1:
                continue
            
            images = images.to(device)
            
            mini_batch_size = images.size()[0]
            label_real = torch.full((mini_batch_size, ), 1).to(device)
            label_fake = torch.full((mini_batch_size, ), 0).to(device)
            
            d_out_real = D(images)
            
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            
            fake_images = G(input_z)
            
            d_out_fake = D(fake_images)
            
            d_loss_real = criterion(d_out_real.view(-1), label_real)
            d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
            d_loss = d_loss_real + d_loss_fake
            
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            
            d_loss.backward()
            d_optimizer.step()
            
            #==== Generator ====
            
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            
            fake_images = G(input_z)
            d_out_fake = D(fake_images)
            
            g_loss = criterion(d_out_fake.view(-1), label_real)
            
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            
            g_loss.backward()
            g_optimizer.step()
            
            
            epoch_d_loss += d_loss.item()
            epoch_g_loss += d_loss.item()
            
        d_loss_list.append(epoch_d_loss/batch_size)
        g_loss_list.append(epoch_g_loss/batch_size)
        
        print(f"#======Epoch: {epoch}=======")
        print(f'd_loss: {epoch_d_loss/batch_size}, g_loss: {epoch_g_loss/batch_size}')
        
    
    return G, D, (d_loss_list, g_loss_list)