In [1]:
import numpy as np
import math
import cancer
import sklearn.preprocessing

In [2]:
X_train, y_train = cancer.get_train()
X_train = sklearn.preprocessing.scale(X_train)
X_test, y_test = cancer.get_test()
X_test = sklearn.preprocessing.scale(X_test)

In [6]:
def soft_threshold(a, delta):
    return np.array([np.sign(x) * max(np.abs(x) - delta, 0) for x in a])

In [18]:
def fit(X, y, delta):
    N, D = X.shape
    cs = np.unique(y)
    C = len(cs)
    sse = np.zeros(D)
    mus = []
    prior = []
    xbar = np.mean(X, axis = 0)

    for c in cs:
        i = y == c
        X_c = X[i]
        prior.append(np.mean(i))
        mu_c = np.mean(X_c, axis = 0)
        mus.append(mu_c)
        sse += np.sum((X_c - mu_c) ** 2, axis = 0)
        
    sigma2 = sse / (1.0 * N - C)
    sigma = np.sqrt(sigma2)
    s0 = np.median(sigma)
    
    for i, c in enumerate(cs):
        m = math.sqrt(1.0 / sum(y == c) - 1.0 / N)
        d = (mus[i] - xbar) / (m * (sigma + s0))
        d = soft_threshold(d, delta)
        mus[i] = xbar + m * (sigma + s0) * d
        
    return cs, np.log(prior), mus, sigma2
        
def predict(model, X):
    cs, log_prior, mus, sigma2 = model
    N, D = X.shape
    C = len(cs)
    loglik = np.empty((N, C))
    
    for i in range(C):
        Z = 0.5 * (X - mus[i]) ** 2 / sigma2
        loglik[:, i] = log_prior[i] - np.sum([Z[:, j] for j in range(D)], axis = 0)
        Z = (X - mus[i]) ** 2
        loglik[:, i] = 2 * log_prior[i] - np.sum([Z[:, j] for j in range(D)], axis = 0)
    
    return cs[np.argmax(loglik, axis = 1)]

In [19]:
model = fit(X_train, y_train, 4.3)
print np.sum(predict(model, X_test) != y_test), 'errors out of', len(y_test)

22 errors out of 54


In [33]:
import pandas as pd

url = 'http://www.bioinf.ucd.ie/people/aedin/R/full_datasets/khan_train.csv'
X = pd.read_csv(url).values[:, 1:].T
X.shape

(64, 2308)

In [36]:
X[:, -1]

array([ 0.2044,  0.299 ,  0.223 ,  0.0871,  0.2157,  0.2525,  0.4084,
        0.5724,  0.3923,  0.2685,  0.6007,  0.456 ,  0.61  ,  0.2413,
        0.32  ,  0.4385,  0.1131,  0.1593,  0.1457,  1.042 ,  0.1715,
        0.1469,  0.1085,  0.1073,  0.2676,  0.2655,  0.262 ,  0.2177,
        0.1281,  0.2709,  0.2358,  0.2586,  0.2481,  0.2091,  0.0597,
        0.0536,  0.0837,  0.2792,  0.3812,  0.1523,  0.1002,  0.1071,
        0.0446,  0.2242,  0.4338,  0.1913,  0.0589,  0.2034,  0.2344,
        0.3379,  0.105 ,  0.1035,  0.138 ,  0.2039,  0.2975,  0.2728,
        0.1345,  0.1346,  0.2221,  0.381 ,  0.8788,  0.1925,  0.4603,
        0.3379])