In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import gc
from train import train
from input_transformer import InputTransformer
from pytorch_pretrained_biggan import BigGAN
from input_noise_dataset import NoiseDataset

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

torch.set_printoptions(sci_mode=False)

In [None]:
torch.cuda.empty_cache()
model = InputTransformer().to(device)
torch.cuda.empty_cache()
print("Number of parameters:", model.getNumberOfParameters())

In [None]:
dataset = NoiseDataset()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
biggan = BigGAN.from_pretrained('biggan-deep-512').to(device)
torch.cuda.empty_cache()

In [None]:
if torch.cuda.is_available():
    discriminator = torch.load("trained_models/abstract_image_discriminator.pkl").to(device)
    feedbackPredictor = torch.load("trained_models/feedback_predictor.pkl").to(device)
else:
    discriminator = torch.load("trained_models/abstract_image_discriminator.pkl", map_location=torch.device('cpu')).to(device)
    feedbackPredictor = torch.load("trained_models/feedback_predictor.pkl", map_location=torch.device('cpu')).to(device)
torch.cuda.empty_cache()

In [None]:
def loss(prediction, label):
    generatedImage = biggan(prediction[0], prediction[1], 0.99)
    isAbstract = discriminator(generatedImage)
    feedback = feedbackPredictor(generatedImage)
    del generatedImage
    gc.collect()
    torch.cuda.empty_cache()
    return torch.sum((1-isAbstract) + (1-(feedback/6)))

In [None]:
model = train(model, optimizer, loss, dataset, device, epochs=1, batchSize=1, countAccuracy=False)

In [None]:
torch.save(model, 'trained_models/input_transformer.pkl')

In [None]:
initialInput = dataset[0][0].to(device)
transformedInput = model(initialInput)
generatedImage = biggan(transformedInput[0], transformedInput[1], 0.99)
isAbstract = discriminator(generatedImage).item()
feedback = feedbackPredictor(generatedImage).item()
pilImage = torchvision.transforms.ToPILImage()(torch.squeeze(generatedImage.cpu(), dim=0))
del generatedImage
gc.collect()
torch.cuda.empty_cache()
plt.imshow(pilImage)
plt.show()

print("Abstractness: {:.2f}".format(isAbstract))
print("Feedback: {:.2f}".format(feedback))