In [1]:
import tensorflow as tf
import tensorly as tl
from tensorly.decomposition import parafac
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import numpy as np
import time
import matplotlib.pyplot as plt
import cv2

In [2]:
# function that applies CP Decomposition to an image tensor and reshapes it for ML algorithm
def decompose_and_flatten(tensor, rank):
    factors = parafac(tensor, rank)
    flattened = tl.kruskal_to_tensor(factors).flatten()
    return flattened



# load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

In [3]:
%%time
# normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

# flatten the labels
y_train, y_test = y_train.flatten(), y_test.flatten()

# select a subset of the data
subset_size = 1000
x_train_subset = x_train[:subset_size]
y_train_subset = y_train[:subset_size]

# apply CP Decomposition to each image in the subset
rank = 5
processed_x_train_subset = np.array([decompose_and_flatten(image, rank) for image in x_train_subset])

CPU times: user 1h 23min 43s, sys: 3min 9s, total: 1h 26min 52s
Wall time: 14min 34s


In [4]:
# split subset into training and test sets
X_train, X_val, y_train, y_val = train_test_split(processed_x_train_subset, y_train_subset, test_size=0.2)

# train a simple logistic regression model
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# evaluate the model
accuracy = model.score(X_val, y_val)
print(f"Model Accuracy: {accuracy}")

Model Accuracy: 0.335


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [None]:
results = {}
reshaped_x_train_subset = x_train_subset
while reshaped_x_train_subset[0].shape[0] >=1 :
    current_dimensions = f"({reshaped_x_train_subset[0].shape[0]}, {reshaped_x_train_subset[0].shape[1]})"
    start = time.time()
    
    # reprocess data
    processed_x_train_subset = np.array([decompose_and_flatten(image, int(tf.rank(image))) for image in reshaped_x_train_subset])
     
        
    # get elapsed time
    duration, cpu_time = time.time()-start, time.process_time()

    # retrain and retest model
    X_train, X_val, y_train, y_val = train_test_split(processed_x_train_subset, y_train_subset, test_size=0.2)
    model.fit(X_train, y_train)
    accuracy = model.score(X_val, y_val)

    # save to results dictionary
    results[current_dimensions] = [accuracy, duration, cpu_time]
    print(results)
    
    # reduce dimensions of image before next iteration, if possible
    if (reshaped_x_train_subset[0].shape[0] > 1):  
        reshaped_x_train_subset = np.array([cv2.resize(image, (image.shape[0]//2, image.shape[1]//2)) for image in reshaped_x_train_subset])
    else: break

{'(32, 32)': [0.325, 831.2861759662628, 27523.210382]}
{'(32, 32)': [0.325, 831.2861759662628, 27523.210382], '(16, 16)': [0.275, 735.2455909252167, 32300.946484]}


In [None]:
accuracies = [val[0] for val in list(results.values())]
cpu_times = [val[2] for val in list(results.values())]
runtimes = [val[1] for val in list(results.values())]
x_labels = list(results.keys())

In [None]:
print(cpu_times)

In [None]:
plt.scatter(x_labels, runtimes, color='purple')
plt.xlabel('Image dimensions')
plt.ylabel('Runtime')
plt.savefig('cp_runtimes', bbox_inches='tight', orientation='landscape')

In [None]:
plt.scatter(x_labels,cpu_times,color='purple')
plt.xlabel('Image dimensions')
plt.ylabel('CPU usage time (cumulative)')
plt.savefig('cp_cpu_times', bbox_inches='tight', orientation='landscape')

In [None]:
plt.scatter(x_labels,accuracies,color='purple')
plt.xlabel('Image dimensions')
plt.ylabel('Model accuracy')
plt.savefig('cp_accuracies', bbox_inches='tight', orientation='landscape')