In [None]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!pip install kaggle
!chmod 600 /root/.kaggle/kaggle.json
!kaggle datasets download -d ashwingupta3012/human-faces
!unzip human-faces


In [None]:
import torch
import torch.nn as nn

In [None]:
import matplotlib.pyplot as plt 
import cv2


pictures2show = ['/content/Humans/1 (1000).jpg', '/content/Humans/1 (1029).jpg',
                 '/content/Humans/1 (1053).jpg', '/content/Humans/1 (1076).jpg',
                 '/content/Humans/1 (1102).jpg', '/content/Humans/1 (1121).jpg',
                 '/content/Humans/1 (1140).jpg', '/content/Humans/1 (1178).jpg', 
                 '/content/Humans/1 (1202).jpg', '/content/Humans/1 (1224).jpg']

pic_box = plt.figure(figsize=(16,4))
 
for i, picture in enumerate(pictures2show):
    picture = cv2.imread(picture)
    picture = cv2.cvtColor(picture, cv2.COLOR_BGR2RGB)
    pic_box.add_subplot(2,5,i+1)
    plt.imshow(picture)
    plt.axis('off')
plt.show()    

In [None]:
from torchvision import transforms
from PIL import Image

def get_tensor_image_from_path(path):
    img = Image.open(path).resize((256, 256))
    convert_tensor = transforms.ToTensor()
    convert_tensor.requires_grad=True
    return convert_tensor(img)


In [None]:
from torch.utils.data import Dataset, DataLoader
import os

class FacesDataset(Dataset):
    def __init__(self, image_dir = '/content/Humans/'):
        self.images_pathes = []
        for all_images in os.walk(image_dir):
            for name_file in all_images[2]:
                if name_file[-3:] == 'jpg' and get_tensor_image_from_path(image_dir + name_file).shape[0] == 3:
                    self.images_pathes.append(image_dir + name_file)

    def __getitem__(self, index):
        return get_tensor_image_from_path(self.images_pathes[index])

    def __len__(self):
        return len(self.images_pathes)

In [None]:
dataset = FacesDataset()
print(dataset[1].shape)
plt.imshow(dataset[1].permute(1,2,0))

In [None]:
len(dataset)

In [None]:
image_loader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=0)

In [None]:
for batch in image_loader:
    break

In [None]:
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.01),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(64, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Conv2d(64, 8, 4, 2, 1),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.01),
            nn.Dropout(0.1),
            nn.Flatten(),
            nn.Linear(8 * 8 * 8, 64),
            nn.ReLU(),
            nn.Linear(64, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.model(input)
    

In [None]:
class AdaIn(nn.Module):
    def __init__(self, eps = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x, y):
        mean_x = torch.mean(x, dim=[2,3])
        mean_y = torch.mean(y, dim=[2,3])
        std_x = torch.std(x, dim=[2,3])
        std_y = torch.std(y, dim=[2,3])
        mean_xu = mean_x.unsqueeze(-1).unsqueeze(-1)
        mean_yu = mean_y.unsqueeze(-1).unsqueeze(-1)
        std_xu = std_x.unsqueeze(-1).unsqueeze(-1) + self.eps
        std_yu = std_y.unsqueeze(-1).unsqueeze(-1) + self.eps
        return std_yu * ((x - mean_xu) / std_xu) + mean_yu

In [None]:
layer_sample = AdaIn()
assert layer_sample(dataset[0].unsqueeze(0), dataset[1].unsqueeze(0)).shape == torch.Size([1, 3, 256, 256])

In [None]:
class Vec2Image(nn.Module):
    def __init__(self, input_space = 512,
                       out_channels = 3,
                       image_size = 256):
        super().__init__()

        self.first = nn.Sequential(
            nn.Linear(input_space, out_channels * image_size * image_size)
        )
        self.input_space = input_space 
        self.out_channels = out_channels
        self.image_size = image_size

    def forward(self, x):
        if x.shape[1] != self.input_space:
            raise ValueError("Incorrect shape of input vector")
        out = self.first(x)
        out = out.reshape(-1, self.out_channels, self.image_size, self.image_size)
        return out

In [None]:
layer = Vec2Image()

input_vector = torch.rand(1, 512)
output = layer(input_vector)
assert output.shape == torch.Size([1, 3, 256, 256])

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
class Block(nn.Module):
    def __init__(self, 
                 in_channels, 
                 input_shape, 
                 out_channels,
                 style_space=512
                 ):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor = 2, mode='nearest')
        self.layer_1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, 1, 1),
            nn.ReLU()
        )
        self.A_1 = Vec2Image(style_space, in_channels, 2 * input_shape)
        self.layer_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.ReLU()
        )
        self.A_2 = Vec2Image(style_space, out_channels, 2 * input_shape)
        self.adain_1 = AdaIn()
        self.adain_2 = AdaIn()

    def forward(self, x, space_vector):
        hiddenq = self.upsample(x)
        hidden = self.layer_1(hiddenq)
        noise = torch.rand(list(hidden.shape)).to(device) / 10
        hiddenp = hidden + noise
        style_image = self.A_1(space_vector)
        hidden_1 = self.adain_1(hiddenp, style_image)
        hidden_1p = self.layer_2(hidden_1)
        noise1 = torch.rand(list(hidden_1p.shape)).to(device) / 10
        hidden_1pq = hidden_1p + noise1
        style_imagep = self.A_2(space_vector)
        hidden_2 = self.adain_2(hidden_1pq, style_imagep)
        return hidden_2

In [None]:
layer = Block(32, 32, 64).to(device)
input_image = torch.rand(1, 32, 32, 32).to(device)
vec = torch.rand(1, 512).to(device)
print(layer(input_image, vec).shape)

In [None]:
class StyleGAN(nn.Module):
    def __init__(self, latent_space = 512, 
                       style_space = 512, 
                       ):
        super().__init__()
        self.latent_space = latent_space 
        self.style_space = style_space
        self.fc_net = nn.Sequential(
            nn.Linear( self.latent_space,  self.latent_space),
            nn.ReLU(),
            nn.Linear( self.latent_space,  self.latent_space),
            nn.ReLU(),
            nn.Linear(self.latent_space, self.style_space),
            nn.ReLU(),
            nn.Linear(self.latent_space, self.style_space),
            nn.ReLU(),
            nn.Linear(self.latent_space, self.style_space)
        )
        
        self.start_image = nn.Parameter(torch.rand(16, 32, 32).to(device))
        self.block3 = Block(16, 32, 8, style_space)
        self.block4 = Block(8, 64, 4, style_space)
        self.block5 = Block(4, 128, 3, style_space)

    def forward(self, x):
        style_vec = self.fc_net(x)
        batch_size = x.shape[0]
        output = torch.cat([self.start_image.unsqueeze(0)] * batch_size, 0)        
        output1 = self.block3(output, style_vec)
        output2 = self.block4(output1, style_vec)
        output3 = self.block5(output2, style_vec)
        return output3



In [None]:
model = StyleGAN().to(device)
generated_image = model(torch.rand(1, 512).to(device))
print(generated_image.shape)


In [None]:
plt.imshow(generated_image.squeeze(0).permute(1,2,0).cpu().detach().numpy())

In [None]:
discriminator = Discriminator().to(device)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr = 1e-5)

In [None]:
optimizer_generator = torch.optim.Adam(model.parameters(), lr = 1e-3)


In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
from tqdm.auto import tqdm

num_pretrain = 100

criterion_pretrain = nn.MSELoss()

limit_unseen = 20
cur = 0

for epoch in tqdm(range(num_pretrain)):
    for batch in tqdm(image_loader):
        optimizer_generator.zero_grad()
        input_vec = torch.rand((batch.shape[0], 512)).to(device)
        fake_images = model(input_vec)
        loss = criterion_pretrain(fake_images, batch.to(device))
        loss.backward()
        optimizer_generator.step()
        plt.imshow(fake_images[0].squeeze(0).permute(1,2,0).cpu().detach().numpy())
        plt.show()
        break

In [None]:
from tqdm.auto import tqdm

num_epochs = 40

criterion = nn.BCELoss()

limit_unseen = 20
cur = 0


for epoch in tqdm(range(num_epochs)):
    for batch in tqdm(image_loader):
        optimizer_discriminator.zero_grad()

        output_real = discriminator(batch.to(device))
        target_real = torch.ones(output_real.shape).to(device)
        loss = criterion(output_real, target_real)

        input_vec = torch.rand((batch.shape[0], 512)).to(device)
        fake_images = model(input_vec)
        output_fake = discriminator(fake_images)
        target_fake = torch.zeros(output_fake.shape).to(device)
        loss += criterion(output_fake, target_fake)
        loss.backward()
        optimizer_discriminator.step()

        optimizer_generator.zero_grad()
        fake_images = model(input_vec)
        output_fake = discriminator(fake_images)
        target_real = torch.zeros(output_fake.shape).to(device)
        loss = criterion(output_fake, target_real)
        loss.backward()
        optimizer_generator.step()

        cur += 1
        if cur % limit_unseen == limit_unseen - 1:
            plt.imshow(fake_images[0].squeeze(0).permute(1,2,0).cpu().detach().numpy())
            plt.show()
