In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def show_tensor_images(title, image_tensor, num_images=16, size=(3, 64, 64), nrow=3):
    plt.figure(figsize = (20,10))
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu().clamp_(0, 1)
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow, padding=0)
    plt.title(title)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.axis('off')
    plt.show()

class Generator(nn.Module):
    def __init__(self, z_dim=64, im_chan=3, hidden_dim=16):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.dense1 = nn.Linear(z_dim, z_dim)
        self.dense2 = nn.Linear(z_dim, 64)
        self.activation = nn.Tanh()
        self.gen = nn.Sequential(
            self.make_gen_block(1, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim * 2),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, scale_factor=2, kernel_size=4, stride=1, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(negative_slope=0.2)
            )
        else:
            return nn.Sequential(
                nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.LeakyReLU(negative_slope=0.2)
            )

    def forward(self, noise):
        x = self.dense1(noise)
        x = self.activation(x)
        x = self.dense2(x)
        x = self.activation(x)
        x = x.view(-1, 1, 8, 8)
        # x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)

class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16, z_dim=64):
        super(Critic, self).__init__()
        self.dense1 = nn.Linear(64, z_dim)
        self.dense2 = nn.Linear(z_dim, 1)
        self.activation = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(negative_slope=0.2)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        crit_pred = self.crit(image)
        x = crit_pred.view(len(crit_pred), -1)
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dense2(x)
        return self.sigmoid(x)
    
device = 'cuda'
im_chan = 3
z_dim = 64

gen = Generator(z_dim=z_dim, im_chan=im_chan, hidden_dim=48).to(device)
crit = Critic(im_chan=im_chan, hidden_dim=48).to(device)

noise = torch.randn(16, z_dim, device=device)
fake = gen(noise)
print(fake.shape)
crit(fake)
im_size = fake.shape[3]

In [None]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_gradient
def get_gradient(crit, real, fake, epsilon):
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)
    
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        # Note: You need to take the gradient of outputs with respect to inputs.
        # This documentation may be useful, but it should not be necessary:
        # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad
        #### START CODE HERE ####
        inputs=mixed_images,
        outputs=mixed_scores,
        #### END CODE HERE ####
        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient

# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: gradient_penalty
def gradient_penalty(gradient):
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    
    # Penalize the mean squared distance of the gradient norms from 1
    #### START CODE HERE ####
    penalty = nn.functional.mse_loss(gradient_norm, torch.ones_like(gradient_norm))
    #### END CODE HERE ####
    return penalty

# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_gen_loss
def get_gen_loss(crit_fake_pred):
    #### START CODE HERE ####
    gen_loss = -torch.mean(crit_fake_pred)
    #### END CODE HERE ####
    return gen_loss

# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_crit_loss
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    #### START CODE HERE ####
    mean_fake = torch.mean(crit_fake_pred)
    mean_real = torch.mean(crit_real_pred)
    crit_loss = mean_fake - mean_real + gp * c_lambda
    #### END CODE HERE ####
    return crit_loss

In [None]:
!rm -f /fabric_dataset
!pip install gdown
!apt-get install unrar
!gdown --id 1iQST_v2PJSSnuybcg61G2tBG6Nq5QK8v
!mkdir fabric_dataset
!mkdir fabric_dataset/class
!mv Data-Omema.rar fabric_dataset/class/Data-Omema.rar
!cd fabric_dataset/class/ && unrar e Data-Omema.rar


In [None]:
!cd fabric_dataset/class/ && rm `ls | head -n2000`
!ls fabric_dataset/class/ | wc

In [None]:
batch_size = 6

from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=im_size,scale=(1.0, 1.5)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.6002, 0.5573, 0.5342],
                             std=[0.0825, 0.0801, 0.0796])
    ])
fabric_dataset = datasets.ImageFolder(
    root='./fabric_dataset',
    transform=data_transform)

dataloader = torch.utils.data.DataLoader(
    fabric_dataset,
    batch_size=batch_size, shuffle=True,
    num_workers=2)

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from tqdm import tqdm

lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10

gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def get_noise(n_samples, z_dim, device='cpu'):
    return torch.randn(n_samples, z_dim, device=device)


n_epochs = 50
display_step = int(len(dataloader.dataset)/batch_size)
cur_step = 0
generator_losses = []
critic_losses = []
crit_repeats = 1
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0.0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

            ### Update generator ###
        for _ in range(1):
            gen_opt.zero_grad()
            fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
            fake_2 = gen(fake_noise_2)
            crit_fake_pred = crit(fake_2)

            gen_loss = get_gen_loss(crit_fake_pred)
            gen_loss.backward()

            # Update the weights
            gen_opt.step()
            # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            

            show_tensor_images('fake',fake)
            show_tensor_images('real',real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1

In [None]:
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torch.utils import data
dataset = datasets.ImageFolder('./fabric_dataset', transform=transforms.Compose([transforms.Resize(64),
                             transforms.CenterCrop(64),
                             transforms.ToTensor()]))

loader = data.DataLoader(dataset,
                         batch_size=10,
                         num_workers=0,
                         shuffle=False)

mean = 0.0
for images, _ in loader:
    batch_samples = images.size(0) 
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
mean = mean / len(loader.dataset)

var = 0.0
for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    var += ((images - mean.unsqueeze(1))**2).sum([0,2])
std = torch.sqrt(var / (len(loader.dataset)*224*224))

print(mean, std)

In [None]:
torch.save(gen.state_dict(), 'gen.pth')
torch.save(crit.state_dict(), 'crit.pth')

In [None]:
gen.load_state_dict(torch.load('gen.pth', map_location=device))
crit.load_state_dict(torch.load('crit.pth', map_location=device))