In [None]:
#Packages

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pymks import (generate_multiphase, plot_microstructures, PrimitiveTransformer,
                   TwoPointCorrelation,FlattenTransformer, PrimitiveBasis)
from pymks.stats import correlate
from pymks.tools import draw_microstructures
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
import glob

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [None]:
#Append images all to one list.
images = []
for filename in glob.glob(r'C:\Users\mikep\Desktop\project\data\fiji data\1\*.tif'):
    im=Image.open(filename)
    images.append(im)
np.shape(images[0])

In [None]:
#List comprehension to convert the members of the "images" list to arrays.
imarrays = [ np.array(x) for x in images ]
imarrays[0].shape

In [None]:
#Stacking all image arrays to create the 'sample' dimension and to convert the list to an array.
data_a = np.stack(images, axis=0)
data_a.shape

In [None]:
#Visualization of the microstructures.
plot_microstructures(*data_a[:10], cmap='gray', colorbar=False);

In [None]:
#List comprehension to add the "phases" axis with the PrimitiveTransformer class.
data = PrimitiveTransformer(n_state=2, min_=0.0, max_=1.0).transform(data_a)
data.shape

In [None]:
data_corr = TwoPointCorrelation(
    periodic_boundary=True,
    cutoff=10,
    correlations=[(0, 0), (0, 1)]
).transform(data)

data_corr.shape

#for i in range(0, 11):
    #print(data_corr[i, :, :, 0])

In [None]:
plot_microstructures(
    data_corr[0, :, :, 0],
    data_corr[0, :, :, 1],
    titles=['Auto-correlation', 'Cross-correlation'],
    showticks=True
);

In [None]:
pc_scores = PCA(
    svd_solver='full',
    n_components=3,
    random_state=10
).fit_transform(data_corr.reshape(10, -1))

In [None]:
#PYTEST_VALIDATE_IGNORE_OUTPUT
%matplotlib notebook

In [None]:
#PYTEST_VALIDATE_IGNORE_OUTPUT
pc1, pc2, pc3, pc4 = np.split(pc_scores, 4)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(pc1[:,0], pc1[:,1], pc1[:,2], c='r', marker='o')
ax.scatter(pc2[:,0], pc2[:,1], pc2[:,2], c='g', marker='o')
ax.scatter(pc3[:,0], pc3[:,1], pc3[:,2], c='b', marker='o')
ax.scatter(pc4[:,0], pc4[:,1], pc4[:,2], c='k', marker='o')

ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_zlabel("PC3")
plt.show()

model = Pipeline(steps=[
        ('discretize', PrimitiveTransformer(n_state=2, min_=0.0, max_=1.0)),
        ('correlations', TwoPointCorrelation(
        periodic_boundary=True,
        correlations=[[0,0], [0,1]]
                                        ).transform(data))
])

print(imarray.shape)
x_stats = model.transform(imarray).persist()
print(x_stats.shape)

X = PrimitiveTransformer(n_state=2, min_=0.0, max_=1.0).transform(imarray)
X.shape
X_ = prim_basis.discretize(imarray)
X_corr = correlate(X_, periodic_axes=[0, 1], basis=prim_basis)