In [None]:
import torch as torch
from torch import nn
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import matplotlib as matplotlib

In [None]:
matplotlib.colors.Normalize(vmin=-1, vmax=1, clip=False)

In [None]:
if torch.cuda.is_available():
    cuda_id = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(cuda_id)
    device = "cuda:0"
else:
    device = "cpu"

## Creating Discriminator and Generator model classes

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        return self.network(img)

In [None]:
class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, input):
        return torch.reshape(input, self.shape)


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(64, 1024),
            Reshape((-1, 64, 4, 4)),
            nn.ConvTranspose2d(64, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(64, 128, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.ConvTranspose2d(128, 256, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 3, 5, 1, 2, bias=False),
            nn.Tanh(),
        )

    def forward(self, img):
        return self.network(img)

In [None]:
# return a vector of n samples from N(0,1) distribution
get_normal = lambda shape: torch.normal(torch.zeros(shape), 1).to(device)
normal_vector = get_normal(64)

In [None]:
generator = Generator().to(device)
forwarded_random_image = generator.forward(normal_vector)
forwarded_random_image.shape

In [None]:
def scale_and_display(image, normalize=False, size=8):
    if normalize:
        value_range = (image.min(), image.max())
    else:
        value_range = (-1, 1)
    plt.figure(figsize=(size, size))
    plt.axis("off")
    plt.imshow(
        (
            vutils.make_grid(
                image, nrow=4, padding=2, normalize=True, value_range=value_range
            )
            .permute(1, 2, 0)
            .detach()
            .cpu()
            .numpy()
        )
    )
    plt.show()

In [None]:
scale_and_display(torch.clone(forwarded_random_image).cpu())

In [None]:
scale_and_display(forwarded_random_image, True)

In [None]:
# Let's ensure that the Discriminator doesn't throw any obvious errors
discriminator = Discriminator().to(device)
discriminator.forward(forwarded_random_image)

## Loading the data

In [None]:
def lambda_scaling(tensor):
    return tensor * 2 - 1

In [None]:
data_dir = dset.ImageFolder(
    root="./imgs",
    transform=transforms.Compose(
        [
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda_scaling),
        ]
    ),
)

In [None]:
data = torch.stack([element[0] for element in data_dir]).to(device)
data.shape

In [None]:
scale_and_display(torch.clone(data[2]).cpu())

In [None]:
data[0].shape

## Example of training a non-standard model in a loop

In [None]:
class NonstandardModel(nn.Module):
    def __init__(self):
        super(NonstandardModel, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(10, 128),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(128, 128),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        return self.network(img)

In [None]:
my_model = NonstandardModel().to(device)
for i, param in enumerate(my_model.parameters()):
    if not i in {0, 3}:  # 0 and 3 correspond to the first 2 fully connected layers
        param.requires_grad = False

In [None]:
# copy model parameters to compare them at the end
reference_list = []
for p in my_model.parameters():
    reference_list.append(torch.clone(p))

In [None]:
batch = torch.stack([get_normal(10) for _ in range(10)])
optimizer = torch.optim.SGD(
    my_model.parameters(), lr=0.001, momentum=0.9, nesterov=True
)
for epoch in range(1000):
    y_pred = my_model(batch)
    loss = 42 - 42 * y_pred.mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(loss)

In [None]:
result_list = []
for p in my_model.parameters():
    result_list.append(p)

for ref, res in zip(reference_list, result_list):
    print((ref - res).sum())

## The basic idea of discriminator training

In [None]:
BATCH_SIZE = 16

In [None]:
real_batch = data[:BATCH_SIZE]
real_labels = torch.ones(BATCH_SIZE) + torch.rand(BATCH_SIZE) * 0.05

generated_batch = generator.forward(
    torch.stack([get_normal(64) for _ in range(BATCH_SIZE)])
)
generated_labels = torch.zeros(BATCH_SIZE) - torch.rand(BATCH_SIZE) * 0.05

batch = torch.concat([real_batch, generated_batch])
labels = torch.concat([real_labels, generated_labels])
labels = labels.clip(min=0, max=1).reshape((-1, 1)).to(device)

In [None]:
discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=0.00001)
loss_function = nn.BCELoss()

In [None]:
y_pred = discriminator(batch)
loss = loss_function(y_pred, labels)

discriminator_opt.zero_grad()
loss.backward()
discriminator_opt.step()

loss

## The basic idea of generator training

In [None]:
batch = torch.stack([get_normal(64) for _ in range(BATCH_SIZE)])
labels = torch.ones(BATCH_SIZE)
labels = labels.clip(max=1).reshape((-1, 1)).to(device)

generator_opt = torch.optim.Adam(generator.parameters(), lr=0.00001)

In [None]:
images = generator(batch)
y_pred = discriminator(images)
loss = loss_function(y_pred, labels)
generator_opt.zero_grad()
loss.backward()
generator_opt.step()

loss

## Training the model

In [None]:
n_full_batches = data.shape[0] // BATCH_SIZE
n_ending_elems = data.shape[0] % BATCH_SIZE
batches = [
    data[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] for i in range(n_full_batches)
] + [data[-n_ending_elems:]]

In [None]:
dataloader = torch.utils.data.DataLoader(data, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
def save_state(dis, gen, dis_opt, gen_opt, path, epoch):
    torch.save(dis, path + "/discriminator" + str(epoch))
    torch.save(gen, path + "/generator" + str(epoch))
    torch.save(dis_opt, path + "/discriminator_optimizer" + str(epoch))
    torch.save(gen_opt, path + "/generator_optimizer" + str(epoch))


def load_state(path, epoch):
    dis = torch.load(path + "/discriminator" + str(epoch))
    gen = torch.load(path + "/generator" + str(epoch))
    dis_opt = torch.load(path + "/discriminator_optimizer" + str(epoch))
    gen_opt = torch.load(path + "/generator_optimizer" + str(epoch))
    return dis, gen, dis_opt, gen_opt

In [None]:
constant_for_comparison = torch.stack([get_normal(64) for _ in range(16)])

In [None]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)
discriminator_opt = torch.optim.Adam(
    discriminator.parameters(), lr=0.00001, weight_decay=0.01
)
generator_opt = torch.optim.Adam(generator.parameters(), lr=0.00001)
loss_function = nn.BCELoss()

In [None]:
epoch_loss = []
for epoch in range(3000):
    dis_cum_loss = 0
    gen_cum_loss = 0
    for batch in dataloader:
        batch_size = batch.shape[0]

        # Training discriminator to recognise real images+
        true_labels = torch.ones(batch_size) - torch.rand(batch_size) * 0.05
        dis_pred_for_true = discriminator(batch)

        # Training discriminator to recognise fakes
        generated_batch = generator.forward(get_normal((batch_size, 64)))
        false_labels = torch.zeros(batch_size) + torch.rand(batch_size) * 0.05
        dis_pred_for_false = discriminator(generated_batch)

        dis_loss = loss_function(
            torch.concat([dis_pred_for_true, dis_pred_for_false]),
            torch.concat([true_labels, false_labels]).reshape((-1, 1)).to(device),
        )
        discriminator_opt.zero_grad()
        dis_cum_loss += dis_loss
        dis_loss.backward()
        discriminator_opt.step()

        # Training generator
        gen_labels = torch.ones(batch_size)
        gen_labels = gen_labels.reshape((-1, 1)).to(device)
        gen_batch = generator(get_normal((batch_size, 64)))
        gen_pred = discriminator(gen_batch)
        gen_loss = loss_function(gen_pred, gen_labels)
        generator_opt.zero_grad()
        gen_cum_loss += gen_loss
        gen_loss.backward()
        generator_opt.step()

    if epoch % 50 == 0 and epoch > 0:
        epoch_loss.append((dis_cum_loss, gen_cum_loss))
        save_state(
            discriminator,
            generator,
            discriminator_opt,
            generator_opt,
            "./saved_models",
            epoch,
        )
        print("epoch: ", epoch)
        print("d loss: ", dis_cum_loss / len(batches))
        print("g loss: ", gen_cum_loss / len(batches))
        images = generator.forward(constant_for_comparison)
        scale_and_display(images.cpu(), normalize=True)

In [None]:
epoch_loss

In [None]:
losses = [i.cpu().detach().numpy() for i, j in epoch_loss], [
    j.cpu().detach().numpy() for i, j in epoch_loss
]

In [None]:
fig = plt.figure()
ax1 = fig.add_subplot()

ax1.plot(losses[0], label="discriminator loss")
ax1.plot(losses[1], label="generator loss")
ax1.legend()
plt.show()

## Recreating training data

In [None]:
scale_and_display(data[2], size=4)

In [None]:
reference_image = data[2]
input1 = get_normal(64)
input1.requires_grad = True

optimizer = torch.optim.SGD([input1], 0.01, momentum=0.9, nesterov=True)
loss_function_MSE = nn.MSELoss()

generated_image = generator(input1)
scale_and_display(generated_image, size=4)
for i in range(1000):
    generated_image = generator(input1)
    loss = loss_function_MSE(generated_image, reference_image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

scale_and_display(generator(input1), size=4)

## Manually modifying the input vector

In [None]:
# let's see how the image will change if we manually modify the vector
modified_input = torch.clone(input1)
indices = [2 * i for i in range(32)]
modified_input[indices] = 0.5
scale_and_display(generator(modified_input), size=4)

## Attempt to generate an image from outside of training data

In [None]:
rat_dir = dset.ImageFolder(
    root="./szczoor",
    transform=transforms.Compose(
        [
            transforms.Resize(32),
            transforms.CenterCrop(32),
            transforms.ToTensor(),
            transforms.Lambda(lambda_scaling),
        ]
    ),
)
rat = torch.stack([element[0] for element in rat_dir]).to(device)
scale_and_display(rat, size=4)

In [None]:
input2 = get_normal(64)
input2.requires_grad = True
optimizer = torch.optim.SGD([input2], 0.1, momentum=0.9, nesterov=True)
loss_function_MSE = nn.MSELoss()

generated_image = generator(input2)
scale_and_display(generated_image, size=4)
for i in range(1000):
    generated_image = generator(input2)
    loss = loss_function_MSE(generated_image, rat)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

scale_and_display(generator(input2), size=4)

## Generating data transition

In [None]:
reference_image = data[3]
input3 = get_normal(64)
input3.requires_grad = True
scale_and_display(reference_image, size=4)

optimizer = torch.optim.SGD([input3], 0.01, momentum=0.9, nesterov=True)
loss_function_MSE = nn.MSELoss()

generated_image = generator(input3)
scale_and_display(generated_image, size=4)
for i in range(1000):
    generated_image = generator(input3)
    loss = loss_function_MSE(generated_image, reference_image)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

scale_and_display(generator(input3), size=4)

In [None]:
distance_vect = input3 - input1

for i in range(6):
    generated_image = generator(input1 + distance_vect * i / 5)
    scale_and_display(generated_image)