In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from utils import Generator, Discriminator, DatasetImages, init_weights_ ,Self_Attn
from IPython.display import clear_output
from datetime import datetime, timedelta
import time
import os
import pathlib
from PIL import Image
from DiffAugment_pytorch import DiffAugment
from torch.utils.tensorboard import SummaryWriter
import pprint
from torchvision.utils import make_grid
import fid_score

In [10]:
batch_size = 64
policy = 'color,translation,cutout'
# Utility functions
def cuda(data):
    if torch.cuda.is_available():
        return data.cuda()
    else:
        return data

def denorm(x):
    out = (x + 1) / 2
    return out.clamp_(0, 1)

In [11]:
output_path = pathlib.Path("outputs") / "out"
output_path.mkdir(exist_ok=True, parents=True)

# Prepare tensorboard writer
writer = SummaryWriter(output_path)
args_d = {}
args_d["time"] = datetime.now()
# Log hyperparameters as text
writer.add_text(
        "hyperparameter",
        pprint.pformat(args_d).replace(
            "\n", "  \n"
        ),  # markdown needs 2 spaces before newline
        0,
    )

# Fix a random latent input for samples
fixed_z = cuda(torch.randn(64, 100))

In [12]:
data_path = pathlib.Path("data")
tform = transforms.Compose(
        [
            transforms.Resize(32),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ]
    )
dataset = DatasetImages(
        data_path,
        transform=tform,
    )
dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
    )



In [13]:
mosaic_kwargs = {"nrow": 8, "normalize": True}
n_mosaic_cells = 64
sample_showcase_ix = (
        0  
    )
    
    
writer.add_image(
        "true_data",
        make_grid(
            torch.stack([dataset[i] for i in range(n_mosaic_cells)]),
            **mosaic_kwargs
        ),
        0,
    )

In [14]:
def train(steps, batch_size = 64, z_dim = 100, attn = True):
    z = cuda(torch.randn(batch_size, z_dim))
    # Initialize model
    G = cuda(Generator(batch_size, attn))
    D = cuda(Discriminator(batch_size, attn))
    
    # Make directory for samples and models
    cwd = os.getcwd()
    if not os.path.exists(cwd+'/dir'):
        os.makedirs(cwd+'/dir')
    #os.makedirs(cwd+ '/fid_temp_folder')

    # Initialize optimizer with filter, lr and coefficients
    g_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, G.parameters()), 0.0004, [0.5,0.99])
    d_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), 0.0001, [0.5,0.99])
    G.apply(init_weights_)
    D.apply(init_weights_)
    # Load data
    Iter = iter(dataloader)
    
    # Start timer
    start_time = time.time()
    
    for step in range(steps):
        # ================== Train D ================== #
        D.train(); G.train()
        try:
            real_images = next(Iter)
        except:
            Iter = iter(dataloader)
            real_images = next(Iter)
        iter_length = len(list(real_images))
        # Compute loss with real images
        d_out_real = D((DiffAugment(cuda(real_images), policy=policy)))
        #no diffaugm
        #d_out_real = D(cuda(real_images))
        d_loss_real = torch.nn.ReLU()(1.0 - d_out_real).mean()
        
        # Compute loss with fake images
        z = cuda(torch.randn(batch_size, z_dim))
        fake_images = G(z)
        d_out_fake = D(DiffAugment(fake_images, policy=policy))
        #no diffaug
        #d_out_fake = D(fake_images)
        d_loss_fake = torch.nn.ReLU()(1.0 + d_out_fake).mean()
        
        # Backward + Optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ================== Train G ================== #
        # Create random noise
        z = cuda(torch.randn(batch_size, z_dim))
        fake_images = G(z)
        g_out_fake = D((DiffAugment(fake_images, policy=policy)))
        #noo diffaug
        #g_out_fake = D(fake_images)
        g_loss_fake = - g_out_fake.mean()
        d_optimizer.zero_grad(); g_optimizer.zero_grad()
        g_loss_fake.backward()
        g_optimizer.step()
        
        # Print out log info
        if (step + 1) % 20 == 0:
            elapsed = time.time() - start_time
            expect = elapsed/(step + 1)*(steps-step-1)
            elapsed = str(timedelta(seconds=elapsed))
            expect = str(timedelta(seconds=expect))
            clear_output(wait=True)
            print("Elapsed [{}], Expect [{}], step [{}/{}], D_real_loss: {:.4f}, "
                  " ave_generator_gamma: {:.4f}".
                  format(elapsed,expect,step + 1,steps,d_loss_real.item(),G.attn.gamma.mean().item()))
            
        if step % 50 == 0:
                writer.add_scalar("d_real", d_loss_real.mean().item(), step)
                writer.add_scalar("d_fake", d_loss_fake.mean().item(), step)
                writer.add_scalar("gen", g_loss_fake.mean().item(), step)
                writer.add_scalar(
                    "D_loss", (d_loss).item(), step
                )
                writer.add_scalar("G_loss", g_loss_fake.item(), step)
        #if step % 500 == 0:
                # Save checkpoint (and potentially overwrite an existing one)
                #torch.save(generator, output_path / "model.pt")
        # Sample images
        if (step +1) % (1000) == 0 :
            #G.eval()
            #D.eval() 
            fake_images= G(fixed_z)
            save_image(denorm(fake_images), os.path.join('./dir', '{}_fake.jpg'.format(step + 1)))
            # reshape the generated images
            #images = images.transpose((0, 3, 1, 2))
            #images /= 255
            #fake_images= list(map(lambda x: (x.detach().permute(0, 2, 3, 1) / 2) + 0.5, fake_images))
            c=0
            for img in fake_images: 
                c +=1
                save_image(denorm(img),os.path.join('./fid_temp_folder', '{}_fake.jpg'.format(c)))
                    #generated_images += 1
                # Generate fake images
            gen_imgs_eval = G(fixed_z)
            fid = fid_score.calculate_fid_given_paths(
                        ('./data', './fid_temp_folder'),
                        64, 
                        True,
                        2048  # using he default value
                    )

                    # print the compute fid value:
            print("FID at epoch %d: %.6f" % (step, fid))

                    # log the fid value in tensorboard:
            writer.add_scalar("FID", fid, step)
                    # note that for fid value, the global step is the epoch number.
                    # it is not the global step. This makes the fid graph more informative
                # Generate nice mosaic
            writer.add_image(
                    "fake",
                    make_grid(fake_images, **mosaic_kwargs),
                    step,
                )
        # Save models
        if (step+1) % (1000) == 0:
            torch.save(G.state_dict(),os.path.join('./models', '{}_G.pth'.format(step + 1)))
            torch.save(D.state_dict(),os.path.join('./models', '{}_D.pth'.format(step + 1)))
        if(step == (59000)):
            gen = G.to("cuda")
            #d = D.to("cuda")
            #A = Self_Attn(3).to("cuda")
            #x = real_images
            writer.add_graph(gen, input_to_model=z)
            #writer.add_graph(d, input_to_model=fake_images)
            #writer.add_graph(A, input_to_model=fake_images)

In [15]:
train(steps = 60000,batch_size = 64,z_dim = 100, attn = False)
print('Done training part 1')


Elapsed [0:58:35.399427], Expect [0:00:00], step [60000/60000], D_real_loss: 0.0430,  ave_generator_gamma: 0.0000


100%|██████████| 5/5 [00:00<00:00,  5.47it/s]
100%|██████████| 1/1 [00:00<00:00,  5.46it/s]


FID at epoch 59999: 291.899042
Done training part 1


In [8]:
#python projector.py --label=0 --label-dim=10 --outdir=out --target=data/plane_10.jpg --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl

In [19]:
#gif creation
from PIL import Image, ImageDraw, ImageFont

font = ImageFont.truetype("./demo/arial.ttf", 18)
def create_image_with_text(img, wh, text):
    width, height = wh
    draw = ImageDraw.Draw(img)
    draw.text((width, height), text, font = font, fill="white")
    return img

frames = []

for i in range(100, 20001, 100):
    img = Image.open('samples_mnist/{}_fake.png'.format(str(i)))
    #img1 = Image.open('samples_mnist_attn/{}_fake.png'.format(str(i)))
    width, height = img.size
    expand = Image.new(img.mode, (width*2 + 10, height + 40), "black")
    expand.paste(img, (0, 0))
    #expand.paste(img1, (width + 10, 0))
    epoch = round(i*64/60000,2)
    new_frame = create_image_with_text(expand,(10,258), "After "+str(epoch)+" epoches")
    new_frame = create_image_with_text(new_frame,(10,238), "Without Attention")
    #new_frame = create_image_with_text(new_frame,(width + 20,238), "With Attention")
    frames.append(new_frame)
    
frames[0].save('./demo/comparison_mnist.gif', format='GIF',
               append_images=frames[1:],
               save_all=True,
               duration=60, loop=0)
 