In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [44]:
from mlutils.layers.readouts import PointPooled2d
from mlutils.layers.cores import Stacked2dCoreDropOut

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import Parameter


class Encoder(nn.Module):

    def __init__(self, core, readout):
        super().__init__()
        self.core = core
        self.readout = readout

    @staticmethod
    def get_readout_in_shape(core, shape):
        train_state = core.training
        core.eval()
        tmp = torch.Tensor(*shape).normal_()
        nout = core(tmp).size()[1:]
        core.train(train_state)
        return nout

    def forward(self, x):
        x = self.core(x)
        x = self.readout(x)
        return F.elu(x) + 1


class Ensemble(nn.Module):
    def __init__(self,train_loader, config):
        super().__init__()
        self.seeds = config['seeds']
        del config['seeds']
        self.config = config
        self.train_loader = train_loader

        self.model0 = create_model(self.train_loader, self.seeds[0],  **self.config)#.to('cuda:0')
        self.model1 = create_model(self.train_loader, self.seeds[1], **self.config)#.to('cuda:1')

    def forward(self, x):
        out = [self.model0(x), self.model1(x)]
        return out, torch.cat(out)

    @staticmethod
    def diff_real(preds, true):
        return F.mse_loss(preds, true, reduction='mean')

    @staticmethod
    def variance(preds):
        return preds.std(dim=0)


def create_model(train_loader, seed=0, **config):
    np.random.seed(seed)
    in_shape = train_loader.img_shape
    n_neurons = train_loader.n_neurons
    transformed_mean = train_loader.transformed_mean
    core = Stacked2dCoreDropOut(input_channels=1,
                                hidden_channels=32,
                                input_kern=15,
                                hidden_kern=7,
                                dropout_p=config['dropout_p'],
                                layers=3,
                                gamma_hidden=config['gamma_hidden'],
                                gamma_input=config['gamma_input'],
                                skip=3,
                                final_nonlinearity=False,
                                bias=False,
                                momentum=0.9,
                                pad_input=False,
                                batch_norm=True,
                                hidden_dilation=1,
                                laplace_padding=0,
                                input_regularizer="LaplaceL2norm")
    ro_in_shape = Encoder.get_readout_in_shape(core, in_shape)

    readout = PointPooled2d(ro_in_shape, n_neurons,
                            pool_steps=2, pool_kern=4,
                            bias=True, init_range=0.2)

    gamma_readout = 0.1

    def regularizer():
        return readout.feature_l1() * gamma_readout

    readout.regularizer = regularizer

    ## Model init
    model = Encoder(core, readout)
    r_mean = transformed_mean
    model.readout.bias.data = r_mean
    model.core.initialize()
    model.train()
    return model


def create_ensemble(train_loader, **config):

    ensemble = Ensemble(train_loader, config)

    return ensemble

In [45]:
from mlutils.data.datasets import StaticImageSet
from mlutils.data.transforms import Subsample, ToTensor


import numpy as np

from torch.utils.data import DataLoader, Subset
from torch.utils.data.sampler import SubsetRandomSampler
def create_dataloaders(file, batch_size):

    dat = StaticImageSet(file, 'images', 'responses')
    idx = (dat.neurons.area == 'V1') & (dat.neurons.layer =='L2/3')
    dat.transforms = [Subsample(np.where(idx)[0]), ToTensor(cuda=False)]
    
    train_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'train')[0]),
                              batch_size=batch_size)
    train_loader.img_shape = dat.img_shape
    train_loader.n_neurons = dat.n_neurons
    _, train_loader.transformed_mean = dat.transformed_mean()
    
    val_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'validation')[0]),
                              batch_size=batch_size)
    val_loader.img_shape = dat.img_shape
    val_loader.n_neurons = dat.n_neurons

    test_loader = DataLoader(dat,
                              sampler=SubsetRandomSampler(np.where(dat.tiers == 'test')[0]),
                              batch_size=batch_size)
    
    test_loader.img_shape = dat.img_shape
    test_loader.n_neurons = dat.n_neurons
    
    loaders = {'train': train_loader, 'val': val_loader, 'test': test_loader}

    return loaders

In [46]:
model_config = dict(dropout_p=0.5, gamma_hidden=100, gamma_input=1 )
model_config['seeds'] =[5, 6]

loaders = create_dataloaders('./data/static20892-3-14-preproc0.h5', batch_size=64)

In [47]:
ens = create_ensemble(loaders['train'], **model_config)

In [48]:
from torch.utils.data import Dataset, DataLoader

In [49]:
    class CatData(Dataset):
        def __init__(self, dataset, images, responses, n_gpu):
            super().__init__()
            self.dataset = dataset
            self.images = images
            self.responses = responses
            self.n_gpu = n_gpu
            

        def __getitem__(self, item):
            resp = torch.stack(
                       [torch.tensor(dat.responses[item]).to('cuda:{}'.format(i))
                                                      for i in range(self.n_gpu)])
            return self.images[item], resp
        def __len__(self):
            return self.images.size(0)

In [50]:
dat = StaticImageSet('./data/static20892-3-14-preproc0.h5', 'images', 'responses')
idx = (dat.neurons.area == 'V1') & (dat.neurons.layer =='L2/3')

dat.transforms = [Subsample(np.where(idx)[0]), ToTensor(cuda=False)]

In [51]:
mydat = CatData(dat, dat[()].images, dat[()].responses, 1)

In [52]:
train_loader = DataLoader(mydat, batch_size=2)

In [53]:
from mlutils.measures import PoissonLoss
optimizer = torch.optim.SGD(ens.parameters(), lr=0.1)

In [75]:
for batch, labels in train_loader:
    _, echt = ens(batch)  
    print(F.mse_loss(echt, torch.cat([sep[0], sep[1]])))
    break

torch.Size([2, 1, 36, 64])
tensor(0.0158, grad_fn=<MeanBackward0>)


In [61]:
sep[1].size()

torch.Size([2, 8198])

In [65]:
cata[:2] ==sep[0]

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])

In [67]:
import torch.nn.functional as F

In [70]:
F.mse_loss(echt, torch.cat([sep[0], sep[1]]))


tensor(0.0114, grad_fn=<MeanBackward0>)

In [72]:
torch.cat([sep[0], sep[1]]) == cata

tensor([[True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True],
        [True, True, True,  ..., True, True, True]])