#### Summary
Takes trained model (from Luke).

On each "trial" takes 1 sample + N(e.g. 20) test, where exactly one of the 20 is the same class as in the sample.

Finds the 1 in the N that matches the sample the best, based on the posterior predictive probabilities

In [None]:
testdata_numpy = np.fromfile("examples/mnist/data/binarized_mnist_test.amat", dtype=np.int16).reshape(-1,1,28,28)

1) get encodings for all train characters

2) build model using those training characters

3) classify test characters




In [None]:
%matplotlib inline
import torch
import matplotlib.pyplot as plt
import examples.mnist as M
import numpy as np

model=torch.load("./model.p", map_location='cpu') #LT
print("Loaded model.p")



## ==== REPLACE M.testdata WITH TORCHVISION MNIST DATA, so that have accurate labels.
# scale by 255, sample from bernoulli with those p, in order to binarize.


import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

mnist_test = datasets.MNIST(root='./data', train=False, download=True)
mnist_test.test_data = mnist_test.test_data.reshape(-1, 1, 28, 28)

# ===== BINARIZES THE DATASET
from torch.distributions.bernoulli import Bernoulli
mnist_bin = Bernoulli(mnist_test.test_data.to(torch.float64)/255).sample()

plt.figure()
plt.subplot(121)
plt.imshow(mnist_test.test_data.to(torch.float64)[1][0])
plt.subplot(122)
plt.imshow(mnist_bin[1][0])

mnist_test.test_data = mnist_bin

mnist_test.test_data.to(torch.uint8)
mnist_test.test_data[0][0]

plt.figure()
plt.subplot(121)
plt.imshow(mnist_test.test_data[1][0])

# ========= PERFORM REPLACEMENT
n = M.testdata.shape[0]
M.testdata = mnist_test.test_data[0:n]
M.testlabels = mnist_test.test_labels[0:n]


## ===== COMPARE original binarized dataset (above, train) vs. one that I made (below, test)

plt.figure(figsize=(10,20))
indthis = np.random.randint(low=0, high=100, size=10)
for i, ind in enumerate(indthis):
    plt.subplot(1,10,i+1)
    plt.imshow(M.data[ind][0].numpy().reshape(28, 28), vmin=0, vmax=1)
    plt.tight_layout()
    
plt.figure(figsize=(10,20))
for i, ind in enumerate(indthis):
    plt.subplot(1,10,i+1)
    plt.imshow(M.testdata[ind][0].numpy().reshape(28, 28), vmin=0, vmax=1)
##    plt.tight_layout()

# Classification code below

## define a function that loads appropriate characters

def getIdx(charSamp, M, Nway=20):    
#     charSamp = 1 # the sample integer (0,...,9)
#     Nway = 20 # how many test characters (only 1 will match the Samp)

    charSamp = torch.tensor(np.array(charSamp))
    
    # get sample
    a = np.where(M.testlabels==charSamp)
    idx_sample, idx_testmatch = np.random.choice(a[0], size=2, replace=False)

    # get the N-1 test that do not match the sample
    a = np.where(M.testlabels!=charSamp)
    idx_testnonmatch = np.random.choice(a[0], size=Nway-1, replace=False)
    
    return idx_sample, idx_testmatch, idx_testnonmatch


idx_sample, idx_testmatch, idx_testnonmatch = getIdx(1, M)




In [None]:
charSamp = 4
Nway=10

## FUNCTION to computer posteriors
idx_sample, idx_testmatch, idx_testnonmatch = getIdx(charSamp = charSamp, M=M, Nway=Nway)
# concatenate them so that can run all together
idx_all = np.concatenate((idx_sample.reshape(1), idx_testmatch.reshape(1), idx_testnonmatch))

# get latent for 1-sample
c_all, _ = model.encoder(i=np.arange(len(idx_all)), x=M.testdata[idx_all])
# c_all, _ = model.encoder(i=np.arange(len(idx_all)), x=M.testdata[idx_all])

In [None]:
# likelihoods for N-tests - HERE: EACH CHAR USING ITS OWN MODEL.
noiseval = 0.25
sample_probs=True
NiterNoise = 5
score_all = []
for n in range(NiterNoise):
    x_decod, score = model.decoder(i=np.arange(len(idx_all)), c=c_all, x=M.testdata[idx_all], 
                             noise=noiseval, sample_probs=sample_probs)
    score_all.append(score.detach().numpy())

score_all = [ss.reshape(1,-1) for ss in score_all]
score_all = np.concatenate(score_all)

plt.figure(figsize=(20, 5))
plt.plot(score_all.T, '-o')

plt.figure(figsize=(20,5))
for j, xx in enumerate(M.testdata[idx_all]):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('orig%s' % j)
    plt.imshow(xx.numpy().reshape(28, 28), vmin=0, vmax=1)   
    
plt.figure(figsize=(20,5))
for j, xx in enumerate(x_decod):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('decod%s' % j)
    plt.imshow(xx.detach().numpy().reshape(28, 28), vmin=0, vmax=1)   
    

# -underlay with the images of the characters

In [None]:
# DIFFERENT - HERE, MODEL FOR THE FIRST CHARACTER, AGAINST THE OTHERS.
# This illustrates the problem:
# the image that is decoded using model.decoder() is entirely derived from the c that are input. 
c_onlyfirst = [c[0] for _ in range(len(idx_all))]


noiseval = 0
sample_probs=True
NiterNoise = 5
score_all = []
for n in range(NiterNoise):
    x_decod, score = model.decoder(i=np.arange(len(idx_all)), c=c_onlyfirst, x=M.testdata[idx_all], 
                                   noise=noiseval, sample_probs=sample_probs)
    score_all.append(score.detach().numpy())

score_all = [ss.reshape(1,-1) for ss in score_all]
score_all = np.concatenate(score_all)

plt.figure(figsize=(20, 5))
plt.plot(score_all.T, '-o')

plt.figure(figsize=(20,5))
for j, xx in enumerate(M.testdata[idx_all]):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('orig%s' % j)
    plt.imshow(xx.numpy().reshape(28, 28), vmin=0, vmax=1)   
    
plt.figure(figsize=(20,5))
for j, xx in enumerate(x_decod):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('decod%s' % j)
    plt.imshow(xx.detach().numpy().reshape(28, 28), vmin=0, vmax=1)   
    

# -underlay with the images of the characters

In [None]:
## HERE: using the same i, but conditioning on each model differnetly
## doesn't work - given some i, will give the same output each time... [for the decoder]. 

noiseval = 0
sample_probs=True
NiterNoise = 5
score_all = []
for n in range(NiterNoise):
    ctmp, x_decod = model.sample(i=[idx_all[0] for _ in range(len(idx_all))], x=M.testdata[idx_all],
                                sample_probs=sample_probs)
    
    _, score = model.decoder(i=np.arange(len(idx_all)), c=c_all, x=M.testdata[idx_all], 
                                   noise=noiseval, sample_probs=sample_probs)
    score_all.append(score.detach().numpy())

score_all = [ss.reshape(1,-1) for ss in score_all]
score_all = np.concatenate(score_all)

plt.figure(figsize=(20, 5))
plt.plot(score_all.T, '-o')

plt.figure(figsize=(20,5))
for j, xx in enumerate(M.testdata[idx_all]):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('orig%s' % j)
    plt.imshow(xx.numpy().reshape(28, 28), vmin=0, vmax=1)   
    
plt.figure(figsize=(20,5))
for j, xx in enumerate(x_decod):
    plt.subplot(1,len(idx_all),j+1)
    plt.title('decod%s' % j)
    plt.imshow(xx.detach().numpy().reshape(28, 28), vmin=0, vmax=1)   
    

# -underlay with the images of the characters

## Summary so far:
Am not sure how to "refit" a given model to different characters. The first method I tried (i.e. fix c, then decode conditioned on other characters) doesn't work, since c entirely determines the outpt. I think I need use the posterior for a given character (e.g. infer 10 of those guys to estimate the posterior, then sample from that). The plan of doing inference to get c, and then refitting that to all images does not work. 



In [None]:
# ===== clear all model mixtures


In [None]:
model(i=np.arange(len(idx_all)), x=M.testdata[idx_all[0]].repeat(len(idx_all),1,1,1))


In [None]:
log_probs = model(i=np.arange(len(idx_all)), x=M.testdata[idx_all])
model(i=np.arange(len(idx_all))+3, x=M.testdata[idx_all[0]].repeat(len(idx_all),1,1,1))


In [None]:
idx_all
x=M.testdata[j:j+1].repeat(100, 1, 1, 1)

In [None]:
model.conditional(i=np.array(idx_all[0:1]), D=x_all[0:1], x=x_all[0:1])

In [None]:
idx_all = np.concatenate((idx_sample.reshape(1), idx_testmatch.reshape(1), idx_testnonmatch))

a = [model.conditional(i=np.array(idx_all[0:1]), D=x_all[0:1], x=x_all[kk:kk+1]) for kk in range(1, 3)]
print(a)

a = np.concatenate([aa.detach().numpy() for aa in a])
a = torch.tensor(a)
print(a)


In [None]:
a = [model.conditional(i=np.array(idx_all[0:1]), D=x_all[0:1], x=x_all[kk:kk+1]) for kk in range(1, 3)]
a

In [None]:
## ===== ZEROTH, extract indices in test set that you want to work with
def getClassScore(charSamp, Nway=10, frontierSize=5, nUpdates=10, plotON=False):
# charSamp = 3 # which digit to use for the sample?
# Nway=15
# frontierSize = 5; how many components to remember

    # ========== GET RANDOM SAMPLES
    # IN ORDER: (sample, same_as_sample, diff_from_sample)
    idx_sample, idx_testmatch, idx_testnonmatch = getIdx(charSamp = charSamp, M=M, Nway=Nway)
    idx_all = np.concatenate((idx_sample.reshape(1), idx_testmatch.reshape(1), idx_testnonmatch))


    # EXTRACT DATA
    x_samp = M.testdata[idx_all[0]] # sample
#     x_test = [M.testdata[ii] for ii in idx_all[1:]] # need to already be updated in the model.
    x_all = [M.testdata[ii] for ii in idx_all] # need to already be updated in the model.
    # idx_all = range(len(idx_all))


    # ====== FIRST, EMPTY MIXTURE COMPONENTS
    model.frontierSize = frontierSize
    from torch.nn import Parameter
    model.mixtureComponents = [[] for _ in range(len(M.testdata))]
    model.mixtureWeights = Parameter(torch.zeros(len(M.testdata), model.frontierSize)) #Unnormalised log-q
    model.mixtureScores = [[] for _ in range(len(M.testdata))] #most recent log joint
    model.nMixtureComponents = Parameter(torch.zeros(len(M.testdata)))


    ## ===== FIRST, need to update model with posteriors for the novel stimuli    
    if plotON is True:
        print(model.mixtureWeights[idx_all[0]])
    model.makeUpdates(i=idx_all, x=x_all, nUpdates=nUpdates)
    if plotON is True:
        print(model.mixtureWeights[idx_all[0]])

    ## TESTING CLASSIFICATION ACCURACY
    
    # x_samp = M.testdata[idx_all[0]] # sample
    # x_test = [M.testdata[ii] for ii in idx_all] # need to already be updated in the model.
    # scores = model.conditional(idx_all, x_test, [x_samp]*(len(idx_all)-1))
    scores = model.conditional(i=idx_all, D=x_all, x=[x_samp]*(len(idx_all))) # use model for D to predict x_samp
    
    scores_sampmodel = [model.conditional(i=np.array(idx_all[0:1]), D=x_all[0:1], x=x_all[kk:kk+1]) for kk in range(0, len(x_all))]
    scores_sampmodel = torch.tensor(np.concatenate([aa.detach().numpy() for aa in scores_sampmodel]))

    scores_sum = scores_sampmodel + scores
    # ======== FOR EACH TEST CHARACTER, GET ITS SCORE 
    if plotON is True:
        plt.figure()
        plt.plot(scores.detach().numpy(), '-ok')
        # predictive += scores[0].mean().item() #.item() so we don't hold onto tensor for gradient information
        # if not (scores>scores[0]).any():
        #     num_tied = (scores==scores[0]).sum().item()
        #     hits += 1/num_tied 
        # total += 1

        # print("Took", int(time.time()-starttime), "seconds")
        # print(hits / total) # hit rate
        # print(predictive/total) # mean score
    print(scores)
    print(idx_all)
    return scores, scores_sampmodel, idx_all, scores_sum

hits = 0
total = 0
predictive = 0


In [None]:
Ntrials = 1000
import pickle
charall = np.random.randint(0, 10, Ntrials)

saveInterval = 20 # save mod this number trials

# for w, f, u in zip([5, 5, 10, 10, 20, 20],[5, 10, 5, 10, 5, 10],[20, 20, 20, 20, 20, 20]):
for w, f, u in zip([20, 20],[5, 10],[20, 20]):
    scores_all = []
    scores_sampmodel_all = []
    scores_sum_all = []
    idx_all = []
    for i, cc in enumerate(charall):
        print(i)
        cc = np.array(cc)
        scores, scores_sampmodel, idx, scores_sum = \
        getClassScore(charSamp=cc, Nway=w, frontierSize=f, nUpdates=u, plotON=False)

        scores_all.append(scores.detach())
        scores_sampmodel_all.append(scores_sampmodel.detach())
        scores_sum_all.append(scores_sum.detach())
        idx_all.append(idx)

        if i%saveInterval==0 and i>0:
            ## ========= SAVE OUTPUT
            datall = [scores_all, scores_sampmodel, scores_sampmodel_all, idx_all]

            # ======== SAVE
            with open('./datsave/datall_w%sf%su%s' %(w, f, u), 'wb') as fl:
                pickle.dump(datall, fl)

            print('==== SAVED!')
                
                
        ## ========= SAVE FINAL
        datall = [scores_all, scores_sampmodel, scores_sampmodel_all, idx_all]
        with open('./datsave/datall_w%sf%su%s' %(w, f, u), 'wb') as fl:
            pickle.dump(datall, fl)

In [None]:
w=5
f=5
u=20
with open('./datsave/datall_w%sf%su%s' %(w, f, u), 'rb') as f:
    scores_all, scores_sampmodel, scores_sampmodel_all, idx_all \
    = pickle.load(f)

In [None]:
w=5
f=5
u=20
with open('./datsave/datall', 'rb') as f:
    dat = pickle.load(f)

In [None]:
## ============== OPTIONAL - LOAD PREVIOUSLY SAVED STUFF

In [None]:
## SUMMARIZE CLASSIFICATION ERROR
# ===== 1) PLOT likelihood probabilites
plt.figure(figsize=(18, 5))

for nn in range(10):
    
    plt.subplot(2,5,nn+1)
    plt.title('int=%d' % nn)
    
#     scorestoplot = [scores_all[ii] for ii in range(len(idx_all)) if M.testlabels[idx_all[ii][0]]==nn] # trials with nn as trainsamp
#     plt.plot(scorestoplot, '-ok')
    
    trialsthis = [ii for ii in range(len(idx_all)) if M.testlabels[idx_all[ii][0]]==nn] # trials with nn as trainsamp
    
    for tt in trialsthis:
        plt.plot(scores_all[tt].detach().numpy(), '-')
        



In [None]:
# ===== PLOT - combine across integers
plt.figure(figsize=(10, 5))
for ss in scores_all:
    plt.plot(ss.detach().numpy(), '-')
plt.xlabel('train-test-rest')
plt.ylabel('log(posterior predictive prob)')
# ==== Histogram of "rank" for the correct test
tmp = [sum(scores_all[i][1]>scores_all[i][2:]) for i in range(len(scores_all))]
rank = np.array(len(scores_all[0])-2)-tmp # ranges from 0 to NumOthers

plt.figure()

plt.subplot(1,2,1)
plt.hist(rank, density=True)
plt.title('rank of actual match (using sum of bidir probabilites)')
plt.xlabel('rank(0==>best match)')

plt.subplot(1,2,2)
plt.hist(rank, density=False)

    

## BElow - scratch

In [None]:
# ========== GET RANDOM SAMPLES
# IN ORDER: (sample, same_as_sample, diff_from_sample)
idx_sample, idx_testmatch, idx_testnonmatch = getIdx(charSamp = charSamp, M=M, Nway=Nway)
idx_all = np.concatenate((idx_sample.reshape(1), idx_testmatch.reshape(1), idx_testnonmatch))


# EXTRACT DATA
x_samp = M.testdata[idx_all[0]] # sample
#     x_test = [M.testdata[ii] for ii in idx_all[1:]] # need to already be updated in the model.
x_all = [M.testdata[ii] for ii in idx_all] # need to already be updated in the model.
# idx_all = range(len(idx_all))


# ====== FIRST, EMPTY MIXTURE COMPONENTS
model.frontierSize = 5
from torch.nn import Parameter
model.mixtureComponents = [[] for _ in range(len(M.testdata))]
model.mixtureWeights = Parameter(torch.zeros(len(M.testdata), model.frontierSize)) #Unnormalised log-q
model.mixtureScores = [[] for _ in range(len(M.testdata))] #most recent log joint
model.nMixtureComponents = Parameter(torch.zeros(len(M.testdata)))


## ===== FIRST, need to update model with posteriors for the novel stimuli
nUpdates=1
print(model.mixtureWeights[idx_all[0]])
model.makeUpdates(i=idx_all, x=x_all, nUpdates=nUpdates)
print(model.mixtureWeights[idx_all[0]])

## TESTING CLASSIFICATION ACCURACY

hits = 0
total = 0
predictive = 0

# x_samp = M.testdata[idx_all[0]] # sample
# x_test = [M.testdata[ii] for ii in idx_all] # need to already be updated in the model.
# scores = model.conditional(idx_all, x_test, [x_samp]*(len(idx_all)-1))
scores = model.conditional(i=idx_all, D=x_all, x=[x_samp]*(len(idx_all)))
# scores = model.conditional(range(len(x_all)), x_all, x_all)

plt.figure()
plt.plot(scores.detach().numpy(), '-ok')
# predictive += scores[0].mean().item() #.item() so we don't hold onto tensor for gradient information
# if not (scores>scores[0]).any():
#     num_tied = (scores==scores[0]).sum().item()
#     hits += 1/num_tied 
# total += 1

# print("Took", int(time.time()-starttime), "seconds")
# print(hits / total) # hit rate
# print(predictive/total) # mean score

In [None]:
## ===== SECOND, get predictive probabilities
# THIS IS DOING FOR TRAINING DATA...

import time
## HEWITT FUNCTION
n_samples = 100
n_way = 50

n_samples = min(n_samples, len(M.data))
n_way = min(n_way, len(M.data))
print("Evaluating %d-way classification accuracy"%n_way, flush=True)
starttime=time.time()
#i = list(range(len(testM.data)))

hits = 0
total = 0
predictive = 0
for trueclass in range(n_samples): #100 classes np.random.choice(range(len(testM.data)), size=500, replace=False):
    x = M.data[trueclass]
    others = list(range(trueclass)) + list(range(trueclass+1, len(M.data)))
    i_rearranged = [trueclass] + list(np.random.choice(others, size=n_way-1, replace=False))
    x_inp = [M.data[ii] for ii in i_rearranged] # need to already be updated in the model.
    scores = model.conditional(i_rearranged, x_inp, [x]*n_way)
    predictive += scores[0].mean().item() #.item() so we don't hold onto tensor for gradient information
    #print("scores:", scores)
    if not (scores>scores[0]).any():
        num_tied = (scores==scores[0]).sum().item()
        hits += 1/num_tied 
    total += 1
print("Took", int(time.time()-starttime), "seconds")
print(hits / total) # hit rate
print(predictive/total) # mean score