In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [62]:
from mlutils.layers.readouts import PointPooled2d
from mlutils.layers.cores import Stacked2dCoreDropOut

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Parameter


class Encoder(nn.Module):

    def __init__(self, core, readout):
        super().__init__()
        self.core = core
        self.readout = readout

    @staticmethod
    def get_readout_in_shape(core, shape):
        train_state = core.training
        core.eval()
        tmp = torch.Tensor(*shape).normal_()
        nout = core(tmp).size()[1:]
        core.train(train_state)
        return nout

    def forward(self, x):
        x = self.core(x)
        x = self.readout(x)
        return F.elu(x) + 1


class Ensemble(nn.Module):
    def __init__(self, n_models, seeds, train_loader, config):
        super().__init__()
        self.n_models = n_models
        self.seeds = seeds
        self.models = nn.ModuleList()
        self.config = config
        self.train_loader = train_loader

    def init_ensemble(self):
        for rand_seed in self.seeds:
            self.models.append(create_model(self.train_loader, rand_seed, **self.config))
            """for i, model in enumerate(self.models):
                self.add_module('model_{}'.format(i), model)"""

    def forward(self, x):
        out = [model(x) for model in self.models]
        return torch.cat(out)

    @staticmethod
    def diff_real(preds, true):
        return F.mse_loss(preds, true, reduction='mean')

    @staticmethod
    def variance(preds):
        return preds.std(dim=0)

def create_model(train_loader, seed, **config):
    np.random.seed(seed)
    in_shape = train_loader.img_shape
    n_neurons = train_loader.n_neurons
    transformed_mean = train_loader.transformed_mean
    core = Stacked2dCoreDropOut(input_channels=1,
                                hidden_channels=32,
                                input_kern=15,
                                hidden_kern=7,
                                dropout_p=config['dropout_p'],
                                layers=3,
                                gamma_hidden=config['gamma_hidden'],
                                gamma_input=config['gamma_input'],
                                skip=3,
                                final_nonlinearity=False,
                                bias=False,
                                momentum=0.9,
                                pad_input=False,
                                batch_norm=True,
                                hidden_dilation=1,
                                laplace_padding=0,
                                input_regularizer="LaplaceL2norm")
    ro_in_shape = Encoder.get_readout_in_shape(core, in_shape)

    readout = PointPooled2d(ro_in_shape, n_neurons,
                            pool_steps=2, pool_kern=4,
                            bias=True, init_range=0.2)

    gamma_readout = 0.1

    def regularizer():
        return readout.feature_l1() * gamma_readout

    readout.regularizer = regularizer

    ## Model init
    model = Encoder(core, readout)
    r_mean = transformed_mean
    model.readout.bias.data = r_mean
    model.core.initialize()
    model.train()
    return model


def create_ensemble(n_models, seeds, train_loader, **config):

    ensemble = Ensemble(n_models, seeds, train_loader, config)

    ensemble.init_ensemble()

    return ensemble

In [63]:
from mlutils.data.datasets import StaticImageSet
from mlutils.data.transforms import Subsample, ToTensor


import numpy as np

from torch.utils.data import DataLoader, Subset
from torch.utils.data.sampler import SubsetRandomSampler
def create_dataloaders(file, batch_size):

    dat = StaticImageSet(file, 'images', 'responses')
    idx = (dat.neurons.area == 'V1') & (dat.neurons.layer =='L2/3')
    dat.transforms = [Subsample(np.where(idx)[0]), ToTensor(cuda=False)]
    
    train_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'train')[0]),
                              batch_size=batch_size)
    train_loader.img_shape = dat.img_shape
    train_loader.n_neurons = dat.n_neurons
    _, train_loader.transformed_mean = dat.transformed_mean()
    
    val_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'validation')[0]),
                              batch_size=batch_size)
    val_loader.img_shape = dat.img_shape
    val_loader.n_neurons = dat.n_neurons

    test_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'test')[0]),
                              batch_size=batch_size)
    
    test_loader.img_shape = dat.img_shape
    test_loader.n_neurons = dat.n_neurons
    
    loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

    return loaders

In [64]:
model_config = dict(dropout_p=0.5, gamma_hidden=100, gamma_input=1 )
loaders = create_dataloaders('./data/static20892-3-14-preproc0.h5', batch_size=64)



In [65]:
ens = create_ensemble(10, range(10), loaders['train'], **model_config)

In [179]:
for i,model in enumerate(ens.models):
    print(i)

0
1
2
3
4
5
6
7
8
9


In [70]:
x = torch.ones(64, 1, 32, 32)

In [181]:
ens(x).size()

torch.Size([640, 8198])

In [31]:
ens.state_dict()

OrderedDict()

In [123]:
class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 1)
        self.fc2 = nn.Linear(2, 1)
    def forward(self, x):
        return torch.cat([self.fc1(x), self.fc2(x)])
    

In [124]:
z = torch.ones(2)

In [133]:
net = TestNet()
net.fc2.weight = nn.Parameter(torch.zeros(1,2))
net.fc2.bias = nn.Parameter(torch.zeros(1))
net.fc1.weight = nn.Parameter(torch.zeros(1,2))
net.fc1.bias = nn.Parameter(torch.ones(1))

In [135]:
F.mse_loss(net(z), torch.ones(2, 1)).backward()

  """Entry point for launching an IPython kernel.


In [138]:
net.fc2.weight.grad

tensor([[-1., -1.]])

In [137]:
net(z)

tensor([1., 0.], grad_fn=<CatBackward>)

In [132]:
net.fc2.weight

Parameter containing:
tensor([[0., 0.]], requires_grad=True)

In [142]:

torch.cat([torch.ones((64, 3, 8, 8)).float()]*ens.n_models).size()

torch.Size([640, 3, 8, 8])

In [258]:
    class CatData(Dataset):
        def __init__(self, dataset, images, responses, n_models):
            super().__init__()
            self.dataset = dataset
            self.images = images
            self.responses = responses
            self.n_models = n_models

        def __getitem__(self, item):
            return self.images[item], torch.stack([torch.tensor(dat.responses[item])] * self.n_models)
        def __len__(self):
            return self.images.size(0)

In [259]:
from torch.utils.data import Dataset, DataLoader

In [260]:
dat = StaticImageSet('./data/static20892-3-14-preproc0.h5', 'images', 'responses')
idx = (dat.neurons.area == 'V1') & (dat.neurons.layer =='L2/3')

dat.transforms = [Subsample(np.where(idx)[0]), ToTensor(cuda=False)]

In [261]:
mydat = CatData(dat, dat[()].images, dat[()].responses, 10)

In [262]:
train_loader = DataLoader(mydat, batch_size=2)

In [263]:
from mlutils.measures import PoissonLoss
optimizer = torch.optim.SGD(ens.parameters(), lr=0.1)

In [264]:
criterion = PoissonLoss()

In [266]:
        for images, responses in train_loader:
            optimizer.zero_grad()
            print(responses.size())
            print(ens(images.float()).size())
            loss = criterion(ens(images.float()), responses.view(20, -1))
            loss.backward()
            optimizer.step()

torch.Size([2, 10, 8198])
torch.Size([20, 8198])
torch.Size([2, 10, 8198])
torch.Size([20, 8198])
torch.Size([2, 10, 8198])
torch.Size([20, 8198])
torch.Size([2, 10, 8198])
torch.Size([20, 8198])


KeyboardInterrupt: 

In [257]:
torch.stack([torch.tensor(dat.responses[1])] * 10).shape

torch.Size([10, 8198])

In [184]:
ens(torch.tensor(dat.images[0]).view(1, 1, 36, 64)).size()

torch.Size([10, 8198])