In [3]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge

In [None]:
os.environ['CUDA_LAUNCH_BLOCKING']= "1"

In [2]:
import os

In [4]:
os.environ

environ{'ALLUSERSPROFILE': 'C:\\ProgramData',
        'APPDATA': 'C:\\Users\\kanan\\AppData\\Roaming',
        'ASL.LOG': 'Destination=file',
        'CHOCOLATEYINSTALL': 'C:\\ProgramData\\chocolatey',
        'CHOCOLATEYLASTPATHUPDATE': '132607965707908918',
        'COMMONPROGRAMFILES': 'C:\\Program Files\\Common Files',
        'COMMONPROGRAMFILES(X86)': 'C:\\Program Files (x86)\\Common Files',
        'COMMONPROGRAMW6432': 'C:\\Program Files\\Common Files',
        'COMPUTERNAME': 'KIRTAN',
        'COMSPEC': 'C:\\WINDOWS\\system32\\cmd.exe',
        'CONDA_DEFAULT_ENV': 'pytorch',
        'CONDA_EXE': 'C:\\Users\\kanan\\anaconda3\\Scripts\\conda.exe',
        'CONDA_PREFIX': 'C:\\Users\\kanan\\anaconda3\\envs\\pytorch',
        'CONDA_PREFIX_1': 'C:\\Users\\kanan\\anaconda3',
        'CONDA_PROMPT_MODIFIER': '(pytorch) ',
        'CONDA_PYTHON_EXE': 'C:\\Users\\kanan\\anaconda3\\python.exe',
        'CONDA_SHLVL': '2',
        'CUDA_LAUNCH_BLOCKING': '1',
        'CUDA_PATH': 'C:\

In [5]:
class Discriminator(nn.Module):
    def __init__(self,in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)

In [6]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dims):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dims),
            nn.Tanh()
        )
    def forward(self,x):
        return self.gen(x)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [8]:
lr = 3e-4
z_dim=64
img_dims = 28*28*1
batch_size=32
num_epochs=50

In [9]:
disc= Discriminator(img_dims).to(device)
gen = Generator(z_dim,img_dims).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transforms= transforms.Compose(
        [transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))]
        )

In [10]:
dataset=datasets.MNIST(root="dataset/",transform= transforms,download= True)
loader= DataLoader(dataset,batch_size = batch_size,shuffle= True)
optim_disc=optim.Adam(disc.parameters(),lr=lr)
optim_gen = optim.Adam(gen.parameters(),lr=lr)
criterion = nn.BCELoss()

In [11]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In [15]:
steps=0

In [20]:
for epoch in range(num_epochs):
    for batch_index,(real,_) in enumerate(loader):
        
        real = real.reshape(-1,784).to(device)
        batch_size = real.shape[0]
        noise = torch.randn(batch_size,z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        loosD_real = criterion(disc_real,torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        loosD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        loosD = (loosD_fake + loosD_real)/2
        disc.zero_grad()
        loosD.backward(retain_graph = True)
        optim_disc.step()
        output= disc(fake).view(-1)
        loosG= criterion(output,torch.ones_like(output))
        gen.zero_grad()
        loosG.backward()
        optim_gen.step()
    
        if batch_index == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_index}/{len(loader)} \
                Loss D : {loosD:.4f} , loosG: {loosG:.4f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                data = real.reshape(-1,1,28,28)
                image_grid_fake = torchvision.utils.make_grid(fake,normalize=True)
                image_grid_real = torchvision.utils.make_grid(data,normalize = True)
                
                writer_fake.add_image("Mnist fake image" , image_grid_fake, global_step = steps)
                writer_real.add_image("Mnist real image" , image_grid_real, global_step = steps)
                
                steps +=1
                
                
        

Epoch [0/50] Batch 0/1875                 Loss D : 0.1005 , loosG: 3.4196
Epoch [1/50] Batch 0/1875                 Loss D : 0.1307 , loosG: 3.6059
Epoch [2/50] Batch 0/1875                 Loss D : 0.0442 , loosG: 3.5541
Epoch [3/50] Batch 0/1875                 Loss D : 0.0310 , loosG: 3.7329
Epoch [4/50] Batch 0/1875                 Loss D : 0.0132 , loosG: 5.1319
Epoch [5/50] Batch 0/1875                 Loss D : 0.0490 , loosG: 5.2856
Epoch [6/50] Batch 0/1875                 Loss D : 0.0378 , loosG: 4.4586
Epoch [7/50] Batch 0/1875                 Loss D : 0.1150 , loosG: 5.5717
Epoch [8/50] Batch 0/1875                 Loss D : 0.0207 , loosG: 4.6727
Epoch [9/50] Batch 0/1875                 Loss D : 0.0094 , loosG: 5.1724
Epoch [10/50] Batch 0/1875                 Loss D : 0.0063 , loosG: 5.6917
Epoch [11/50] Batch 0/1875                 Loss D : 0.0146 , loosG: 6.1333
Epoch [12/50] Batch 0/1875                 Loss D : 0.0117 , loosG: 4.6335
Epoch [13/50] Batch 0/1875         

In [21]:
print(torch.__version__)

1.8.1


In [22]:
!pip install -q tb-nightly

In [25]:
!tensorboard --logdir runs

^C


In [26]:
%load_ext tensorboard
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 21152), started 1:40:40 ago. (Use '!kill 21152' to kill it.)