In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch
from torch import nn

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
device

In [None]:
import os
G='files/glasses/G/'
NoG='files/glasses/NoG/'
os.makedirs(G, exist_ok=True)
os.makedirs(NoG, exist_ok=True)

In [None]:
import random

In [None]:
from PIL import Image

In [None]:
imgs = os.listdir(G)
imgs
random.seed(42)
samples = random.sample(imgs, 16)
fig = plt.figure(dpi=200, figsize=(8, 2))
for i in range(16):
    ax = plt.subplot(2, 8, i+1)
    img = Image.open(f'{G}{samples[i]}')
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()

In [None]:
imgs = os.listdir(NoG)
imgs
random.seed(42)
samples = random.sample(imgs, 16)
fig = plt.figure(dpi=200, figsize=(8, 2))
for i in range(16):
    ax = plt.subplot(2, 8, i+1)
    img = Image.open(f'{NoG}{samples[i]}')
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()

In [None]:
class Critic(nn.Module):
    def __init__(self, img_channels, features):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(.2),
            self.block(features, features*2, 4, 2, 1),
            self.block(features*2, features*4, 4, 2, 1),
            self.block(features*4, features*8, 4, 2, 1),
            self.block(features*8, features*16, 4, 2, 1),
            self.block(features*16, features*32, 4, 2, 1),
            nn.Conv2d(features*32, 1, kernel_size=4, stride=2, padding=0)
        )
    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels, 
                kernel_size=kernel_size, 
                stride=stride, 
                padding=padding,
                bias=False
            ),
            nn.InstanceNorm2d(out_channels, affine=True),
            nn.LeakyReLU(.2)
        )
    def forward(self, x):
        return self.net(x)


In [None]:
class Generator(nn.Module):
    def __init__(self, noise_channels, img_channels, features):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            self.block(noise_channels, features * 64, 4, 1, 0),
            self.block(features * 64, features * 32, 4, 2, 1),
            self.block(features * 32, features * 16, 4, 2, 1),
            self.block(features * 16, features * 8, 4, 2, 1),
            self.block(features * 8, features * 4, 4, 2, 1),
            self.block(features * 4, features * 2, 4, 2, 1),
            nn.ConvTranspose2d(
                features * 2, img_channels, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [None]:
def weight_init(net):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            print('Conv init...')
            nn.init.xavier_normal_(m.weight.data)

In [None]:
z_dim = 100
img_channels = 3
features = 16
gen = Generator(z_dim + 2, img_channels, features)
critic = Critic(img_channels + 2, features)

from pathlib import Path
if Path('files/models/glasses_checkpoint.pth').is_file():
  loaded_data = torch.load('files/models/glasses_checkpoint.pth', map_location=device)
  previous_epoch = loaded_data['epoch']
  gen_state_dict = loaded_data['gen_state_dict']
  critic_state_dict = loaded_data['critic_state_dict']
  gen.load_state_dict(gen_state_dict)
  critic.load_state_dict(critic_state_dict)
else:
  weight_init(gen)
  weight_init(critic)
  previous_epoch = -1
gen.to(device)
critic.to(device)
lr = .0001
opt_gen = torch.optim.Adam(gen.parameters(), lr=lr, betas=(0, .9))
opt_critic = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0, .9))


In [None]:

def GP(critic, real, fake):
    B, C, H, W = real.shape
    alpha=torch.rand((B,1,1,1)).repeat(1,C,H,W).to(device)
    interpolated_images = real * alpha + (1 - alpha) * fake
    critic_scores = critic(interpolated_images)
    gradient  = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = critic_scores,
        grad_outputs=torch.ones_like(critic_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gp = torch.mean((gradient_norm - 1)**2)
    return gp


In [None]:
import torchvision.transforms as T
import torchvision
batch_size = 32
imgsz = 256
transform = T.Compose([
    T.Resize((imgsz, imgsz)),
    T.ToTensor(),
    T.Normalize((.5, .5, .5), (.5, .5, .5))
])

In [None]:
data_set = torchvision.datasets.ImageFolder(root="files/glasses", transform=transform)

In [None]:
from tqdm import tqdm
new_data = []
for (img, label) in tqdm(data_set):
    channels = torch.zeros(2, imgsz, imgsz)
    if label == 0:
        channels[0, :, :] = 1
    else:
        channels[1, :, :] = 1
    img_and_label = torch.cat((img, channels), dim=0)
    new_data.append(img_and_label)


In [None]:
data_loader = torch.utils.data.DataLoader(new_data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
import os
os.makedirs('files/generated_glasses', exist_ok=True)
def plot_epoch(epoch):
    # test images with glasses
    noise = torch.randn(32, z_dim, 1, 1)
    labels = torch.zeros(32, 2, 1, 1)
    # use label [1,0] so G knows what to generate
    labels[:,0,:,:]=1
    noise_and_labels=torch.cat([noise,labels],dim=1).to(device)
    fake=gen(noise_and_labels).cpu().detach()
    fig=plt.figure(figsize=(20,10),dpi=200)
    for i in range(32):
        ax = plt.subplot(4, 8, i + 1)
        img=(fake.cpu().detach()[i]/2+0.5).permute(1,2,0)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
    plt.subplots_adjust(hspace=-0.6)
    plt.savefig(f"files/generated_glasses/G{epoch}.png")
    plt.show() 
    # test images without glasses
    noise = torch.randn(32, z_dim, 1, 1)
    labels = torch.zeros(32, 2, 1, 1)
    # use label [0,1] so G knows what to generate
    labels[:,1,:,:]=1
    noise_and_labels=torch.cat([noise,labels],dim=1).to(device)
    fake=gen(noise_and_labels).cpu().detach()
    fig=plt.figure(figsize=(20,10),dpi=200)
    for i in range(32):
        ax = plt.subplot(4, 8, i + 1)
        img=(fake.cpu().detach()[i]/2+0.5).permute(1,2,0)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])
    plt.subplots_adjust(hspace=-0.6)
    plt.savefig(f"files/generated_glasses/NoG{epoch}.png")
    plt.show()     

In [None]:
def train_epoch(img_and_labels):
    onehots = img_and_labels[:, 3:, :1, :1]
    real = img_and_labels.to(device)
    B = real.shape[0]
    for _ in range(5):
        noise = torch.randn(B, z_dim, 1, 1)
        noise_and_labels=torch.cat([noise,onehots],dim=1).to(device)
        fake_imgs = gen(noise_and_labels).to(device)
        fakelabels = img_and_labels[:, 3:, :, :].to(device)
        fake=torch.cat([fake_imgs,fakelabels],dim=1).to(device)
        critic_real = critic(real).reshape(-1)
        critic_fake = critic(fake).reshape(-1)
        gp = GP(critic, real, fake)
        loss_critic = torch.mean(critic_fake) - torch.mean(critic_real) + 10 * gp
        opt_critic.zero_grad()
        loss_critic.backward(retain_graph=True)
        opt_critic.step()
    gen_fake = critic(fake)
    loss_gen = -torch.mean(gen_fake)
    opt_gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()
    return loss_critic, loss_gen


def save_checkpoint(epoch, gloss, closs):
    torch.save({
        'epoch': epoch,
        'gen_state_dict': gen.state_dict(),
        'critic_state_dict': critic.state_dict(),
    }, f'files/models/glasses_checkpoint.pth')


In [None]:
if previous_epoch > 0:
    plot_epoch(previous_epoch)
for epoch in range(previous_epoch + 1, 100):
    closs, gloss = 0, 0
    for img_and_labels in tqdm(data_loader):
        loss_critic, loss_gen = train_epoch(img_and_labels)
        closs += loss_critic.item() / len(data_loader)
        gloss += loss_gen.item() / len(data_loader)
    print(f"at epoch {epoch}, critic loss: {closs}, generator loss {gloss}") 
    plot_epoch(epoch)
    save_checkpoint(epoch, gloss, closs)

In [22]:
torch.jit.save(torch.jit.script(gen), 'files/models/glasses_gen.pt')