## Creating a generative AI with twinlab

In [None]:
# Standard imports
import pickle

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# twinLab
import twinlab as tl

In [None]:
# Parameters

# Uncomment this to use images of the numbers 0-9
experiment = "MNIST"

# Uncomment this to use images of objects from CIFAR-10 database 
# Data is from the website https://www.cs.toronto.edu/~kriz/cifar.html 
# experiment = "CIFAR-10"

# Random numbers
random_seed = 123

In [None]:
# Functions
def unpickle(file): # Unpickle a CIFAR-10 file
    with open(file, "rb") as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

def wrangle_image(linear_image, npix): # Reshape a CIFAR-10 image
    pix = npix**2
    if len(linear_image) == pix:
        image = linear_image.reshape(npix, npix)
    elif len(linear_image) == 3*pix:
        R = linear_image[0*pix:1*pix].reshape(npix, npix)
        G = linear_image[1*pix:2*pix].reshape(npix, npix)
        B = linear_image[2*pix:3*pix].reshape(npix, npix)
        image = np.dstack((R, G, B)).astype(np.uint8)
    else:
        raise ValueError("Image is neither 1D nor 3D.")
    return image

In [None]:
# Calculations
np.random.seed(random_seed)

In [None]:
if experiment == "MNIST":

    # Read data in and set pixels that the training data has 
    # In this case, it's 8x8 pixel pictures of numbers 0 to 9. 1798 pictures
    npix = 8
    filepath = "MNIST/data.csv"
    df = pd.read_csv(filepath)

elif experiment == "CIFAR-10":

    # 32x32 pixel pictures. 10 pictures of 10 different types of object 
    npix = 32
    filepath = "CIFAR-10/data_batch_1"
    data = unpickle(filepath)

    df = pd.DataFrame(data[b"data"])
    df.columns = [f"{RGB}-{i}-{j}" for RGB in ["R", "G", "B"] for i in range(npix) for j in range(npix)]
    # Iterate through the RGB values that compose these pictures 
    # Each pixel gets a value so we can unpack a 3D object into the 2D dataframe
    df["number"] = data[b"labels"] # TODO: Try to insert this as the first column

else:

    raise ValueError("Experiement not recognised")
    # You've specified an experiment that doesn't exist!

inputs = ["number"]
outputs = list(df.drop(columns=inputs).columns)

# Plot an image
image = wrangle_image(df[outputs].iloc[0].to_numpy(), npix)
plt.figure(figsize=(2, 2))
plt.imshow(image, cmap="binary_r")
plt.xticks([]); plt.yticks([])
plt.show()

display(df)

Data campaign

In [None]:
# Set up campaign
inputs = ["number"]
outputs = list(df.drop(columns=['number']))
setup_dict = {
    "inputs": inputs,
    "outputs": outputs,
    'estimator': 'gaussian_process_regression', # What type of model do you want to use? 
    'decompose_outputs': True, # Equivalent of PCA/SVD for TL--on or off?
    'output_explained_variance': 0.75 # Toggle this number to improve accuracy
}

campaign = tl.Campaign(**setup_dict)
# Setting up parameters for TL campaign 

# Run campaign
train_dict = {
    "df": df,
    "train_test_split": 200,
    # Increase this number to increase the amount of data used to train the model
}
campaign.fit(**train_dict)

In [None]:
df_predict = pd.DataFrame({'number': list(range(10))})
# df_predict = ['R00']
# display(df_predict)
# campaign.predict(df_predict)
df_mean, _ = campaign.predict(df_predict)
# Pull out the mean and true to the prediction of the campaign
# Can also pull out the standard deviation (std)
display(df_mean)

# Plot the mean value of each figure/number from the trained dataset
plt.subplots(2, 5, figsize=(10, 4))
iplot = 0
for row in range(10):
    iplot += 1
    plt.subplot(2,5,iplot)
    image = wrangle_image(df_mean.iloc[row].to_numpy(), npix)
    plt.imshow(image, cmap="binary_r")
    plt.xticks([]); plt.yticks([])
plt.show()

Output

In [None]:
n = 10
df_samples = campaign.sample(df_predict, n)
display(df_samples)
# Pull out random samples of each type of image from the trained dataset

# Plot random samples of each type of image from the trained dataset:
nrow, ncol = n, 10
# npix should be 8
plt.subplots(nrow, ncol, figsize=(10, 1*n))
iplot = 0
for sample in range(n):
    for row in range(10):
        iplot += 1
        plt.subplot(nrow, ncol, iplot)
        linear_image = df_samples.xs(row, axis="columns", level=1, drop_level=True).iloc[sample].to_numpy()
        image = wrangle_image(linear_image, npix)
        plt.imshow(image, cmap="binary_r")
        plt.xticks([]); plt.yticks([])
plt.show()