In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
from discriminator_dataset import DiscriminatorDataset
from train import train
from image_regressor_model import ImageRegressor

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
torch.cuda.empty_cache()
model = ImageRegressor(1).to(device)
torch.cuda.empty_cache()
print("Number of parameters:", model.getNumberOfParameters())

In [None]:
dataset = DiscriminatorDataset()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
def loss(prediction, label):
    prediction = 0.999*prediction + 0.0005
    return -torch.sum(label*torch.log(prediction) + (1-label)*torch.log(1-prediction))

In [None]:
model = train(model, optimizer, loss, dataset, device, epochs=15, batchSize=16)

In [None]:
torch.save(model, 'trained_models/abstract_image_discriminator.pkl')

In [None]:
randomSample = dataset[np.random.randint(len(dataset))]
pilImage = torchvision.transforms.ToPILImage()(randomSample[0])
plt.imshow(pilImage)
plt.show()

model.eval()
prediction = model(torch.unsqueeze(randomSample[0], dim=0).to(device)).item()
model.train()
if prediction >= 0.5:
    print("Predicted to be abstract with {:.2f}% probability".format(prediction*100))
else:
    print("Predicted to be realistic with {:.2f}% probability".format((1-prediction)*100))