# NFT GAN

What do the most expensive, most sought-after NFTs have in common? Truth is, we don't know. But using the power of GANs we might be able to find out.

In [None]:
import torch as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as Transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from ai import *
from networks import MyGAN
from tqdm import tqdm
from torchvision.utils import save_image
import os

### Image Hyper Parameters

In [None]:
image_size = 256
batch_size = 64
latent_size = 256
stats = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)

### Load Dataset

In [None]:
train_ds = ImageFolder("dataset/", transform=Transforms.Compose([
    Transforms.Resize(image_size),
    Transforms.CenterCrop(image_size),
    Transforms.ToTensor(),
    Transforms.Normalize(*stats)]))

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

In [None]:
def denorm(img_tensors):
    return img_tensors * stats[1][0] + stats[0][0]

def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))

def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break

In [None]:
show_batch(train_dl)

In [None]:
gan_model = MyGAN(latent_size, image_size, batch_size)

xb = T.randn(batch_size, latent_size, 1, 1) # random latent tensors
fake_images = gan_model.generator(xb)
print(fake_images.shape)
show_images(fake_images)

In [None]:
sample_dir = 'generated'
os.makedirs(sample_dir, exist_ok=True)

In [None]:
def save_samples(index, latent_tensors, show=True):
    fake_images = gan_model.generator(latent_tensors)
    fake_fname = "generated-images-{0:0=4d}.png".format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print("Saving", fake_fname)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

Load stuff into the device

In [None]:
device = get_default_device()
gan_model.to_device(device)
train_dl = DeviceDataLoader(train_dl, device)

In [None]:
epochs = 200
lr = 0.0004
start_idx = 1

In [None]:
fixed_latent = T.randn(64, latent_size, 1, 1, device=gan_model.device)

In [None]:
save_samples(0, fixed_latent)

### Train GAN

In [None]:
T.cuda.empty_cache()

print(f"Discriminator parameters count: {count_parameters(gan_model.discriminator)}")
print(f"Generator parameters count: {count_parameters(gan_model.generator)}")

# Losses & scores
losses_g = []
losses_d = []
real_scores = []
fake_scores = []

# Create optimizers
opt_d = T.optim.Adam(gan_model.discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_g = T.optim.Adam(gan_model.generator.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(epochs):
    
    for real_images, _ in tqdm(train_dl):
        # Train discriminator
        loss_d, real_score, fake_score = gan_model.train_discriminator(real_images, opt_d)
        # Train generator
        loss_g = gan_model.train_generator(opt_g)
        
    # Record losses & scores
    losses_g.append(loss_g)
    losses_d.append(loss_d)
    real_scores.append(real_score)
    fake_scores.append(fake_score)
    
    # Log losses & scores (last batch)
    print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
        epoch+1, epochs, loss_g, loss_d, real_score, fake_score))

    # Save generated images
    save_samples(epoch+start_idx, fixed_latent, show=False)

history = losses_g, losses_d, real_scores, fake_scores