In [7]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset, DataLoader

In [8]:
import matplotlib.pyplot as plt

In [9]:
from PIL import Image

In [10]:
import sys
sys.path.append("../figures")

from where import WhereFill, WhereShift, RetinaBackground, RetinaMask, RetinaWhiten 
from where import RetinaTransform, WhereNet, CollTransform, MNIST, Normalize, WhereTrainer

In [11]:
from main import init
args = init(filename='../data/2019-06-05')
args

{'w': 28,
 'minibatch_size': 100,
 'train_batch_size': 50000,
 'test_batch_size': 10000,
 'noise_batch_size': 1000,
 'mean': 0.1307,
 'std': 0.3081,
 'N_pic': 128,
 'offset_std': 30,
 'offset_max': 34,
 'noise': 1.0,
 'contrast': 0.7,
 'sf_0': 0.1,
 'B_sf': 0.1,
 'N_theta': 6,
 'N_azimuth': 24,
 'N_eccentricity': 10,
 'N_phase': 2,
 'rho': 1.41,
 'bias_deconv': True,
 'p_dropout': 0.0,
 'dim1': 1000,
 'dim2': 1000,
 'lr': 0.005,
 'do_adam': True,
 'bn1_bn_momentum': 0.5,
 'bn2_bn_momentum': 0.5,
 'momentum': 0.3,
 'epochs': 60,
 'num_processes': 1,
 'no_cuda': True,
 'log_interval': 100,
 'verbose': 1,
 'filename': '../data/2019-06-05',
 'seed': 2019,
 'N_cv': 10,
 'do_compute': True}

In [12]:
from retina import Retina
retina = Retina(args)

Inversing retina transform...


FileNotFoundError: [Errno 2] No such file or directory: '/tmp/retina_6_24_10_2_1.41_128_inverse_transform.npy'

In [8]:
accuracy_map = np.load('../data/MNIST_accuracy.npy')

FileNotFoundError: [Errno 2] No such file or directory: '../data/MNIST_accuracy.npy'

In [None]:
accuracy_map.min()

In [None]:
plt.imshow(accuracy_map)

In [None]:
_ = plt.plot(accuracy_map)

## Unit tests

In [None]:
import math

In [None]:
i_offset = -18 #None #
j_offset = 18 #None #

In [None]:
transform=transforms.Compose([
                               WhereFill(N_pic=args.N_pic),
                               #WhereShift(args, theta=3*math.pi/4), 
                               #WhereShift(args, radius = 18), 
                               WhereShift(args, i_offset=i_offset, j_offset=j_offset),
                               RetinaBackground(),
                               RetinaMask(N_pic=args.N_pic),
                               RetinaWhiten(N_pic=args.N_pic),
                               RetinaTransform(retina.retina_transform_vector),
                               #Normalize()
                           ])

In [None]:
target_transform=transforms.Compose([
                               WhereFill(accuracy_map=accuracy_map, N_pic=args.N_pic),
                               #WhereShift(args, theta=3*math.pi/4, baseline = 0.1),
                               #WhereShift(args, radius = 18, baseline = 0.1),
                               WhereShift(args, i_offset=i_offset, j_offset=j_offset, baseline = 0.1),
                               CollTransform(retina.colliculus_transform_vector),
                           ])

In [None]:
dataset_train = MNIST('../data',
                        train=True,
                        download=True,
                        transform=transform,
                        target_transform = target_transform,
                        )

In [None]:
train_loader = torch.utils.data.DataLoader(dataset_train,
                                         batch_size=args.minibatch_size,
                                         shuffle=True)

In [None]:
dataset_test = MNIST('../data',
                        train=False,
                        download=True,
                        transform=transform,
                        target_transform = target_transform,
                        )

In [None]:
test_loader = torch.utils.data.DataLoader(dataset_test,
                                         batch_size=args.minibatch_size,
                                         shuffle=True)

In [None]:
data, label = next(iter(train_loader))

In [None]:
data.shape

In [None]:
#plt.imshow(data[i,:,:].detach().numpy())

In [None]:
label.shape

In [None]:
i = 7
plt.plot(data[i,:].detach().numpy())


In [None]:
plt.imshow(retina.retina_invert(data[i,:].detach().numpy()))

In [None]:
plt.plot(label[i,:].detach().numpy())

In [None]:
plt.imshow(retina.accuracy_invert(label[i,:].detach().numpy()))

In [None]:
plt.hist(data[i,:].detach().numpy().flatten())

## WhereTrainer Class test

In [None]:
whereTrainer = WhereTrainer(args, generate_data=True)

Generating training dataset
0 100
1 200
2 300
3 400
4 500
5 600
6 700
7 800
8 900
9 1000
10 1100
11 1200
12 1300
13 1400
14 1500
15 1600
16 1700
17 1800
18 1900
19 2000
20 2100
21 2200
22 2300
23 2400
24 2500
25 2600
26 2700
27 2800
28 2900
29 3000
30 3100
31 3200
32 3300
33 3400
34 3500
35 3600
36 3700
37 3800
38 3900
39 4000
40 4100
41 4200
42 4300
43 4400
44 4500
45 4600
46 4700


In [None]:
whereTrainer.model

In [None]:
whereTrainer.loss_func

In [None]:
whereTrainer.optimizer

In [None]:
data, label = next(iter(whereTrainer.test_loader))

In [None]:
data.shape

In [None]:
label.min()

In [None]:
output = whereTrainer.model(data)

In [None]:
output.shape

In [None]:
output.max()

In [None]:
loss_func = torch.nn.BCEWithLogitsLoss()
loss = loss_func(output, label)

In [None]:
if True:
    from where import train
    train(args, 
          whereTrainer.model, 
          "cpu", 
          whereTrainer.train_loader, 
          whereTrainer.loss_func, 
          whereTrainer.optimizer, 
          1)

In [None]:
if False:
    for epoch in range(1, 2): #args.epochs + 1):
        whereTrainer.train(epoch)
        whereTrainer.test()

In [None]:
whereTrainer.test()

## Dataset generation

In [None]:
if False:
    for i, (data, label) in enumerate(train_loader):
        print(i, (i+1) * args.minibatch_size)
        if i == 0:
            full_data = data
            full_label = label
        else:
            full_data = torch.cat((full_data, data), 0)
            full_label = torch.cat((full_label, label), 0)

In [None]:
dataset = TensorDataset(full_data, full_label)
data_loader = DataLoader(dataset, batch_size=args.minibatch_size)

In [None]:
full_data.shape

In [None]:
data, label = next(iter(data_loader))

In [None]:
data.shape

In [None]:
plt.plot(label[i,:].detach().numpy())

In [None]:
np.random.seed(26722)
np.random.randn()

In [None]:
from display import pe, minmax
minmax(-15, 10)