In [1]:
import torch, random, numpy as np, os

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

In [2]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"  # must be before torch import

In [3]:
import numpy as np
from PIL import Image
import cv2
import matplotlib.pyplot as plts
import random

In [4]:
import torch

# Fix Python, NumPy, PyTorch RNGs
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Force deterministic behavior in cuDNN
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
from torchvision.datasets.mnist import *

def Tansfrm(im):
    im = np.array(im).astype(np.float32)
    im = (im -128)/255
    return torch.tensor(im, requires_grad=False).cuda()

dataset_first = MNIST("", train=True, transform= Tansfrm, download=True )

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 499kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.59MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.85MB/s]


In [7]:
dataloader_first = DataLoader(dataset_first, batch_size=2048, shuffle=True)
print(dataloader_first.sampler.generator)
# <torch._C.Generator object at ...>

print(dataset_first.data.shape)


None
torch.Size([60000, 28, 28])


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Input: (1, 28, 28)

        # Convolutional layers with spectral norm
        self.conv1 = spectral_norm(nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1))   # -> (32, 28, 28)
        self.conv2 = spectral_norm(nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1))  # -> (64, 28, 28)

        self.pool = nn.MaxPool2d(2, 2)  # halves each spatial dim

        self.conv3 = spectral_norm(nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)) # -> (128, 14, 14)
        self.conv4 = spectral_norm(nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)) # -> (256, 14, 14)

        # Fully connected layers with spectral norm
        self.fc1 = spectral_norm(nn.Linear((256+64) * 7 * 7, 512))
        self.fc2 = spectral_norm(nn.Linear(512, 128))
        self.fc3 = spectral_norm(nn.Linear(128, 1))  # critic output (no sigmoid)


    def forward(self, x0):
        # Ensure shape (B, 1, 28, 28)
        x0 = x0.view(-1, 1, 28, 28)

        x1 = F.relu(self.conv1(x0), 0.2)
        x2 = self.pool(F.relu(self.conv2(x1), 0.2))   # -> (64, 14, 14)

        x3 = F.relu(self.conv3(x2+F.interpolate(x0, scale_factor=0.5, mode='bilinear', align_corners=False)), 0.2)
        x4 = self.pool(F.relu(self.conv4(x3), 0.2))   # -> (256, 7, 7)

        x = torch.cat([x4,F.interpolate(x2, scale_factor=0.5, mode='bilinear', align_corners=False)], dim=1)
        x = x.view(-1, (256+64) * 7 * 7)  # flatten
        x_p = F.relu(self.fc1(x), 0.2)
        x_pp = F.relu(self.fc2(x_p), 0.2)

        x_t = self.fc3(x_pp)   # final critic score (can be any real number)
        return x_t


In [9]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc = nn.Linear(5, 32*7*7)
        self.bn1 = nn.BatchNorm2d(16)
        self.deconv1 = nn.ConvTranspose2d(32, 16,kernel_size=4, stride = 2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(16, 8,kernel_size=3, stride = 1, padding=1)
        self.deconv3 = nn.ConvTranspose2d(8, 4,kernel_size=4, stride = 2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(4, 1,kernel_size=3, stride = 1, padding=1)
        self.bn2 = nn.BatchNorm2d(8)
        self.bn3 = nn.BatchNorm2d(4)


    def forward(self, z):
        x = self.fc(z).view(-1, 32, 7, 7)
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))
        x = torch.tanh(self.deconv4(x))   # (batch, 1, 28, 28)
        return x.squeeze(1)               # (batch, 28, 28)


generator = Generator().cuda()
discriminator = Discriminator().cuda()

import torch.nn.init as init

# You can iterate through all modules and apply initialization
for m in generator.modules():
      if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
          init.xavier_uniform_(m.weight)
          # Optionally initialize biases to zero or a small constant
          if m.bias is not None:
              init.constant_(m.bias, 0)

for m in discriminator.modules():
      if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
          init.xavier_uniform_(m.weight)
          # Optionally initialize biases to zero or a small constant
          if m.bias is not None:
              init.constant_(m.bias, 0)
              m.bias.requires_grad_(False)


In [10]:
g_optimizer = torch.optim.SGD(generator.parameters(), lr=0.0003, momentum=0.7, nesterov=True)
d_optimizer = torch.optim.SGD(discriminator.parameters(), lr=0.0003, momentum=0.7, nesterov=True)


In [11]:
import matplotlib.pyplot as plt

def plot_images(fake_images, num = 5):

    for i in range(num):
        img_tensor = fake_images[i]
        img = img_tensor.permute(1, 2, 0).detach().cpu().numpy().astype(np.float64)

        # Normalize to [0,1]
        img = img*255+128
        img = torch.min(torch.tensor(255), torch.tensor(img))
        # Scale to [0,255] and convert to uint8
        img_uint8 = (img).numpy().astype(np.uint8)

        plt.imshow(img_uint8[:, :, ::-1])
        plt.axis("off")
        plt.show()

In [12]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [13]:
checkpoint = torch.load("/content/drive/MyDrive/gan_models/generator_checkpoint_epoch_48800.pth", map_location="cpu", weights_only=False)
generator.load_state_dict(checkpoint["model_state_dict"], strict=True)
g_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])


# Restore RNGs
random.setstate(checkpoint["python_rng_state"])
np.random.set_state(checkpoint["numpy_rng_state"])
torch.set_rng_state(checkpoint["rng_state"])
torch.cuda.set_rng_state(checkpoint["cuda_rng_state"])



checkpoint = torch.load("/content/drive/MyDrive/gan_models/discriminator_checkpoint_48800.pth", map_location="cpu", weights_only=False)
discriminator.load_state_dict(checkpoint["model_state_dict"], strict=True)
d_optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [14]:
dataloader_iter = iter(dataloader_first)

In [15]:
def init_ema_model(model):
    """Make a deepcopy of model for EMA tracking."""
    ema_model = copy.deepcopy(model)
    for p in ema_model.parameters():
        p.requires_grad_(False)
    return ema_model

In [16]:
def update_ema_model(model, ema_model, decay=0.75):
    """Update EMA model in place."""
    with torch.no_grad():
        msd = model.state_dict()
        for k, ema_v in ema_model.state_dict().items():
            model_v = msd[k].detach()
            if torch.is_floating_point(ema_v):
                ema_v.mul_(decay).add_(model_v, alpha=1 - decay)
            else:
                ema_v.copy_(model_v)


In [17]:
def load_ema_to_model(model, ema_model):
    """Copy EMA model weights into the actual model."""
    with torch.no_grad():
        for p, ema_p in zip(model.state_dict().values(), ema_model.state_dict().values()):
            p.copy_(ema_p)

In [18]:
def dump_images(fake_images):
    for i in range(fake_images.shape[0]):
      img_tensor = fake_images[i]
      img = img_tensor.detach().cpu().numpy().astype(np.float32)

      # Normalize to [0,1]
      img = img*255+128
      img = torch.min(torch.tensor(255), torch.tensor(img))
      # Scale to [0,255] and convert to uint8
      img_uint8 = (img).numpy().astype(np.uint8)

      cv2.imwrite(f"/content/drive/MyDrive/dump_images/{i}.png", img_uint8)


In [None]:
import copy
import random

epochs = 0

save_interval=20

checkpoint_dir = "/content/drive/MyDrive/gan_models"

#normalize_each_param(discriminator)

#discriminator_ema = init_ema_model(discriminator).cuda()  # EMA model has same architecture


for epoch in range(48803, epochs+280004):


    checkpoint_dir = "/content/drive/MyDrive/gan_models"

    if (epoch)%save_interval == 0:

      checkpoint_path =f"{checkpoint_dir}/generator_checkpoint_epoch_{epoch}.pth"


      torch.save({
            'epoch': epoch,
            'model_state_dict': generator.state_dict(),
            'optimizer_state_dict': g_optimizer.state_dict(),
            "rng_state": torch.get_rng_state(),
            "numpy_rng_state": np.random.get_state(),
            "python_rng_state": random.getstate(),
            "cuda_rng_state": torch.cuda.get_rng_state()
        }, checkpoint_path)
      checkpoint_path = f"{checkpoint_dir}/discriminator_checkpoint_{epoch}.pth"
      torch.save({
            'epoch': epoch,
            'model_state_dict': discriminator.state_dict(),
            'optimizer_state_dict': d_optimizer.state_dict()
        }, checkpoint_path)

      print(f"Model checkpoint saved to {checkpoint_path}")


    print("Epoch starting {}".format(epoch))
    steps = 0

    d_optimizer = torch.optim.SGD(discriminator.parameters(), lr=0.0003, momentum=0.7, nesterov=True)

    try:
        #discriminator_loss = torch.tensor(0.0).cuda()
        while True:
            steps += 1
            if steps > 4:
                break

            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            try:
              ims, _ = next(dataloader_iter)

            except StopIteration:
              dataloader_first = DataLoader(dataset_first, batch_size=2048, shuffle=True)
              dataloader_iter = iter(dataloader_first)
              ims, _ = next(dataloader_iter)


            real_pred = discriminator(ims)


            gauss = torch.randn(len(ims), 5, device="cuda")

            fake_images = generator(gauss)

            fake_images_labeled = torch.tensor(fake_images, requires_grad=False).cuda().detach()

            fake_pred = discriminator(fake_images_labeled)

            temp_pred = discriminator(0.5*ims+0.5*fake_images_labeled)

            #temp_pred2 = discriminator(0.25*ims+0.75*fake_images_labeled)

            #temp_pred3 = discriminator(0.75*ims+0.25*fake_images_labeled)

            ls = 10*(fake_pred-real_pred).mean() + 10*((temp_pred-0.5*real_pred-0.5*fake_pred).abs().mean())#+ 10*((temp_pred2-0.25*real_pred-0.75*fake_pred).abs().mean()) + 10*((temp_pred3-0.75*real_pred-0.25*fake_pred).abs().mean())


            for _ in range(4):

              alph = random.uniform(0.0, 1.0)
              temp_pred = discriminator(alph*ims+(1-alph)*fake_images_labeled)
              ls += 10*((temp_pred-alph*real_pred-(1-alph)*fake_pred).abs().mean())



            ls.backward()
            d_optimizer.step()

            print("Discriminator loss in class 0 images {:.7f} and step = {} fake_pred {:.7f} real_pred {:.7f}".format(ls, steps, np.float64(fake_pred.mean().cpu()), np.float64(real_pred.mean().cpu())))

            #print("Discriminator loss in class 0 images {:.7f} and step = {} fake_pred {:.7f} real_pred {:.7f} temp_pred {:.7f} 1q {:.7f} 3q {:.7f}".format(ls, steps, np.float64(fake_pred.mean().cpu()), np.float64(real_pred.mean().cpu()), np.float64(temp_pred.mean().cpu()), np.float64(temp_pred2.mean().cpu()), np.float64(temp_pred3.mean().cpu())))


    except Exception as e:
        print(e)
        pass

    #update_ema_model(discriminator, discriminator_ema)
    #load_ema_to_model(discriminator, discriminator_ema)


    #generator_ema = init_ema_model(generator).cuda()  # EMA model has same architecture
    g_optimizer = torch.optim.SGD(generator.parameters(), lr=0.0003, momentum=0.7, nesterov=True)


    steps = 0

    while True:
        steps += 1
        if steps > 8:
            break

        d_optimizer.zero_grad()
        g_optimizer.zero_grad()


        gauss = torch.randn(2048, 5, device="cuda")

        fake_images_labeled = generator(gauss)

        fake_pred = discriminator(fake_images_labeled)

        ls = -10*(fake_pred).mean()
        ls.backward()

            #normalize_each_param(discriminator)

        g_optimizer.step()
        print("Generator loss in class 0 images {:.7f} and step = {} ".format(ls, steps))



    #update_ema_model(generator, generator_ema, decay=0.85)
    #load_ema_to_model(generator, generator_ema)

    if (epoch)%save_interval==0:
        with torch.no_grad():

            gauss = torch.randn(100, 5, device="cuda")

            fake_images = generator(gauss)

            fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()

            fake_images_labeled = fake_images
            dump_images(torch.clone(fake_images_labeled))




Epoch starting 48803


  fake_images_labeled = torch.tensor(fake_images, requires_grad=False).cuda().detach()
Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  print("Discriminator loss in class 0 images {:.7f} and step = {} fake_pred {:.7f} real_pred {:.7f}".format(ls, steps, np.float64(fake_pred.mean().cpu()), np.float64(real_pred.mean().cpu())))


Discriminator loss in class 0 images -0.4267675 and step = 1 fake_pred -0.3976799 real_pred -0.1845406
Discriminator loss in class 0 images -1.1502283 and step = 2 fake_pred -0.1472352 real_pred 0.1437779
Discriminator loss in class 0 images -2.0894594 and step = 3 fake_pred 0.1470234 real_pred 0.5205612
Discriminator loss in class 0 images -0.5030944 and step = 4 fake_pred 0.4584092 real_pred 0.9309561
Generator loss in class 0 images -5.6629748 and step = 1 
Generator loss in class 0 images -6.0379400 and step = 2 
Generator loss in class 0 images -6.4438534 and step = 3 
Generator loss in class 0 images -6.7314048 and step = 4 
Generator loss in class 0 images -6.9368954 and step = 5 
Generator loss in class 0 images -7.0855474 and step = 6 
Generator loss in class 0 images -7.2283678 and step = 7 
Generator loss in class 0 images -7.3602991 and step = 8 
Epoch starting 48804
Discriminator loss in class 0 images 0.0100095 and step = 1 fake_pred 0.7492102 real_pred 1.0200742
Discrimi

  fake_images = torch.tensor(fake_images, requires_grad=False).cuda().detach()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Discriminator loss in class 0 images -2.4770975 and step = 1 fake_pred -0.0818309 real_pred 0.5266601
Discriminator loss in class 0 images -2.2653785 and step = 2 fake_pred 0.0125731 real_pred 0.6682600
Discriminator loss in class 0 images -2.0029373 and step = 3 fake_pred 0.0757670 real_pred 0.7538804
Discriminator loss in class 0 images -3.5527058 and step = 4 fake_pred 0.0920933 real_pred 0.7602668
Generator loss in class 0 images -2.6399236 and step = 1 
Generator loss in class 0 images -2.7729478 and step = 2 
Generator loss in class 0 images -2.7840190 and step = 3 
Generator loss in class 0 images -2.8979781 and step = 4 
Generator loss in class 0 images -3.0018904 and step = 5 
Generator loss in class 0 images -3.1070347 and step = 6 
Generator loss in class 0 images -3.2021480 and step = 7 
Generator loss in class 0 images -3.3288980 and step = 8 
Epoch starting 49130
Discriminator loss in class 0 images -1.40411