<a href="https://colab.research.google.com/github/kerimoglutolga/AdversarialLearning/blob/master/GATconditioned.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np 

In [2]:
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

train_set = MNIST(
    './', 
    download=True,
    transform=transforms.ToTensor(), 
    train=True
)

test_set = MNIST(
    './', 
    download=True,
    transform=transforms.ToTensor(), 
    train=False
)

--2021-10-04 09:47:19--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-10-04 09:47:20--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [            <=>     ]  33.20M  6.35MB/s    in 5.5s    

2021-10-04 09:47:27 (6.05 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-ubyte
MNIST/raw/tra

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=128, shuffle=False)

In [4]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(10, 10)

       
        self.add_conditions = nn.Sequential(
        nn.Linear(28*28 + 10, 256), nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(256, 784), nn.LeakyReLU(0.2, inplace=True),    
        )
        self.net = nn.Sequential(
        nn.Conv2d(1, 48, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(0.2, inplace=True),
        nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(48, 48, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(48, 1, kernel_size=1), nn.Tanh()
        )

    def forward(self, x, labels):
        # Concatenate label embedding and image to produce input
        img = torch.cat((self.label_emb(labels), x), -1)
        img = self.add_conditions(img)
        img = img.view(img.size(0), 28,28)
        img = self.net(img)
        return img


In [6]:
class Classifier(nn.Module):
  def __init__(self):
    super().__init__()
    self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)
    self.add_conditions = nn.Sequential(
      nn.Linear(28*28 + 10, 256), nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 784), nn.LeakyReLU(0.2, inplace=True),    
    )
    self.net = nn.Sequential(
        nn.Conv2d(1, 48, kernel_size=3), nn.ReLU(),
        nn.Conv2d(48, 48, kernel_size=3, stride=2, padding=1), nn.ReLU(),
        nn.Conv2d(48, 96, kernel_size=3), nn.ReLU(),
        nn.Conv2d(96, 96, kernel_size=3, stride=2, padding=1), nn.ReLU(),
        nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(96, 96, kernel_size=1, stride=1, padding=1), nn.ReLU(),
        nn.Conv2d(96, 10, kernel_size=1), 
        nn.AvgPool2d(kernel_size=8),
    )
  def forward(self, x, labels):
    img = torch.cat((self.label_emb(labels), x), -1)
    img = self.add_conditions(img)
    img = img.view(img.size(0), 28,28)
    img = self.net(img)
    logits = F.softmax(self.net(img), dim=1)
    return logits.view(-1, 10)

In [None]:
epochs = 20
epsilon = 0.1
alpha = 0.5
cg = 0.5
k = 1