In [1]:
device = "cpu"
import torch
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

In [2]:
import h5py
with h5py.File("tactmat.h5", 'r') as dataset:
    samples = dataset['samples'][:]  # Shape: [materials, samples, time_steps, taxels_x, taxels_y]
    materials = dataset['materials'][:]

In [17]:
import numpy as np
from torch.utils.data import Dataset

class TactileMaterialDataset(Dataset):
    def __init__(self, file_path):

        with h5py.File(file_path, 'r') as dataset:
            samples = dataset['samples'][:]/154.
            materials = dataset['materials'][:]
        
        # Only the sample 1 for now
        self.samples = np.fft.fft(samples[1,:,:,:,:].reshape(100,1000,16), axis = 1).real
        self.samples = torch.tensor(self.samples[:,30:970,:]).float()
        self.labels = torch.tensor(np.repeat(range(len(materials)), 100))
        
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx].to(device), self.labels[idx].to(device)

In [18]:
from torch.utils.data import DataLoader
from tqdm import tqdm
train_dataset = TactileMaterialDataset("tactmat.h5")

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [7]:
import torch
import torch.nn as nn

# FrEIA imports
import FrEIA.framework as Ff
import FrEIA.modules as Fm


N_DIM = 16 *940

def subnet_fc(dims_in, dims_out):
    return nn.Sequential(nn.Linear(dims_in, 128), nn.ReLU(),
                         nn.Linear(32,  dims_out))

inn = Ff.SequenceINN(N_DIM).to(device)
for k in range(4):
    inn.append(Fm.AllInOneBlock, subnet_constructor=subnet_fc, permute_soft=True)

optimizer = torch.optim.Adam(inn.parameters(), lr=0.001)

for i in range(20):
    optimizer.zero_grad()
    for data in tqdm(train_loader):
        x = torch.Tensor(data.reshape(940,16).float())
    
        z,_, log_jac_det = inn(x)
    
        loss = 0.5*torch.sum(z**2, 1) - log_jac_det
        loss = loss.mean() / N_DIM
    
        loss.backward()
        optimizer.step()


KeyboardInterrupt: 

In [None]:
z = torch.randn(1, 2)
samples_pred, _ = inn(z, rev=True)

In [33]:
import torch
import torch.nn as nn
import torch.optim

import FrEIA.framework as Ff
import FrEIA.modules as Fm

ndim_total = 940 * 16

def one_hot(labels, out=None):
    '''
    Convert LongTensor labels (contains labels 0-9), to a one hot vector.
    Can be done in-place using the out-argument (faster, re-use of GPU memory)
    '''
    if out is None:
        out = torch.zeros(labels.shape[0], 36).to(labels.device)
    else:
        out.zeros_()

    out.scatter_(dim=1, index=labels.view(-1,1), value=1.)
    return out

class MNIST_cINN(nn.Module):
    '''cINN for class-conditional MNISt generation'''
    def __init__(self, lr):
        super().__init__()

        self.cinn = self.build_inn()

        self.trainable_parameters = [p for p in self.cinn.parameters() if p.requires_grad]
        for p in self.trainable_parameters:
            p.data = 0.01 * torch.randn_like(p)

        self.optimizer = torch.optim.Adam(self.trainable_parameters, lr=lr, weight_decay=1e-5)

    def build_inn(self):

        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512),
                                 nn.ReLU(),
                                 nn.Linear(512, ch_out))

        cond = Ff.ConditionNode(36)
        nodes = [Ff.InputNode(1, 940, 16)]

        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}))

        for k in range(5):
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom , {'seed':k}))
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=cond))

        return Ff.ReversibleGraphNet(nodes + [cond, Ff.OutputNode(nodes[-1])], verbose=False)

    def forward(self, x, l):
        z,jac = self.cinn(x, c=one_hot(l), jac=True)
        return z, jac

    def reverse_sample(self, z, l):
        return self.cinn(z, c=one_hot(l), rev=True)

In [34]:
cinn = MNIST_cINN(5e-3).to(device)

In [35]:
from tqdm import tqdm

In [37]:
optimizer = torch.optim.Adam(cinn.parameters(), lr=0.001)
nll_mean = []
for epoch in range(20):
    optimizer.zero_grad()
    for (x, l) in tqdm(train_loader):
        x, l = x.cuda(), l.cuda()
        z, log_j = cinn(x, l)

        nll = torch.mean(z**2) / 2 - torch.mean(log_j) / ndim_total
        nll.backward()
        nll_mean.append(nll.item())
        optimizer.step()

100%|██████████| 100/100 [00:05<00:00, 19.44it/s]
100%|██████████| 100/100 [00:05<00:00, 19.34it/s]
100%|██████████| 100/100 [00:05<00:00, 19.31it/s]
100%|██████████| 100/100 [00:05<00:00, 19.35it/s]
100%|██████████| 100/100 [00:05<00:00, 19.35it/s]
100%|██████████| 100/100 [00:05<00:00, 19.23it/s]
100%|██████████| 100/100 [00:05<00:00, 19.21it/s]
100%|██████████| 100/100 [00:05<00:00, 19.29it/s]
100%|██████████| 100/100 [00:05<00:00, 19.31it/s]
100%|██████████| 100/100 [00:05<00:00, 19.32it/s]
100%|██████████| 100/100 [00:05<00:00, 19.24it/s]
100%|██████████| 100/100 [00:05<00:00, 19.35it/s]
100%|██████████| 100/100 [00:05<00:00, 19.32it/s]
100%|██████████| 100/100 [00:05<00:00, 19.35it/s]
100%|██████████| 100/100 [00:05<00:00, 19.38it/s]
100%|██████████| 100/100 [00:05<00:00, 19.37it/s]
100%|██████████| 100/100 [00:05<00:00, 19.38it/s]
100%|██████████| 100/100 [00:05<00:00, 19.22it/s]
100%|██████████| 100/100 [00:05<00:00, 19.36it/s]
100%|██████████| 100/100 [00:05<00:00, 19.34it/s]
