In [1]:
# import torch, random, numpy as np, os

# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [2]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # must be before torch import

In [3]:
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plts
import random

In [4]:
import torch

# Fix Python, NumPy, PyTorch RNGs
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Force deterministic behavior in cuDNN
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
from torchvision.datasets.mnist import *

def Tansfrm(im):
    im = np.array(im).astype(np.float32)
    im = (im -128)/255
    return torch.tensor(im, requires_grad=False).cuda()

dataset_first = MNIST("", train=True, transform= Tansfrm, download=True )

In [7]:
dataloader_first = DataLoader(dataset_first, batch_size=2048, shuffle=True)
print(dataloader_first.sampler.generator)
# <torch._C.Generator object at ...>

print(dataset_first.data.shape)


None
torch.Size([60000, 28, 28])


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Input: (1, 28, 28)

        # Convolutional layers with spectral norm
        self.conv1 = spectral_norm(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1))   # -> (32, 28, 28)
        self.conv2 = spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1))  # -> (64, 28, 28)

        self.pool = nn.MaxPool2d(2, 2)  # halves each spatial dim

        self.conv3 = spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)) # -> (128, 14, 14)
        self.conv4 = spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)) # -> (256, 14, 14)

        # Fully connected layers with spectral norm
        self.fc1 = spectral_norm(nn.Linear(256 * 7 * 7, 512))
        self.fc2 = spectral_norm(nn.Linear(512, 128))
        self.fc3 = spectral_norm(nn.Linear(128, 1))  # critic output (no sigmoid)


    def forward(self, x):
        # Ensure shape (B, 1, 28, 28)
        x = x.view(-1, 1, 28, 28)

        x = F.leaky_relu(self.conv1(x), 0.2)
        x = self.pool(F.leaky_relu(self.conv2(x), 0.2))   # -> (64, 14, 14)

        x = F.leaky_relu(self.conv3(x), 0.2)
        x = self.pool(F.leaky_relu(self.conv4(x), 0.2))   # -> (256, 7, 7)

        x = x.view(-1, 256 * 7 * 7)  # flatten
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)

        x = self.fc3(x)   # final critic score (can be any real number)
        return x


In [9]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc = nn.Linear(5, 32*7*7)
        self.bn1 = nn.BatchNorm2d(16)
        self.deconv1 = nn.ConvTranspose2d(32, 16,kernel_size=4, stride = 2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(16, 1,kernel_size=4, stride = 2, padding=1)


    def forward(self, z):
        x = self.fc(z).view(-1, 32, 7, 7)
        x = F.relu(self.bn1(self.deconv1(x)))
        x = torch.tanh(self.deconv2(x))   # (batch, 1, 28, 28)
        return x.squeeze(1)               # (batch, 28, 28)


generator = Generator().cuda()
discriminator = Discriminator().cuda()

import torch.nn.init as init

# You can iterate through all modules and apply initialization
for m in generator.modules():
      if isinstance(m, nn.Linear):
          init.uniform_(m.weight, a=-0.3, b=0.3)
          # Optionally initialize biases to zero or a small constant
          if m.bias is not None:
              init.constant_(m.bias, 0)

for m in discriminator.modules():
      if isinstance(m, nn.Linear):
          init.uniform_(m.weight, a=-0.3, b=0.3)
          # Optionally initialize biases to zero or a small constant
          if m.bias is not None:
              init.constant_(m.bias, 0)


In [10]:
g_optimizer = torch.optim.SGD(generator.parameters(), lr=0.00001, momentum=0.7, nesterov=True)
d_optimizer = torch.optim.SGD(discriminator.parameters(), lr=0.00001, momentum=0.7, nesterov=True)


In [11]:
import matplotlib.pyplot as plt

def plot_images(fake_images, num = 5):

    for i in range(num):
        img_tensor = fake_images[i]
        img = img_tensor.permute(1, 2, 0).detach().cpu().numpy().astype(np.float64)

        # Normalize to [0,1]
        img = img*255+128
        img = torch.min(torch.tensor(255), torch.tensor(img))
        # Scale to [0,255] and convert to uint8
        img_uint8 = (img).numpy().astype(np.uint8)

        plt.imshow(img_uint8[:, :, ::-1])
        plt.axis("off")
        plt.show()

In [12]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [13]:
!ls /content/drive/MyDrive/gan_models/

discriminator_checkpoint_17540.pth  generator_checkpoint_epoch_17540.pth
discriminator_checkpoint_17560.pth  generator_checkpoint_epoch_17560.pth
discriminator_checkpoint_17580.pth  generator_checkpoint_epoch_17580.pth
discriminator_checkpoint_17600.pth  generator_checkpoint_epoch_17600.pth
discriminator_checkpoint_17620.pth  generator_checkpoint_epoch_17620.pth
discriminator_checkpoint_17640.pth  generator_checkpoint_epoch_17640.pth
discriminator_checkpoint_17660.pth  generator_checkpoint_epoch_17660.pth
discriminator_checkpoint_17680.pth  generator_checkpoint_epoch_17680.pth
discriminator_checkpoint_17700.pth  generator_checkpoint_epoch_17700.pth
discriminator_checkpoint_17720.pth  generator_checkpoint_epoch_17720.pth
discriminator_checkpoint_17740.pth  generator_checkpoint_epoch_17740.pth
discriminator_checkpoint_17760.pth  generator_checkpoint_epoch_17760.pth
discriminator_checkpoint_17780.pth  generator_checkpoint_epoch_17780.pth
discriminator_checkpoint_17800.pth  generator_check

In [14]:
# !ls /content/drive/MyDrive
checkpoint = torch.load("/content/drive/MyDrive/gan_models/generator_checkpoint_epoch_20740.pth", map_location="cpu", weights_only=False)
generator.load_state_dict(checkpoint["model_state_dict"], strict=True)
g_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])


# Restore RNGs
random.setstate(checkpoint["python_rng_state"])
np.random.set_state(checkpoint["numpy_rng_state"])
torch.set_rng_state(checkpoint["rng_state"])
torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])



checkpoint = torch.load("/content/drive/MyDrive/gan_models/discriminator_checkpoint_20740.pth", map_location="cpu", weights_only=False)
discriminator.load_state_dict(checkpoint["model_state_dict"], strict=True)
d_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [15]:
dataloader_iter = iter(dataloader_first)

In [16]:
def init_ema_model(model):
    """Make a deepcopy of model for EMA tracking."""
    ema_model = copy.deepcopy(model)
    for p in ema_model.parameters():
        p.requires_grad_(False)
    return ema_model

In [17]:
def update_ema_model(model, ema_model, decay=0.9):
    """Update EMA model in place."""
    with torch.no_grad():
        msd = model.state_dict()
        for k, ema_v in ema_model.state_dict().items():
            model_v = msd[k].detach()
            if torch.is_floating_point(ema_v):
                ema_v.mul_(decay).add_(model_v, alpha=1 - decay)
            else:
                ema_v.copy_(model_v)


In [18]:
def load_ema_to_model(model, ema_model):
    """Copy EMA model weights into the actual model."""
    with torch.no_grad():
        for p, ema_p in zip(model.state_dict().values(), ema_model.state_dict().values()):
            p.copy_(ema_p)

In [19]:
def dump_images(fake_images):
    for i in range(fake_images.shape[0]):
      img_tensor = fake_images[i]
      img = img_tensor.detach().cpu().numpy().astype(np.float32)

      # Normalize to [0,1]
      img = img*255+128
      img = torch.min(torch.tensor(255), torch.tensor(img))
      # Scale to [0,255] and convert to uint8
      img_uint8 = (img).numpy().astype(np.uint8)

      cv2.imwrite(f"/content/drive/MyDrive/dump_images/{i}.png", img_uint8)


In [None]:
import copy

epochs = 0

save_interval=20

checkpoint_dir = "/content/drive/MyDrive/gan_models"

#normalize_each_param(discriminator)

discriminator_ema = init_ema_model(discriminator).cuda()  # EMA model has same architecture


for epoch in range(20740, epochs+280004):


    checkpoint_dir = "/content/drive/MyDrive/gan_models"

    if (epoch)%save_interval == 0:

      checkpoint_path =f"{checkpoint_dir}/generator_checkpoint_epoch_{epoch}.pth"


      torch.save({
            'epoch': epoch,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': g_optimizer.state_dict(),
            "rng_state": torch.get_rng_state(),
            "numpy_rng_state": np.random.get_state(),
            "python_rng_state": random.getstate(),
            "cuda_rng_state": torch.cuda.get_rng_state()
        }, checkpoint_path)
      checkpoint_path = f"{checkpoint_dir}/discriminator_checkpoint_{epoch}.pth"
      torch.save({
            'epoch': epoch,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': d_optimizer.state_dict()
        }, checkpoint_path)

      print(f"Model checkpoint saved to {checkpoint_path}")


    print("Epoch starting {}".format(epoch))
    steps = 0

    d_optimizer = torch.optim.SGD(discriminator.parameters(), lr=0.00001, momentum=0.7, nesterov=True)

    try:
        #discriminator_loss = torch.tensor(0.0).cuda()
        while True:
            steps += 1
            if steps > 20:
                break

            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            try:
              ims, _ = next(dataloader_iter)

            except StopIteration:
              dataloader_first = DataLoader(dataset_first, batch_size=2048, shuffle=True)
              dataloader_iter = iter(dataloader_first)
              ims, _ = next(dataloader_iter)


            real_pred = discriminator(ims)

            gauss = torch.randn(len(ims), 5, device="cuda")

            fake_images = generator(gauss)

            fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()

            fake_images_labeled = fake_images
            fake_pred = discriminator(fake_images_labeled)


            temp_pred = discriminator(0.5*ims+0.5*fake_images_labeled)

            temp_pred2 = discriminator(0.25*ims+0.75*fake_images_labeled)

            ls = 10*(fake_pred-real_pred).mean() + 10*((temp_pred-0.5*real_pred-0.5*fake_pred).abs().mean())+ 10*((temp_pred2-0.25*real_pred-0.75*fake_pred).abs().mean())


            ls.backward()
            d_optimizer.step()

            print("Discriminator loss in class 0 images {:.7f} and step = {} fake_pred {:.7f} real_pred {:.7f} temp_pred {:.7f} 1q {:.7f} 3q {:.7f}".format(ls, steps, np.float64(fake_pred.mean().cpu()), np.float64(real_pred.mean().cpu()), np.float64(temp_pred.mean().cpu()), np.float64(temp_pred2.mean().cpu()), np.float64(discriminator(0.75*ims+0.25*fake_images_labeled).mean().cpu())))


    except Exception as e:
        print(e)
        pass

    update_ema_model(discriminator, discriminator_ema)
    load_ema_to_model(discriminator, discriminator_ema)


    #generator_ema = init_ema_model(generator).cuda()  # EMA model has same architecture

    steps = 0

    g_optimizer = torch.optim.SGD(generator.parameters(), lr=0.00001, momentum=0.7, nesterov=True)


    while True:
        steps += 1
        if steps > 40:
            break

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()


        gauss = torch.randn(2048, 5, device="cuda")

        fake_images = generator(gauss)


        fake_images_labeled = fake_images


        fake_pred = discriminator(fake_images_labeled)

        ls = -10*(fake_pred).mean()
        ls.backward()

            #normalize_each_param(discriminator)

        g_optimizer.step()
        print("Generator loss in class 0 images {:.7f} and step = {} ".format(ls, steps))



    #update_ema_model(generator, generator_ema, decay=0.85)
    #load_ema_to_model(generator, generator_ema)

    if (epoch)%save_interval==0:
        with torch.no_grad():

            gauss = torch.randn(100, 5, device="cuda")

            fake_images = generator(gauss)

            fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()

            fake_images_labeled = fake_images
            dump_images(torch.clone(fake_images_labeled))




Model checkpoint saved to /content/drive/MyDrive/gan_models/discriminator_checkpoint_20740.pth
Epoch starting 20740


  fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print("Discriminator loss in class 0 images {:.7f} and step = {} fake_pred {:.7f} real_pred {:.7f} temp_pred {:.7f} 1q {:.7f} 3q {:.7f}".format(ls, steps, np.float64(fake_pred.mean().cpu()), np.float64(real_pred.mean().cpu()), np.float64(temp_pred.mean().cpu()), np.float64(temp_pred2.mean().cpu()), np.float64(discriminator(0.75*ims+0.25*fake_images_labeled).mean().cpu())))


Discriminator loss in class 0 images 191.5104065 and step = 1 fake_pred 62.9859009 real_pred 46.6182098 temp_pred 55.8993835 1q 59.7100258 3q -264.9158020
Discriminator loss in class 0 images -1379.3699951 and step = 2 fake_pred -370.1751709 real_pred -230.0538788 temp_pred -299.9184570 1q -334.7703857 3q -310.9025269
Discriminator loss in class 0 images -1552.3791504 and step = 3 fake_pred -429.3036804 real_pred -271.1779785 temp_pred -349.2899475 1q -388.9368896 3q -320.7620850
Discriminator loss in class 0 images -1575.7067871 and step = 4 fake_pred -441.8797607 real_pred -280.9449158 temp_pred -360.0115967 1q -400.2639465 3q -318.1065979
Discriminator loss in class 0 images -1576.9171143 and step = 5 fake_pred -439.0374146 real_pred -277.7919006 temp_pred -356.9012146 1q -397.2281494 3q -311.5064087
Discriminator loss in class 0 images -1570.5412598 and step = 6 fake_pred -432.8453369 real_pred -272.7234802 temp_pred -351.6740417 1q -391.7859192 3q -307.5366821
Discriminator loss i

  fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()


Epoch starting 20741
Discriminator loss in class 0 images 202.7040863 and step = 1 fake_pred -15.7248688 real_pred -33.4686699 temp_pred -23.3707256 1q -19.5831528 3q 106.7780685
Discriminator loss in class 0 images -1250.5360107 and step = 2 fake_pred 4.8215561 real_pred 143.9996948 temp_pred 68.7688293 1q 31.2445469 3q 320.1502380
Discriminator loss in class 0 images -1611.2330322 and step = 3 fake_pred 190.2107697 real_pred 359.2837830 temp_pred 270.6653748 1q 228.6114197 3q 361.1512451
Discriminator loss in class 0 images -1698.7778320 and step = 4 fake_pred 227.1005249 real_pred 406.8980103 temp_pred 312.2838135 1q 266.8466797 3q 368.6935425
Discriminator loss in class 0 images -1697.6114502 and step = 5 fake_pred 229.3674622 real_pred 410.1645813 temp_pred 314.3421631 1q 268.9553223 3q 352.4256897
Discriminator loss in class 0 images -1719.0651855 and step = 6 fake_pred 217.1934509 real_pred 397.7823792 temp_pred 303.2144775 1q 257.9381714 3q 334.8404541
Discriminator loss in cla