In [None]:
import argparse
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms
from sklearn.cluster import MiniBatchKMeans

In [None]:
dim0 = 28
dim = dim0*dim0 # dim of the original img
kernel_size = 6
stride = 3
q_dim = 8192 # dim of the quantized img
k_fs = 5
suffix = '_reg_divide'

_device = 'cuda'

In [None]:
dataroot = 'D:\Lab\dataset'

total_train_dataset = datasets.MNIST(dataroot, train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor()
                   ]))

total_train_size = len(total_train_dataset)
split_ratio = 0.8
train_size = int(split_ratio * total_train_size)
valid_size = total_train_size - train_size

train_dataset, valid_dataset = random_split(total_train_dataset,[train_size, valid_size])
print('n: {}'.format(len(train_dataset)))

In [None]:
# Use kmeans to quantize dataset
b_size = 48000
kmeans_epochs = 10000

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=b_size, shuffle=True)

kmeans = MiniBatchKMeans(n_clusters=q_dim, batch_size=b_size)
_loss = 1e10
_cts = np.zeros((q_dim, 6*6))
_patience = max_patience = 2
print(kmeans)

for _epoch in range(kmeans_epochs):
    it = iter(train_loader)
    x = np.random.randint(0, 28-kernel_size+1)
    y = np.random.randint(0, 28-kernel_size+1)
    img, label = it.next()
    while True:
        try:
            img = img[:, 0, x:x+6, y:y+6]
            img = img.reshape(img.shape[0], -1)
            assert(img.shape[1] == kernel_size**2)
            kmeans.partial_fit(img)
        except:
            break
    _new_loss = np.sum(kmeans.cluster_centers_ - _cts)
    print('Epoch {} finishes. Loss: {}'.format(_epoch, _new_loss))
    _cts = kmeans.cluster_centers_
    if _new_loss < _loss:
        _loss = _new_loss
        _patience = max_patience
    else:
        _patience -= 1
        if _patience == 0:
            break

In [None]:
cts = kmeans.cluster_centers_
cts_ts = torch.from_numpy(cts)
def imshow(img):
    plt.figure(figsize=(30,30))
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

# show images
imshow(torchvision.utils.make_grid(cts_ts.view((cts_ts.shape[0], 1, 6, 6))[:128]))

In [None]:
class Trans(object):
    def __init__(self, kmeans=kmeans):
        self.kmeans = kmeans
        
    def __call__(self, img):
        img = np.array(img)
        res = np.zeros((8, 8))
        for x in range(0, 8):
            for y in range(0, 8):
                res[x, y] = self.kmeans.predict(img[3*x:3*x+6, 3*y:3*y+6].reshape(1, -1))[0]
        return torch.from_numpy(res.reshape(-1)).float()

In [None]:
train_dataset.dataset.transform = Trans()
valid_dataset.dataset.transform = Trans()
test_dataset = datasets.MNIST(dataroot, train=False, download=True,
                   transform=Trans())
test_size = len(test_dataset)

In [None]:
b_size = 128
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=b_size, shuffle=True)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=b_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=b_size, shuffle=True)

In [None]:
img, label = iter(train_loader).next()

print(img)
print(label)

In [None]:
# Training settings
n_epochs = 500
lr = 0.0001
valid_step = 100
patience = max_patience = 20
best_loss = 100000
best_acc = 0
model_file = './models/model_info' + suffix + '.md'

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(64, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        #x = F.dropout(x, training=self.training, p=0.8)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training, p=0.2)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)    

def test_validate(model, device, test_loader, test_valid='Test'):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc = 100. * correct / len(test_loader.dataset)
    print('\n{} set: Average loss: {:.6f}, Accuracy: {}/{} ({:.6f}%)\n'.format(
        test_valid, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss, test_acc
    
def train(model, device, train_loader, valid_loader, optimizer, epoch):
    global best_loss, best_acc, patience
    for batch_idx, (data, target) in enumerate(train_loader):
        #print(data[0].numpy()[:28*28])
        #print(target[0])
        model.train()
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 1 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        if (batch_idx+1) % valid_step == 0:
            print('Start Validating...')
            valid_loss, valid_acc = test_validate(model, device, valid_loader, 'Valid')
            # Very strange. Acc. improve, but loss increase.
            # if valid_loss < best_loss:
            #    best_loss = valid_loss
            if valid_acc > best_acc:
                best_acc = valid_acc
                patience = max_patience
                #print('Bese valid loss: {}'.format(best_loss))
                print('Improved! Reset patience.')
                print('Saving model...')
                torch.save(model, model_file)
            else:
                patience -= 1
                print('Not improved... Patience: {}'.format(patience))
                if patience == 0:
                    print('Out of patience. Stop training.')
                    return

def main():        
    dataroot = 'D:\Lab\dataset'
    
    print(len(train_loader.dataset))
    device = torch.device(_device)
    
    model = Net().to(device)
    print(model)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(1, n_epochs + 1):
        train(model, device, train_loader, valid_loader, optimizer, epoch)
        if patience == 0:
            break
                
    print('Start testing...')
    model = torch.load(model_file)
    test_validate(model, device, test_loader)
        
main()

In [None]:
# maps a pixel to quantized index
def q_map(x, threshold=0.4):
    if x < threshold:
        return 0
    else:
        return 1
    
# given a data matrix, returns the mask
def q_mask(X):
    mask = np.zeros((X.shape[0], X.shape[1]*q_dim), dtype=np.float)
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            mask[i, j*q_dim+q_map(X[i,j])] = 1
    norm_p = np.sum(mask, axis=0) / X.shape[0]
    return mask, norm_p

# X: data matrix
# feature extracting func
# vanilla implementation too slow. use batch.
def ACE_step(X, f, lr, batch_size, mask, norm_p):
    # E[sum f(X_j)|X_i], res is f.dim = data.dim*q_dim
    def expect_cond(X, f, idx):
        f_mask = np.tile(f, (len(idx), 1)) * mask[idx, :]
        sum_per_row = np.sum(f_mask, axis=1)
        sum_per_col = np.sum(f_mask, axis=0)
        res = np.zeros_like(f)
        for j in range(f.shape[0]):
            # if j % 500 == 0: print('expect_cond: {}-th column'.format(j))
            b = (mask[idx, j] == 1)
            if np.sum(b) > 0:
                res[j] = (np.sum(b * sum_per_row) - np.sum(sum_per_col[j])) / np.sum(b)
        return res
    
    # sqrt(E[sum f(X_i)^2])
    def expect_var(X, f):
        # sum_cov = np.sum(np.sum((np.tile(f, (X.shape[0], 1))*mask)**2))
        # e_cov = sum_cov / X.shape[0]
        # res = np.sqrt(e_cov)
        
        res = np.sqrt(np.sum((f**2) * norm_p))
        return res
    
    batch_idx = np.random.choice(X.shape[0], batch_size, False)
    e_cond = expect_cond(X, f, batch_idx)
    new_f = f + lr*e_cond
    cov = expect_var(X, new_f)
    
    new_f = new_f / cov
    return new_f

def init_f(dim=28*28, q_dim=2):
    res_dim = dim * q_dim
    res = np.random.rand(res_dim)
    for j in range(dim):
        res[j*q_dim : (j+1)*q_dim] -= np.sum(res[j*q_dim : (j+1)*q_dim])/2
    return res

def ACE_f(X, mask, norm_p, lr, batch_size, dim=28*28, q_dim=2, epsilon=0.01, max_step=500):
    f_1 = init_f(dim, q_dim)
    for _step in range(max_step):
        f_2 = ACE_step(X, f_1, lr, batch_size, mask, norm_p)
        delta = np.sum((f_1 - f_2)**2)
        f_1 = f_2
        print('Step {}: delta after ACE_step: {}'.format(_step, delta))
        if delta < epsilon:
            break
    return f_1

def norm_dot(f1, f2, norm_p):
    res = np.sum(f1 * f2 * norm_p)
    return res

def ACE_fs(X, k, mask, norm_p, lr=0.8, batch_size=512, dim=28*28, q_dim=2, epsilon=0.01, max_step=500):
    res = []
    for i in range(k):
        print('Building the {}-th feature function...'.format(i))
        f_i = ACE_f(X, mask, norm_p, lr, batch_size, dim, q_dim, epsilon, max_step)
        f_tmp = np.zeros_like(f_i)
        for f_m in res:
            f_tmp = f_tmp + norm_dot(f_i, f_m, norm_p)*f_m
        f_i -= f_tmp
        res.append(f_i)
    return res

Load all the training MNIST as a big matrix

X: n x (28*28)
mask: n x (28*28*2)

In [None]:
X = np.zeros((train_size, dim))
for idx, (img, _) in enumerate(train_dataset):
    X[idx, :] = img.numpy().transpose(1, 2, 0).flatten()
    if idx%5000 == 0: print(idx)
mask, norm_p = q_mask(X)

In [None]:
print(X.shape)
print(mask.shape)
print(mask[0][700:800])
print(norm_p.shape)
print(norm_p[700:800])

Build the feature extracting func. f

In [None]:
import pickle
feature_path = './models/feats' + suffix

if False:
    fs = ACE_fs(X, k_fs, mask, norm_p, 1, train_size, epsilon=0.0000000001, max_step=10)
    with open(feature_path, 'wb') as fp:
        pickle.dump(fs, fp)
else:
    with open(feature_path, 'rb') as fp:
        fs = pickle.load(fp)
print(len(fs))

In [None]:
class InfoTrans(object):
    def __init__(self, map_func, q_dim, fs):
        self.map_func = map_func
        self.fs = fs
        self.k_fs = len(fs)
    
    def __call__(self, img):
        img = img.numpy().transpose(1, 2, 0).flatten()
        dim = len(img)
        
        # returns features
        feats = np.zeros(self.k_fs*len(img))
        for idx, pixel in enumerate(img):
            for k in range(self.k_fs):
                feats[k*dim+idx] = self.fs[k][idx*q_dim+self.map_func(pixel)]
        feats = feats
        
        return feats   

In [None]:
processed_train_file = './models/processed_train' + suffix + '.npz'
processed_valid_file = './models/processed_valid' + suffix + '.npz'
processed_test_file = './models/processed_test' + suffix + '.npz'
train_max_min_file = './models/train_max_min' + suffix + '.npz'

if False:
    processed_train = np.zeros((train_size, k_fs*dim))
    processed_label = np.zeros((train_size, 1))
    for idx, (img, label) in enumerate(train_dataset):
        processed_train[idx, :] = InfoTrans(q_map, q_dim, fs)(img)
        processed_label[idx, 0] = label
        if idx%1000==0: print(idx)
            
    train_min, train_max = np.zeros(k_fs), np.zeros(k_fs)
    for k in range(k_fs):
        train_max[k] = np.max(processed_train[:, k*dim:(k+1)*dim])
        train_min[k] = np.min(processed_train[:, k*dim:(k+1)*dim])
        
    print(train_max)
    print(train_min)
    
    with open(train_max_min_file, 'wb') as fp:
        np.savez(fp, train_max=train_max, train_min=train_min)
    
    for k in range(k_fs):
        processed_train[:, k*dim:(k+1)*dim] = (processed_train[:, k*dim:(k+1)*dim] - train_min[k]) / (train_max[k] - train_min[k]) * 2 - 1

    with open(processed_train_file, 'wb') as fp:
        #torch.save(processed_train, fp)
        np.savez(fp, img=processed_train, label=processed_label)

    processed_valid = np.zeros((valid_size, k_fs*dim))
    processed_label = np.zeros((valid_size, 1))
    for idx, (img, label) in enumerate(valid_dataset):
        processed_valid[idx, :] = InfoTrans(q_map, q_dim, fs)(img)
        processed_label[idx, 0] = label
        if idx%1000==0: print(idx)
        
    for k in range(k_fs):
        processed_valid[:, k*dim:(k+1)*dim] = (processed_valid[:, k*dim:(k+1)*dim] - train_min[k]) / (train_max[k] - train_min[k]) * 2 - 1
        
    with open(processed_valid_file, 'wb') as fp:
        #torch.save(processed_valid, fp)
        np.savez(fp, img=processed_valid, label=processed_label)

    processed_test = np.zeros((test_size, k_fs*dim))
    processed_label = np.zeros((test_size, 1))
    for idx, (img, label) in enumerate(test_dataset):
        processed_test[idx, :] = InfoTrans(q_map, q_dim, fs)(img)
        processed_label[idx, 0] = label
        if idx%1000==0: print(idx)            
        
    for k in range(k_fs):
        processed_test[:, k*dim:(k+1)*dim] = (processed_test[:, k*dim:(k+1)*dim] - train_min[k]) / (train_max[k] - train_min[k]) * 2 - 1

    with open(processed_test_file, 'wb') as fp:
        #torch.save(processed_test, fp)
        np.savez(fp, img=processed_test, label=processed_label)
    
print('Done!')

In [None]:
class InfoMNIST(torch.utils.data.Dataset):
    def __init__(self, train=True, transform=None, target_transform=None, 
                 processed_train_file = processed_train_file,
                 processed_valid_file = processed_valid_file, 
                 processed_test_file = processed_test_file):
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if self.train=='train':
            filename = processed_train_file
        elif self.train=='valid':
            filename = processed_valid_file
        else:
            filename = processed_test_file
        with open(filename, 'rb') as fp:
            npzfile = np.load(fp)
            self.length = npzfile['label'].shape[0]
            self.img = torch.from_numpy(npzfile['img']).float()
            self.label = torch.from_numpy(npzfile['label'].flatten()).long()

    def __getitem__(self, index):
        img, target = self.img[index], self.label[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return self.length

# Visualization of Features

mean_var_file = './models/mean_var'

if True:
    b_size = 20000

    train_loader_vis = torch.utils.data.DataLoader(
        InfoMNIST('train'),
        batch_size=b_size, shuffle=True)
    '''
    train_mean = torch.zeros((1, k_fs*dim))
    train_var = torch.zeros((1, k_fs*dim))
    for img, label in train_loader_vis:
        train_mean += torch.sum(img, 0, keepdim=True)
    train_mean /= train_size

    for img, label in train_loader_vis:
        train_var += torch.sum((img - train_mean)**2, 0, keepdim=True)
    train_var = torch.sqrt(train_var) / train_size
    
    train_mean = train_mean.numpy()
    train_var = train_var.numpy()
    '''
    train_mean, train_var, train_min, train_max = np.zeros(k_fs), np.zeros(k_fs), np.zeros(k_fs), np.zeros(k_fs)
    for img, label in train_loader_vis:
        for k in range(k_fs):
            train_mean[k] = np.mean(img.numpy()[:, k*dim:(k+1)*dim])        
            train_var[k] = np.var(img.numpy()[:, k*dim:(k+1)*dim])
            train_max[k] = np.max(img.numpy()[:, k*dim:(k+1)*dim])
            train_min[k] = np.min(img.numpy()[:, k*dim:(k+1)*dim])
        break
    
    with open(mean_var_file, 'wb') as fp:
        np.savez(fp, mean=train_mean, var=train_var)
else:
    with open(mean_var_file, 'rb') as fp:
        npzfile = np.load(fp)
        train_mean, train_var = npzfile['mean'], npzfile['var']

print(train_mean)
print(train_var) 
print(train_max)
print(train_min)

class NormalTrans(object):
    def __init__(self, mean, var, _max, _min):
        self.mean = mean
        self.var = var
        self._max = _max
        self._min = _min
    
    def __call__(self, img):
        res = img
        for k in range(k_fs):
            res[k*dim:(k+1)*dim] = (img[k*dim:(k+1)*dim] - self._min[k]) / (self._max[k] - self._min[k]) * 2 - 1#/ self.var        
        return res   

In [None]:
it = iter(test_loader)
test_img, test_label = it.next()
print(test_img.shape)
for i in range(k_fs):
    print(torch.mean(test_img[0][i*28*28:(i+1)*28*28], 0))
print(test_img[0][:])
print(test_label.shape)
for i in range(5):
    test_img, test_label = it.next()
    plt.hist(test_img[0][28*28:2*28*28])
    plt.show()

## visualize the 1st feature

In [None]:
def imshow(npimg):
    npimg = npimg * 0.5 + 0.5
    npimg = npimg.reshape((28,28))
    plt.imshow(npimg, cmap='gray')
    plt.show()

In [None]:
for k in range(1):
    for j in range(8):
        print(test_img[j])
        imshow(test_img[j].numpy()[k*28*28:(k+1)*28*28])
        print(test_label[j])

In [None]:
print(fs[0].shape)
print(fs[0][:50])
plt.hist(fs[0])
plt.show()