In [None]:
import torch
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
import random
import torchvision.utils as vutils
from custom_nets.dcgan import Generator
import os
from scipy.interpolate import interp1d

In [None]:
def display_results(name: str, nz, ngf, nc, device):
    D_losses = np.loadtxt(f'../loss/dcgan_netD_{name}.txt')
    G_losses = np.loadtxt(f'../loss/dcgan_netG_{name}.txt')

    D_real_mean_out = np.loadtxt(f'../mean_out/dcgan_netD_real_{name}.txt')
    D_fake_mean_out = np.loadtxt(f'../mean_out/dcgan_netD_fake_{name}.txt')

    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses, label="G")
    plt.plot(D_losses, label="D")
    plt.xlabel("Iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    plt.figure(figsize=(10,5))
    plt.title("Discriminator Mean Scores During Training")
    plt.plot(D_real_mean_out, label="Real")
    plt.plot(D_fake_mean_out, label="Fake")
    plt.xlabel("Iterations")
    plt.ylabel("Mean")
    plt.legend()
    plt.show()

In [None]:
def save_fake_imgs(name: str, nz, ngf, nc, device):
    manualSeed = 42
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    netG = Generator(ngpu, nz, ngf, nc).to(device)
    netG.load_state_dict(torch.load(f'../nets/dcgan_netG_{name}'))

    fixed_noise = torch.randn(4000, nz, 1, 1, device=device)
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()

    if not os.path.exists('../dcgan_fake_imgs'):
        os.mkdir('../dcgan_fake_imgs')

    for i, img in enumerate(fake):
        save_image(img, f'../dcgan_fake_imgs/dcgan_fake_{name}_{i}.png')

In [None]:
def get_save_interp_vectors(net_name, nz, device):
    if not os.path.exists('../interpol'):
        os.mkdir('../interpol')

    start = torch.randn(nz, device=device)
    end = torch.randn(nz, device=device)

    linfit = interp1d([1, 10], torch.vstack([start, end]), axis=0)
    interp_vectors = [linfit(i) for i in range(1, 10 + 1)]

    np.savetxt(f'../interpol/{net_name}.csv', interp_vectors)

    return torch.tensor(np.array(interp_vectors), device=device)

In [None]:
def generate_interpol_dcgan(name, nz, ngf, nc, device):
    interp_vectors = get_save_interp_vectors('dcgan', nz, device)
    interp_vectors = torch.reshape(interp_vectors, (10, nz, 1, 1)).float()
    netG = Generator(ngpu, nz, ngf, nc).to(device)
    netG.load_state_dict(torch.load(f'../nets/dcgan_netG_{name}'))
    with torch.no_grad():
        fake = netG(interp_vectors).detach().cpu()
    img_list = vutils.make_grid(fake, padding=2, normalize=True, nrow=5)
    plt.figure(figsize=(10,5))
    plt.axis('off')
    plt.title('Fake Images')
    plt.imshow(np.transpose(img_list, (1, 2, 0)))
    fig_tosave = plt.gcf()
    if not os.path.exists('../interpol'):
        os.mkdir('../interpol')
    fig_tosave.savefig('../interpol/dcgan.png', dpi=600)
    plt.show()

In [None]:
ngpu = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
nz = 100
ngf = 64
nc = 3

In [None]:
name = ''

In [None]:
display_results(name, nz, ngf, nc, device)

In [None]:
save_fake_imgs(name, nz, ngf, nc, device)

In [None]:
generate_interpol_dcgan(name, nz, ngf, nc, device)