In [4]:
### Importing libraries ###

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.distributions.bernoulli import Bernoulli

import torchvision
from torchvision.datasets import MNIST
from torchvision.datasets import FashionMNIST
from torchvision import transforms

import matplotlib.pyplot as plt
import time

In [276]:
class PosNegMulti(object):
    def __init__(self, t_max, max_pos, max_neg, cut_th):
        self.t_max = t_max
        self.max_pos = max_pos
        self.max_neg = max_neg
        self.cut_th = cut_th
    
    def __call__(self, tensor):
        pos_bins = torch.linspace(self.max_pos, self.cut_th, steps=self.t_max + 1)
        neg_bins = torch.linspace(self.max_neg, self.cut_th, steps=self.t_max + 1)
        
        datasize = tensor.shape[0]   
        
        #first half is positive amplitude only
        #second half is negative amplitude only
        output = torch.ones(datasize*2) * float('inf')

        pos_bins = pos_bins.repeat(datasize,1)
        neg_bins = neg_bins.repeat(datasize,1)
        data_tr = tensor.reshape((datasize,1))
        
        #boundaries stay the same
        v1,i1 = ((data_tr - pos_bins) == 0).long().max(1)
        v2,i2 = ((neg_bins - data_tr) == 0).long().max(1)

        #non-boundaries go to the previous bin
        v,i3 = ((data_tr - pos_bins) > 0).long().max(1)
        v,i4 = ((neg_bins - data_tr) > 0).long().max(1)
        i3 -= 1 
        i4 -= 1
        
        i1[v1 == 0] = i3[v1 == 0]
        i2[v2 == 0] = i4[v2 == 0]
        
        output[:datasize][tensor > self.cut_th] = i1[tensor > self.cut_th].float()
        output[datasize:][tensor < -self.cut_th] =  i2[tensor < -self.cut_th].float()

        return output

In [581]:
class BinEncoding(object):
    def __init__(self, resolution, tmax, max_pos, max_neg):
        self.resolution = resolution
        self.max_pos = max_pos
        self.max_neg = max_neg
        self.tmax = tmax
        
    def __call__(self, tensor):
        datasize = tensor.shape[0]
        pos_bins = torch.linspace(0, self.max_pos, steps=self.resolution + 1)
        neg_bins = torch.linspace(0, self.max_neg, steps=self.resolution + 1)
        time_bin = torch.cat((torch.linspace(0, self.tmax, steps=self.resolution), 
                              torch.linspace(self.tmax, 0, steps=self.resolution)))
        shift    = torch.tensor([i * self.resolution*2 for i in range(datasize)])
    
        bin_index = torch.ones(datasize) * float('inf')

        output = torch.ones(datasize*self.resolution*2) * float('inf')

        pos_bins = pos_bins.repeat(datasize,1)
        neg_bins = neg_bins.repeat(datasize,1)
        data_tr = tensor.reshape((datasize,1))
        
        #values assign to the previous bins
        v,i1 = ((pos_bins - data_tr) > 0).long().max(1)
        v,i2 = ((data_tr - neg_bins) > 0).long().max(1)  
        i1 -= 1 
        i2 -= 1

        bin_index[tensor > 0] = self.resolution - 1 - i1[tensor > 0].float() 
        bin_index[tensor < 0] =  i2[tensor < 0].float() + self.resolution
        
        bin_index_shift = bin_index + shift
        
        output[bin_index_shift[bin_index_shift != float('inf')].long()] = time_bin[bin_index[bin_index != float('inf')].long()]
        return output

In [202]:
### Preprocessing ###

class PosNeg(object):
    def __init__(self, pn_threshold):
        self.pn_thresh = pn_threshold
        
    def __call__(self, tensor):
        maxt                                    = torch.max(tensor)
        tensor[tensor >= (self.pn_thresh*maxt)] = maxt
        tensor[tensor < (self.pn_thresh*maxt)]  = 0
        
        tensor                                  = tensor - maxt
        tensor[tensor == -maxt]                 = float('Inf')
        tensor_neg                              = tensor.clone()
        tensor_neg[tensor_neg == 0]             = 1
        tensor_neg[tensor_neg == float('Inf')]  = 0
        tensor_neg[tensor_neg == 1]             = float('Inf')
        
        out                    = torch.cat([tensor,tensor_neg], dim=0)
        
        return out





In [585]:
tmax = 4
b = BinEncoding(tmax+1, 4, 10, -10)

b(torch.tensor([-9]))


tensor([inf, inf, inf, inf, inf, inf, inf, inf, inf, 0.])

In [614]:
w = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12]])
intimes = torch.tensor([1.0,2.0, 0.0])
outtimes = torch.tensor([1.0,5.0,6.0, 0.0])
intimes         = torch.flatten(intimes)
q               = outtimes.shape[0]
p               = intimes.shape[0]
ec_in           = intimes.repeat(q,1)
li_out          = outtimes.repeat(p,1).permute(1,0)

print(w[(ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out)])
print((ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out))

print(w)
w[:,1:][((ec_in!=float('Inf'))*(li_out!=float('Inf'))*(ec_in<=li_out))[:,1:]]

tensor([ 1,  3,  4,  5,  6,  7,  8,  9, 12])
tensor([[ True, False,  True],
        [ True,  True,  True],
        [ True,  True,  True],
        [False, False,  True]])
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])


tensor([ 3,  5,  6,  8,  9, 12])