<a href="https://colab.research.google.com/github/mou121/MOU/blob/master/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Tue Apr 30 12:00:51 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   62C    P8              12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [5]:
import torch,pdb
from torch.utils.data import DataLoader
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

In [6]:
#visualization function
def show(tensor, ch=1, size=(28,28),num=16):
  data=tensor.detach().cpu().view(-1,ch,*size)
  grid=make_grid(data[:num],nrows=4)
  plt.imshow(grid)
  plt.show()

In [7]:
epochs=500
cur_step=0
info_step=300
mean_gen_loss=0
mean_disc_loss=0
z_dim=64
lr=0.00001
loss_func=nn.BCEWithLogitsLoss()
bs=128
device='cuda'
dataloader=DataLoader(MNIST('.',download=True, transform=transforms.ToTensor()),shuffle=True,batch_size=bs)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [05:02<00:00, 32795.79it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 133500.66it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:50<00:00, 32766.41it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5206485.04it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw






In [19]:
#declare our models
#Generator
def genBlock(inp,out):
  return nn.Sequential(nn.Linear(inp,out),
                       nn.BatchNorm1d(out),
                       nn.ReLU(inplace=True))

class Generator(nn.Module):
  def __init__(self,z_dim=64,i_dim=784,h_dim=128):
    super().__init__()
    self.gen = nn.Sequential(genBlock(z_dim,h_dim),
                             genBlock(h_dim,h_dim*2),
                             genBlock(h_dim*2,h_dim*4),
                             genBlock(h_dim*4,h_dim*8),
                             nn.Linear(h_dim*8,i_dim),
                             nn.Sigmoid(),)
  def forward(self,noise):
    return self.gen(noise)

def gen_noise(number,z_dim):
  return torch.randn(number,z_dim).to(device)

In [9]:
def discBlock(inp,out):
  return nn.Sequential(nn.Linear(inp,out),
                       nn.LeakyReLU(0.2))

class Discriminator(nn.Module):
  def __init__(self, i_dim=784,h_dim=256):
    super().__init__()
    self.disc=nn.Sequential(discBlock(i_dim,h_dim*4),
                            discBlock(h_dim*4,h_dim*2),
                            discBlock(h_dim*2,h_dim),
                            nn.Linear(h_dim,1))
    def forward(self,image):
      return self.disc(image)


In [15]:
gen=Generator(z_dim).to(device)
gen_opt=torch.optim.Adam(gen.parameters(),lr=lr)
disc=Discriminator().to(device)
disc_opt=torch.optim.Adam(disc.parameters(),lr=lr)

In [16]:
x,y=next(iter(dataloader))
print(x.shape,y.shape)
print(y[:10])

torch.Size([128, 1, 28, 28]) torch.Size([128])
tensor([6, 7, 6, 7, 6, 5, 2, 0, 0, 8])


In [20]:
def calc_gen_loss(loss_func,gen,disc,number,z_dim):
  noise=gen_noise(number,z_dim)
  fake=gen(noise)
  pred=disc(fake)
  targets=torch.ones_like(pred)
  gen_loss=loss_func(pred,targets)

def calc_disc_loss(loss_func,gen,disc,number,real,z_dim):
  noise=gen_noise(number,z_dim)
  fake=gen(noise)
  pred_fake=disc(fake.detach())
  targets_fake=torch.zeroes_like(pred_fake)
  disc_fake_loss=loss_func(pred_fake,targets_fake)
  pred_real=disc(real)
  targets_real=torch.ones_like(pred_real)
  disc_real_loss=loss_func(pred_real,targets_real)
  disc_loss=(disc_fake_loss+disc_real_loss)/2
  return


