In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# MNIST dataset
num_epochs = 10
batch_size = 200
learning_rate = 0.0001

root = './MNIST'

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0,), std=(1,))])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

#데이터 다운로드
train_data = dset.MNIST(root=root, train=True, transform=transform, download=True)
test_data = dset.MNIST(root=root, train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, drop_last=True)

label_tags = {
    0: '0', 
    1: '1', 
    2: '2', 
    3: '3', 
    4: '4', 
    5: '5', 
    6: '6',
    7: '7', 
    8: '8', 
    9: '9'
}

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting ./MNIST/MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting ./MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting ./MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting ./MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
# GAN

#D
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            torch.nn.Linear(784, 256, bias=True),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            torch.nn.Linear(256, 64, bias=True),
            torch.nn.BatchNorm1d(64),
            torch.nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            torch.nn.Linear(64, 10, bias=True)
        )
    
    def forward(self, x):
        x = x.view(x.size(0), -1) 
        x_out = self.layer1(x)
        x_out = self.layer2(x_out)
        x_out = self.layer3(x_out)
        return x_out

# 로스계산
criterion = torch.nn.CrossEntropyLoss().to(device)

# weight 계산,,,,?




In [None]:
# G(noise=100)

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            torch.nn.Linear(100, 256, bias=True),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU()
        )
        self.layer2 = nn.Sequential(
            torch.nn.Linear(256, 512, bias=True),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU()
        )
        self.layer3 = nn.Sequential(
            torch.nn.Linear(512, 10, bias=True)
        )

    def forward(self, x):
        x = x.view(x.size(0), 100)                
        out = self.model(x)
        return out


In [None]:
  #초기화
  discriminator = Discriminator().cuda()                 
  generator = Generator().cuda()  

  #옵티마이저
  g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
  d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)

In [2]:
#학습을 시켜야하는데,,,
#어떻게 학습을 시킬까,,,

def train_discriminator(discriminator, x, real_labels, fake_images, fake_labels, y):
    discriminator.zero_grad()
    outputs = discriminator(x, y)
    real_loss = criterion(outputs, real_labels)
    real_score = outputs

    outputs = discriminator(fake_images, y)
    fake_loss = criterion(outputs, fake_labels)
    fake_score = fake_loss

    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    return d_loss, real_score, fake_score

def train_generator(generator, discriminator_outputs, real_labels, y):
    generator.zero_grad()
    g_loss = criterion(discriminator_outputs, real_labels)

    g_loss.backward()
    g_optimizer.step()
    return g_loss