In [1]:
import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, latent_dim=100, hidden_size=[512, 512, 512], output_size=784):
        super(Generator, self).__init__()
        self.hidden_layers = nn.ModuleList()
        
        self.hidden_layers.append(
            nn.Sequential(
                nn.Linear(latent_dim, hidden_size[0]),
                nn.BatchNorm1d(hidden_size[0]),
                nn.ReLU()
            )
        )
        
        for i in range(len(hidden_size) - 1):
            self.hidden_layers.append(
                nn.Sequential(
                    nn.Linear(hidden_size[i], hidden_size[i+1]),
                    nn.BatchNorm1d(hidden_size[i+1]),
                    nn.ReLU()
                )
            )

        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size[-1], output_size),
            nn.Tanh()
        )
            
    def forward(self, x):
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

gen = Generator()
print(gen)

Generator(
  (hidden_layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=100, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1-2): 2 x Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (output_layer): Sequential(
    (0): Linear(in_features=512, out_features=784, bias=True)
    (1): Tanh()
  )
)


In [2]:
z = torch.rand(64, 100) * 2 - 1  # [0, 1] → [-1, 1]

# 확인
print(" - input data")
print("size:", z.shape)
print("최소값:", z.min().item())
print("최대값:", z.max().item())

output = gen(z)
print(" - output data")
print("size:", output.shape)
print("최소값:", output.min().item())
print("최대값:", output.max().item())


 - input data
size: torch.Size([64, 100])
최소값: -0.9995927810668945
최대값: 0.999503493309021
 - output data
size: torch.Size([64, 784])
최소값: -0.9509106874465942
최대값: 0.9578351974487305


In [3]:
import torch
import torch.nn as nn

# maxout class
class Maxout(nn.Module):
    def __init__(self, input_size, hidden_size, size_units=2, num_mo_layers=3, dropout_prob=0.5):
        super(Maxout, self).__init__()
        
        self.mo_layers = nn.ModuleList()
        
        for i in range(num_mo_layers):
            self.mo_layers.append(nn.Sequential(
                nn.Linear(input_size, hidden_size*size_units),
                nn.BatchNorm1d(hidden_size*size_units)
            ))
        self.size_units = size_units
        self.num_mo_layers = num_mo_layers
        
        self.dropout = nn.Dropout(dropout_prob)
        
    def forward(self, x):
        ini_x = [layer(x) for layer in self.mo_layers]
        for i in range(self.num_mo_layers):
            ini = ini_x[i].view(ini_x[i].size(0), -1, self.size_units)
            ini_x[i], _ = ini.max(dim=2)
            ini_x[i] = self.dropout(ini_x[i])
        output = torch.cat(ini_x, dim=1)
        
        return output
    
# maxout class test
maxout = Maxout(512, 256)
print(maxout)

input_tensor = torch.rand(64, 512) # [64, 512]

output = maxout(input_tensor)
print(" -- Maxout class test. --")
print(f"input: {input_tensor.shape}\noutput: {output.shape}")

Maxout(
  (mo_layers): ModuleList(
    (0-2): 3 x Sequential(
      (0): Linear(in_features=512, out_features=512, bias=True)
      (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)
 -- Maxout class test. --
input: torch.Size([64, 512])
output: torch.Size([64, 768])


In [4]:
import torch
import torch.nn as nn

# Discriminator with Maxout class
class Discriminator(nn.Module):
    def __init__(self, input_size=784, hidden_size=[512, 256, 128], num_mo_layer=3, output_size=1, dropout_prob=0.5):
        super(Discriminator, self).__init__()
        self.input_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size[0]),
            nn.BatchNorm1d(hidden_size[0]),
            nn.ReLU()
        )
        
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_size) - 1):
            if i == 0:
                in_size = hidden_size[i]
            else:
                in_size = hidden_size[i] * 3
            self.hidden_layers.append(
                Maxout(input_size=in_size, hidden_size=hidden_size[i+1], dropout_prob=dropout_prob)
            )
        
        self.output_layer = nn.Sequential(
            nn.Linear(hidden_size[-1]*num_mo_layer, output_size),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = self.input_layer(x)
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x
    
# Discriminator test    
dis = Discriminator()
print(dis)

input_tensor = torch.rand(64, 784)
output_tensor = dis(input_tensor)
print(" -- Discriminator test -- ")
print(f"input size: {input_tensor.shape}\noutput size: {output_tensor.shape}")

Discriminator(
  (input_layer): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (hidden_layers): ModuleList(
    (0): Maxout(
      (mo_layers): ModuleList(
        (0-2): 3 x Sequential(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): Maxout(
      (mo_layers): ModuleList(
        (0-2): 3 x Sequential(
          (0): Linear(in_features=768, out_features=256, bias=True)
          (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (output_layer): Sequential(
    (0): Linear(in_features=384, out_features=1, bias=True)
    (1): Sigmoid()
  )
)
 -- Di

## Gan 모델 선언

In [34]:
import torch
import torch.nn as nn

class Gen_loss(nn.Module):
    def __init__(self, eps=1e-8):
        super(Gen_loss, self).__init__()
        self.is_early = True
        self.eps=eps
        
    def set_phase(self, is_early):
        self.is_early = is_early
        
    def forward(self, d_gz):
        if self.is_early:
            loss = -torch.log(d_gz + self.eps).mean()
        else:
            loss = torch.log(1 - d_gz + self.eps).mean()
            
        return loss
    
class Dis_loss(nn.Module):
    def __init__(self, eps=1e-8):
        super(Dis_loss, self).__init__()
        self.eps=eps
        
        
    def forward(self, d_x, d_gz):
        x_loss = -torch.log(d_x + self.eps).mean()
        z_loss = -torch.log(1 - d_gz + self.eps).mean()
        loss = x_loss + z_loss
        return loss

In [41]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from models.gen_dis import build_discriminator, build_generator

# from models.utils import Gen_loss, Dis_loss

class Gan(nn.Module):
    def __init__(self, k=1, device="cpu"):
        super(Gan, self).__init__()
        
        self.generator = build_generator().to(device)
        self.discriminator = build_discriminator().to(device)
        
        self.k = k
        self.device = device
        
        self.gen_loss = Gen_loss()
        self.dis_loss = Dis_loss()
        
        self.optim_g = torch.optim.SGD(self.generator.parameters(), lr=0.0001, momentum=0.9)
        self.optim_d = torch.optim.SGD(self.discriminator.parameters(), lr=0.0001, momentum=0.9)
        
    
    def set_loss(self, gen_loss, dis_loss):
        self.gen_loss = gen_loss
        self.dis_loss = dis_loss
    
    def set_optimizer(self, optimizer_g, optimizer_d):
        self.optim_g = optimizer_g
        self.optim_d = optimizer_d
        
    def train_one_epoch(self, dataLoader, epoch=0):
        # train discriminator k steps
        self.discriminator.train()
        self.generator.eval()
        
        epoch_loss_d = 0.0
        for k in range(self.k):
            with tqdm(dataLoader, unit="batch", leave=False) as tepoch:
                for x in tepoch:
                    tepoch.set_description(f"Epoch {epoch+1} | discriminator {k+1}/{self.k}")

                    self.optim_d.zero_grad()
                    
                    x = x.to(self.device)
                    
                    z = torch.randn(64, 100).to(self.device)
                    gz = self.generator(z)
                    d_gz = self.discriminator(gz)
                    dx = self.discriminator(x)
                    
                    loss_d = self.dis_loss(dx, d_gz)
                    loss_d.backward()
                    
                    self.optim_d.step()
                    epoch_loss_d += loss_d.item()

        epoch_loss_d = epoch_loss_d / self.k

        
        # train generator one steps
        self.generator.train()
        self.discriminator.eval()
        
        epoch_loss_g = 0.0
        with tqdm(dataLoader, unit="batch", leave=False) as tepoch:
            for x in tepoch:
                tepoch.set_description(f"Epoch {epoch+1} | generator")

                self.optim_g.zero_grad()
                x = x.to(self.device)
                
                z = torch.randn(64, 100).to(self.device)
                gz = self.generator(z)
                d_gz = self.discriminator(gz)
                
                loss_g = self.gen_loss(d_gz)
                loss_g.backward()
                
                self.optim_g.step()
                epoch_loss_g += loss_g.item()
        
        return epoch_loss_d, epoch_loss_g
    
    def train(self, trainLoader, epochs, log_path="experiment_01", early_rate=0.1):
        writer = SummaryWriter(log_dir=f"./runs/{log_path}")
        early_epoch = int(epochs * early_rate)
        print(early_epoch)
        for epoch in range(epochs):
            if epoch > early_epoch:
                self.gen_loss.set_phase(False)
                
            train_loss_d, train_loss_g = self.train_one_epoch(trainLoader, epoch=epoch)
            writer.add_scalar("Loss/Discriminator", train_loss_d, epoch)
            writer.add_scalar("Loss/Generator", train_loss_g, epoch)
            print(f"Epoch [{epoch+1}/{epochs}] train_loss_d: {train_loss_d}, train_loss_g: {train_loss_g}")

        writer.close()

        
        
        
    
    def forward(self, x):
        return x

In [42]:
from data.datasets import MNISTDataset, transform
from torch.utils.data import DataLoader
train_dataset = MNISTDataset('./mnist', train=True, transform=transform)



dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=64,
    shuffle=True
)

In [2]:
100%7

2

In [43]:
gan = Gan(device="cuda:0")
gan.train(trainLoader=dataloader, epochs=50)

5


                                                                                  

Epoch [1/50] train_loss_d: 1323.9593309164047, train_loss_g: 195.78534737974405


                                                                                  

KeyboardInterrupt: 

In [16]:
gan = Gan(device="cuda:0")
gan.train_one_epoch(dataloader)

                                                                                               

(1321.750668644905, 191.3832784742117)

In [15]:
gan.train_one_epoch(dataloader)

(1218.3548020124435, 113.96270113438368)

In [16]:
for epoch in range(10):
    print(gan.train_one_epoch(dataloader))

(1051.3563024401665, 38.42546373791993)


KeyboardInterrupt: 

In [100]:
torch.randn(64, 100).shape

torch.Size([64, 100])