In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import torchvision
import fk
import h5py
import matplotlib.pyplot as plt

In [40]:
import os 
import numpy as np
class FkDataset():
    def __init__(self, root, n_frames_in=5, n_frames_out=10, step=1, keys=None):
        self.root = root
        self.n_frames_in = n_frames_in
        self.n_frames_out = n_frames_out
        self.n_frames = n_frames_in + n_frames_out
        self.step = step
        
        filenames = [os.path.join(root, name) for name in sorted(os.listdir(root))]
        if keys is not None:
            filenames = [name for name in filenames 
                 if os.path.basename(name) in keys ]
        self.datasets = [Simulation(filename) for filename in filenames]
        self.cumulative_sizes = self.cumsum(self.datasets)
        
    @staticmethod
    def cumsum(sequence):
        r, s = [], 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx = len(self) + idx
        dataset_idx = torch.utils.data.dataset.bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        start = sample_idx
        end = start + self.n_frames_in + self.n_frames_out
        if end > len(self):
            return None
        return self.datasets[dataset_idx][start:end:self.step]

    def close(self):
        for dataset in self.datasets:
            dataset.states.file.close()
    
class Simulation():
    def __init__(self, filename):
        self.filename = filename
        file = h5py.File(filename, "r")
        self.states = file["states"]
        self.stimuli = fk.io.load_stimuli(file)
        for i in range(len(self.stimuli)):
            self.stimuli[i]["field"] = torch.as_tensor(self.stimuli[i]["field"])
        self.shape = self.states.shape[-2:]
    
    def __getitem__(self, idx):
        states = self.states[idx]
        unstimulated = np.zeros(self.states.shape[-2:])
        stimuli = torch.stack([self.stimulus_at_t(t) for t in range(idx.start, idx.stop, idx.step)])
        return states, stimuli
    
    def __len__(self):
        return len(self.states)
    
    def stimulus_at_t(self, t):
        stimulated = torch.zeros(self.shape)
        for stimulus in self.stimuli:
            active = t >= stimulus["start"]
            active &= ((stimulus["start"] - t + 1) % stimulus["period"]) < stimulus["duration"]
            stimulated = torch.where(stimulus["field"] * (active) > 0, stimulus["field"], stimulated)
        return stimulated

In [68]:
import torch
from torch import nn


class DeepExcite(nn.Module):
    def __init__(self, n_states_in=5, n_states_out=10):
        super(DeepExcite, self).__init__()
        self.states_encoder = Encoder(n_states_in)
        self.stimuli_encoder = Encoder(n_states_in)
        self.states_decoder = Decoder(64, n_states_out)
        
    def forward(self, X_state, X_stim, Y_state=None):
        states = self.states_encoder(X_state)
        stimuli = self.stimuli_encoder(X_stim)
        latent = torch.cat([states, stimuli], dim=1)
        new_states = self.states_decoder(latent)
        if Y_state is not None:
            loss = torch.nn.functional.mse_loss(new_states, Y_state)
            return new_states, loss
        return new_states

class StochasticInference(nn.Module):
    def __init__(self, n_states_in, n_states_out):
        super(StochasticInference, self).__init__()
        self.u_encoder = Encoder(n_states_in, n_states_out)
        self.v_encoder = Encoder(n_states_in, n_states_out)
        self.w_encoder = Encoder(n_states_in, n_states_out)
        self.mu = nn.Linear(32, 1)
        self.logvar = nn.Linear(32, 1)

        # init with normal weights
        torch.nn.init.normal(self.mu)
        torch.nn.init.normal(self.logvar)
        return

    @staticmethod
    def kl_divergence(z, mu, logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    @staticmethod
    def reparameterize(mu, logvar):
        sigma = torch.exp(0.5 * logvar)
        eps = torch.randn_like(sigma)
        return mu + eps * sigma

    def get_loss(self, z, mu, logvar):
        loss = self.kl_divergence(z, mu, logvar)
        return loss

    def forward(self, X):
        X = self.encode(X)
        mu = self.mu(X)
        logvar = self.logvar(X)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar


class Encoder(nn.Module):
    def __init__(self, channels_in):
        super(Encoder, self).__init__()
        self.encoder = nn.ModuleList([
            nn.Conv2d(channels_in, 8, kernel_size=(3, 3), stride=(2, 2)),
            nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2)),
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2)),
            ])

    def forward(self, X):
        for module in self.encoder:
            X = module(X)
        return X


class Decoder(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(Decoder, self).__init__()
        self.decoder = nn.ModuleList([
            nn.ConvTranspose2d(channels_in, 16, kernel_size=(3, 3), stride=(2, 2)),
            nn.ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(2, 2)),
            nn.ConvTranspose2d(8, channels_out, kernel_size=(3, 3), stride=(2, 2))
            ])

    def forward(self, X):
        for module in self.encoder:
            X = module(X)
        return X


In [69]:
root = "/media/ep119/DATADRIVE3/epignatelli/deepexcite/train_dev_set"
dataset = FkDataset(root, keys=["spiral_params5.hdf5"])
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

In [70]:
model = DeepExcite()

In [71]:
for sample in loader:
    states, stimuli = sample
    X_state = states[:5]
    X_stim = stimuli
    Y_state = states[5:]
    y_hat, loss = model(X_state, X_stim, Y_state)
    break

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [8, 5, 3, 3], but got 5-dimensional input of size [4, 15, 3, 1200, 1200] instead