In [None]:
from torch_snippets import *
import torchvision 
from torchvision import transforms 
import torchvision.utils as vutils
import cv2, numpy as np, pandas as pd
import glob
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
os.getcwd()

In [None]:
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

In [16]:
#!mkdir cropped_faces
images = glob.glob(os.getcwd() + '/content/females/*.jpg') + glob.glob(os.getcwd() + '/content/males/*.jpg')
for i in range(len(images)):
    img = read(images[i],1)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    faces = face_cascade.detectMultiScale(gray, 1.3, 5)
    for (x,y,w,h) in faces:
        img2 = img[y:(y+h),x:(x+w),:]
        cv2.imwrite('cropped_faces/'+str(i)+'.jpg', cv2.cvtColor(img2, cv2.COLOR_RGB2BGR))


In [None]:
transform=transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

In [None]:
class Faces(Dataset):
    def __init__(self, folder):
        super().__init__()
        self.folder = folder
        self.images = sorted(Glob(folder))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, ix):
        image_path = self.images[ix]
        image = Image.open(image_path)
        image = transform(image)
        return image

In [None]:
ds = Faces(folder='cropped_faces/')
dataloader = DataLoader(ds, batch_size=64, shuffle=True)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
                                    nn.Conv2d(3,64,4,2,1,bias=False),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*2,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*4, 64*8,4,2,1,bias=False),
            nn.BatchNorm2d(64*8),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Conv2d(64*8,1,4,1,0,bias=False),
            nn.Sigmoid()
        )
        
        self.apply(weights_init)
    def forward(self, input):
        return self.model(input)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100,64*8,4,1,0,bias=False,),
            nn.BatchNorm2d(64*8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8,64*4,4,2,1,bias=False),
            nn.BatchNorm2d(64*4),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4,64*2,4,2,1,bias=False),
            nn.BatchNorm2d(64*2),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2,64,4,2,1,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64,3,4,2,1,bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)
        
    def forward(self,input): return self.model(input)

In [None]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

In [None]:
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real.squeeze(), torch.ones(len(real_data)).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake.squeeze(), torch.zeros(len(fake_data)).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

In [11]:
def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction.squeeze(), torch.ones(len(real_data)).to(device))
    error.backward()
    g_optimizer.step()
    return error
     

In [12]:
loss = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [13]:

log = Report(25)
for epoch in range(25):
    N = len(dataloader)

    for i, images, in enumerate(dataloader):
        real_data = images.to(device)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(torch.randn(len(real_data), 100, 1, 1).to(device)).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss','g_loss'])

EPOCH: 1.000  d_loss: 1.496  g_loss: 3.020  (2.15s - 51.50s remaining)ng)
EPOCH: 1.000  d_loss: 1.478  g_loss: 3.973  (2.36s - 56.52s remaining)ng)
EPOCH: 1.000  d_loss: 1.119  g_loss: 4.799  (2.56s - 61.35s remaining)g)
EPOCH: 1.000  d_loss: 0.882  g_loss: 5.195  (2.77s - 66.36s remaining)g)
EPOCH: 1.000  d_loss: 0.733  g_loss: 5.429  (2.96s - 71.14s remaining)g)
EPOCH: 1.000  d_loss: 0.655  g_loss: 5.618  (3.18s - 76.42s remaining)g)
EPOCH: 1.000  d_loss: 0.602  g_loss: 5.775  (3.38s - 81.23s remaining)g)
EPOCH: 1.000  d_loss: 0.550  g_loss: 5.888  (3.59s - 86.27s remaining)g)
EPOCH: 1.000  d_loss: 0.513  g_loss: 6.073  (3.80s - 91.19s remaining)g)
EPOCH: 1.000  d_loss: 0.474  g_loss: 6.195  (3.99s - 95.83s remaining)g)
EPOCH: 1.000  d_loss: 0.443  g_loss: 6.298  (4.19s - 100.60s remaining))
EPOCH: 1.000  d_loss: 0.421  g_loss: 6.455  (4.41s - 105.75s remaining))
EPOCH: 1.000  d_loss: 0.394  g_loss: 6.618  (4.62s - 110.95s remaining))
EPOCH: 1.000  d_loss: 0.371  g_loss: 6.732  (4.83

KeyboardInterrupt: 

In [None]:
generator.eval()
noise = torch.randn(64, 100, 1, 1, device=device)
sample_images = generator(noise).detach().cpu()
grid = vutils.make_grid(sample_images, nrow=8, normalize=True)
show(grid.cpu().detach().permute(1,2,0), sz=10, titl='Generated Images')