In [19]:
import numpy as np
import torchvision
import torchvision.datasets as datasets
from sklearn.decomposition import PCA
import scipy.stats as stats
import random
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt

### Part I: Train GMM with no dataset argumentation ###

In [3]:
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))
mnist_trainset =  datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]))
raw_trainX = mnist_trainset.data.numpy()
raw_trainY = mnist_trainset.targets.numpy()
raw_testX = mnist_testset.data.numpy()
raw_testY = mnist_testset.targets.numpy()  

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


#### Model 1: We assume each class of digits in mnist shares some features follows Gaussian distribution, hence we feed in these classes and use trained GMMs for K classes to predict label of unseen images.

In [4]:
# use pca to reduce the dimesion of the dataset, original data size is 28 * 28
pca = PCA(n_components=60)
trainX = raw_trainX.reshape(raw_trainX.shape[0], -1)
trainX = pca.fit_transform(trainX)
testX = raw_testX.reshape(raw_testX.shape[0], -1)
testX = pca.transform(testX)

In [5]:
def pretrained_GMM(x, y, stage1_labels, comp=5, class_size=5000):
    model = []
    # stage 1 supervised learning
    for i in stage1_labels:
        trainX_i = x[y == i]
        model.append(GaussianMixture(n_components=comp).fit(trainX_i[:class_size]))
    return model

def GMMs_predict(x, gmms, k, threshold):
    predict = np.zeros((x.shape[0], k), dtype=float)
    # use pretrained model to predict x
    for i in range(k):
        predict[:,i] = gmms[i].score_samples(x)
    
    max_class = np.amax(predict, axis=1)
    return np.where(max_class > threshold, np.argmax(predict, axis=1), -1)

def GMM_classify(x, y, cnum):
    predict = GaussianMixture(n_components=cnum).fit_predict(x)
    result = []
    for i in range(cnum):
        label = stats.mode(y[predict == i])
        result.append((x[predict == i], label))
    return result

In [16]:
def GMM_solution(x, y, total_class, K, gmm_components=5):
    # stage 1 supervised training:
    stage1_classes = random.sample(range(total_class), K)                       # a random K-subset of [0,9]
    stage1_model = pretrained_GMM(x, y, stage1_classes, gmm_components)

    # stage 2 unsupervised training:
    predict = GMMs_predict(x, stage1_model, K, -380)

    new_GMMs = []
    labels = []
    for i in range(K):
        new_GMMs.append(GaussianMixture(n_components=gmm_components).fit(x[predict == i]))
        labels.append(stats.mode(y[predict == i])[0])

    # use rows cannot be classified to train the rest classes
    if total_class > K:
        if total_class - K == 1:
            unlabel_predict = [(x[predict == -1], stats.mode(y[predict == -1])[0])]
        else:
            unlabel_predict = GMM_classify(x[predict == -1], y[predict == -1], 10 - K)

        for i in range(10 - K):
            data, label = unlabel_predict[i]
            new_GMMs.append(GaussianMixture(n_components=10).fit(data))
            labels.append(label[0])
    return new_GMMs, labels

def GMM_validation(gmms, labels, trainX, trainY, testX, testY):
    ind_func = lambda x: labels[x]
    ind_func = np.vectorize(ind_func)

    total_class = len(labels)
    train_predict = GMMs_predict(trainX, gmms, total_class, -np.inf)
    train_predict = ind_func(train_predict)
    

    test_predict = GMMs_predict(testX, new_GMMs, total_class, -np.inf)
    test_predict = ind_func(test_predict)
    
    return np.sum(train_predict == trainY) / trainY.shape, np.sum(test_predict == testY) / testY.shape

In [22]:
new_GMMs, labels = GMM_solution(trainX, raw_trainY, 10, 10)
train_accuracy, test_accuracy = GMM_validation(new_GMMs, labels, trainX, raw_trainY, testX, raw_testY)
print("GMM's Accuracy on training set: {0}".format(train_accuracy))
print("GMM's Accuracy on test set: {0}".format(test_accuracy))

GMM's Accuracy on training set: [0.96223333]
GMM's Accuracy on test set: [0.9602]


In [29]:
ks = [5,6,7,8,9,10]
train_accuracy = []
test_accuracy = []
for k in ks:
    print('Pre feed classes {0}'.format(k))
    new_GMMs, labels = GMM_solution(trainX, raw_trainY, 10, k, 7)
    tra, tta = GMM_validation(new_GMMs, labels, trainX, raw_trainY, testX, raw_testY)
    train_accuracy.append(tra)
    test_accuracy.append(tta)
plt.plot(ks, train_accuracy, label='train_accuracy')
plt.plot(ks, test_accuracy, label='test_accuracy')
plt.legend()
plt.xticks(range(5,11))
plt.show()

Pre feed classes 5


KeyboardInterrupt: 