# FID

In [1]:
import numpy as np
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy.random import random
from scipy.linalg import sqrtm

In [2]:
# calculate frechet inception distance
# notice that at least one of real or generated needs to have a shape greater than or equal to (2,x)
def calculate_fid(real, generated):
	# calculate mean and covariance statistics
	mu1, sigma1 = real.mean(axis=0), cov(real, rowvar=False)
	mu2, sigma2 = generated.mean(axis=0), cov(generated, rowvar=False)
	# calculate sum squared difference between means
	ssdiff = np.sum((mu1 - mu2)**2.0)
	# calculate sqrt of product between cov
	covmean = sqrtm(sigma1.dot(sigma2))
	# check and correct imaginary numbers from sqrt
	if iscomplexobj(covmean):
		covmean = covmean.real
	# calculate score
	fid = ssdiff + trace(sigma1 + sigma2 - 2.0 * covmean)
	return fid

# read in training

In [3]:
import pickle

In [4]:
train = []
file = open('../notes/modernist_notes.pickle', 'rb')
train = pickle.load(file)
file.close()
train = np.array(train)
print('original train shape:', train.shape)

train = train[:, 400:600, 12:60]
print('modified train shape:', train.shape)

train = train.reshape((207, 200*48))
print('reshaped train shape:', train.shape)


original train shape: (207, 1000, 84)
modified train shape: (207, 200, 48)
reshaped train shape: (207, 9600)


# read in result

In [5]:
result = []
file = open('./results/result_modernist', 'rb')
result = pickle.load(file)
file.close()
result = np.array(result)
print(result.shape)

lastPiece = result[-1][0][400:600,12:60,0]
print(lastPiece.shape)
lastPiece = lastPiece.reshape((1, 200*48))
print(lastPiece.shape)

(402, 8, 1000, 84, 1)
(200, 48)
(1, 9600)


# calculate FID

In [6]:
random1 = random(1*9600)
random1 = random1.reshape((1*9600))
# act2 = random(1*100*48)
# act2 = act2.reshape((1,100*48))

In [7]:
# fid between act1 and act1
fid = calculate_fid(train, train)
print('FID (same): %.3f' % fid)
# fid between act1 and act2
fid = calculate_fid(train, lastPiece)
print('FID (different): %.3f' % fid)

FID (same): -0.002
FID (different): 5939.203


In [8]:
fid_rand = calculate_fid(train, random1)
print('FID (different): %.3f' % fid_rand)

FID (different): 20610.990
