# RICA: Reconstruction Independent Component Analysis

In [1]:
import numpy as np
import torch
from torch.nn import Parameter
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data import TensorDataset, DataLoader
import itertools

In [8]:
"""
Reproduces Reconstruction ICA with PyTorch

1. Modify `torchvision_path_cifar10` to your cifar10 path, or just any folder (it will download dataset automatically)
2. If you do not have a GPU, set `use_gpu=False`. It's going to take more than a few minutes.. If you want to speed
   things up a bit:
   - change lambdas to just [2.4], this runs the script with a single lambda value only, and gives decent result
   - maybe reduce num_epochs to 100
   - if you want to run all the lambda values
     - reduce patch_size to 8, which is probably 2x faster than 16
     - reduce num_epochs to 40
"""

use_gpu    = False              # if to use GPU
num_epochs = 100                # how long each lambda runs, 200 is probably overkill
num_steps  = 20                 # how many lambdas to try
patch_size = 16                 # patch size to extract, 16 is max
weight_size= patch_size**2      # weight size is number of pixels in a patch (do not change)
num_filters = weight_size       # complete-ICA has same number of filters as there are pixels
# lambdas = [l*0.4 for l in range(1,num_steps)] # the lambda values will be tried one by one
lambdas = [2.4] # the lambda values will be tried one by one
torchvision_path_cifar10 = 'torchvision_cifar10/'
batch_size = 1000

def maybe_gpu(data):
    return data.cuda() if use_gpu else data

# use cifar10 as dataset
dataset = torchvision.datasets.CIFAR10(
        torchvision_path_cifar10, 
        train=True, 
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.469, 0.481, 0.451], std=[0.239,0.245,0.272])
            # normalize to 0-mean, unit-variance
        ]), 
        download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=2, pin_memory=True)

# load the entire dataset into a single Tensor, this speeds things up quite a bit
data_all = []
for imgs, labels in loader:
    data_all.append(imgs)
data_all = torch.cat(data_all)      # merge into single tensor
data_all = data_all.mean(1)         # make black-white
data_all = maybe_gpu(data_all)

Files already downloaded and verified


In [13]:
data_size = data_all.size(0)
num_batches = int(data_size/batch_size)

In [14]:
print(data_all.shape)
print(weight_size)
print(num_filters)
print(type(data_all.size(0)))
data_all.size(0)/1000

torch.Size([50000, 32, 32])
256
256
<class 'int'>


50.0

In [15]:

# print(len(dataset))
# # print(dataset.targets)
# print(len(next(iter(dataset))))
# # z_chisq_npy = np.load("pnong/np-1d_zeta_fields/z_chisq_seeds741785_501982.npy")
# # z_chisq = torch.Tensor(z_chisq_npy)
# # dataset = TensorDataset(z_chisq) # create your datset
# # loader = DataLoader(dataset) # create your dataloader

# loader = DataLoader(dataset, batch_size=1000, num_workers=2, pin_memory=True)


# # load the entire dataset into a single Tensor, this speeds things up quite a bit
# data_all = []
# for imgs, labels in loader:
#     data_all.append(imgs)
# data_all = torch.cat(data_all)      # merge into single tensor
# print(len(data_all))
# # print(dataset.targets)
# print(len(next(iter(data_all))))
# # data_all = data_all.mean(1)         # make black-white
# # data_all = maybe_gpu(data_all)


In [16]:
def doit(lambd=1, epochs=num_epochs):
    weight    = Parameter(maybe_gpu(1.0/patch_size*torch.Tensor(weight_size, num_filters).normal_()))
    optimizer = torch.optim.RMSprop([weight], 0.001, momentum=0.9)

    for epoch in range(epochs):
        for batch in range(num_batches):
            # select batch
            imgs = data_all[batch*1000:(batch+1)*1000]
            # capture a few patches
            patches = []
            for x,y in itertools.product([0, 8, 16],[0,8,16]):
                patches.append(imgs[:, y:y+patch_size, x:x+patch_size])
            patches = Variable(maybe_gpu(torch.cat(patches)))
            patches = patches.view(patches.size(0), -1)
            latents= patches.matmul(weight)
            output = latents.matmul(weight.t())
            diff = output - patches
            loss_recon = (diff * diff).mean()
            loss_latent= latents.abs().mean()
            loss = lambd * loss_recon + loss_latent
            # optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch)
        print(loss.item())
        print(loss_recon.item())
        print(loss_latent.item())
    weight_images = weight.data.t().contiguous().view(num_filters, 1, patch_size, patch_size).cpu()
    vutils.save_image(weight_images, 'rica_weight_images_{}.jpg'.format(lambd), nrow=patch_size, normalize=True)
    print('Finished lambda={}'.format(lambd))

    return weight_images


for l in lambdas:
    weight = doit(l)



0
1.011399745941162
0.15843912959098816
0.6311457753181458
1
0.6897774934768677
0.05696240812540054
0.553067684173584
2
0.5739295482635498
0.034652501344680786
0.4907635748386383
3
0.49961450695991516
0.024129385128617287
0.4417039752006531
4
0.45104655623435974
0.018258504569530487
0.4072261452674866
5
0.42007017135620117
0.014783014543354511
0.38459092378616333
6
0.39749059081077576
0.012620189227163792
0.36720213294029236
7
0.3802899718284607
0.011223114095628262
0.35335448384284973
8
0.3664909303188324
0.010323027148842812
0.34171566367149353
9
0.3549286723136902
0.009826271794736385
0.3313456177711487
10
0.3466581106185913
0.0099002905189991
0.32289740443229675
11
0.34012675285339355
0.010082520544528961
0.31592869758605957
12
0.3341059386730194
0.010154243558645248
0.3097357451915741
13
0.33015984296798706
0.010510358959436417
0.3049349784851074
14
0.3274437189102173
0.010794878005981445
0.30153602361679077
15
0.32568126916885376
0.01094879861921072
0.2994041442871094
16
0.324018

In [18]:
print(weight.shape)

torch.Size([256, 1, 16, 16])
