In [1]:
%matplotlib notebook
from time import time

import torch
import torch.nn as nn
import torch.optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from FrEIA.framework import InputNode, OutputNode, Node, ReversibleGraphNet
from FrEIA.modules import GLOWCouplingBlock, PermuteRandom, AffineCouplingOneSided

import data

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
ndim_tot = 40
ndim_x = 2
ndim_y = 36
ndim_z = 2

def subnet_fc(c_in, c_out):
    return nn.Sequential(nn.Linear(c_in, 512), nn.ReLU(),
                         nn.Linear(512,  c_out))

nodes = [InputNode(ndim_tot, name='input')]

for k in range(8):
    nodes.append(Node(nodes[-1],
                      GLOWCouplingBlock,
                      {'subnet_constructor':subnet_fc, 'clamp':2.0},
                      name=F'coupling_{k}'))
    nodes.append(Node(nodes[-1],
                      PermuteRandom,
                      {'seed':k},
                      name=F'permute_{k}'))

nodes.append(OutputNode(nodes[-1], name='output'))

model = ReversibleGraphNet(nodes, verbose=False)

In [3]:
import torch.optim
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim
import torch.nn.functional as F

class TactileMaterialDataset(Dataset):
    def __init__(self, file_path, train = True):

        with h5py.File(file_path, 'r') as dataset:
            raw_samples = dataset['samples'][:] / 154.  # Normalize
            materials = [i.decode() for i in dataset['materials'][:]]  # Decode material names
        
        self.samples = raw_samples.reshape(36, 100, 16, 1000)
        
        
        if(train):
            self.samples = self.samples[:,:30,:,:]
            # self.samples = self.samples[:,:80,:,:].reshape(-1, 1000)
            self.labels = torch.tensor(np.repeat(range(len(materials)), 40*16*1000)) 
        else:
            
            self.samples = self.samples[:,95:,:,:].reshape(-1, 1000)
            self.labels = torch.tensor(np.repeat(range(len(materials)), 10*16*1000))             
            
        # Apply FFT along the last axis and separate real & imaginary parts
        fft_transformed = np.fft.fft(self.samples, axis=1)
        self.samples = np.stack((fft_transformed.real, fft_transformed.imag), axis=0) 
        self.samples = self.samples.reshape(-1, 1000, 2).reshape(-1, 2)

        self.samples = torch.tensor(self.samples)
        
        self.labels = F.one_hot(torch.tensor(self.labels), num_classes=36)
            
        # np.fft.fft
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.labels[idx]

In [4]:
# Training parameters
n_epochs = 50
n_its_per_epoch = 8
batch_size = 1000

lr = 1e-3
l2_reg = 2e-5

y_noise_scale = 1e-1
zeros_noise_scale = 5e-2

# relative weighting of losses:
lambd_predict = 3.
lambd_latent = 300.
lambd_rev = 400.


trainable_parameters = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(trainable_parameters, lr=lr, betas=(0.8, 0.9),
                             eps=1e-6, weight_decay=l2_reg)


def MMD_multiscale(x, y):
    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))

    dxx = rx.t() + rx - 2.*xx
    dyy = ry.t() + ry - 2.*yy
    dxy = rx.t() + ry - 2.*zz

    XX, YY, XY = (torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device),
                  torch.zeros(xx.shape).to(device))

    for a in [0.05, 0.2, 0.9]:
        XX += a**2 * (a**2 + dxx)**-1
        YY += a**2 * (a**2 + dyy)**-1
        XY += a**2 * (a**2 + dxy)**-1

    return torch.mean(XX + YY - 2.*XY)


def fit(input, target):
    return torch.mean((input - target)**2)

loss_backward = MMD_multiscale
loss_latent = MMD_multiscale
loss_fit = fit

In [5]:
train_loader = torch.utils.data.DataLoader(TactileMaterialDataset("tactmat.h5"), batch_size=1000, shuffle=False)

  self.labels = F.one_hot(torch.tensor(self.labels), num_classes=36)


In [6]:
test_loader = torch.utils.data.DataLoader(TactileMaterialDataset("tactmat.h5", False), batch_size=1000, shuffle=False)

  self.labels = F.one_hot(torch.tensor(self.labels), num_classes=36)


In [7]:
def train(i_epoch=0):
    model.train()

    l_tot = 0
    batch_idx = 0
    
    t_start = time()
    
    # If MMD on x-space is present from the start, the model can get stuck.
    # Instead, ramp it up exponetially.  
    loss_factor = min(1., 2. * 0.002**(1. - (float(i_epoch) / n_epochs)))

    for x, y in train_loader:
        batch_idx += 1
        if batch_idx > n_its_per_epoch:
            break

        x, y = x.float().to(device), y.float().to(device)
        
        y_clean = y.clone()
        pad_x = zeros_noise_scale * torch.randn(batch_size, ndim_tot -
                                                ndim_x, device=device)
        pad_yz = zeros_noise_scale * torch.randn(batch_size, ndim_tot -
                                                 ndim_y - ndim_z, device=device)

        y += y_noise_scale * torch.randn(batch_size, ndim_y, dtype=torch.float, device=device)

        x, y = (torch.cat((x, pad_x),  dim=1),
                torch.cat((torch.randn(batch_size, ndim_z, device=device), pad_yz, y),
                          dim=1))
        

        optimizer.zero_grad()
        print(x.shape)
        # Forward step:
        output = model(x)[0]
        print("ANAN")
        # Shorten output, and remove gradients wrt y, for latent loss
        y_short = torch.cat((y[:, :ndim_z], y[:, -ndim_y:]), dim=1)

        l = lambd_predict * loss_fit(output[:, ndim_z:], y[:, ndim_z:])

        output_block_grad = torch.cat((output[:, :ndim_z],
                                       output[:, -ndim_y:].data), dim=1)

        l += lambd_latent * loss_latent(output_block_grad, y_short)
        l_tot += l.data.item()

        l.backward()

        # Backward step:
        pad_yz = zeros_noise_scale * torch.randn(batch_size, ndim_tot -
                                                 ndim_y - ndim_z, device=device)
        y = y_clean + y_noise_scale * torch.randn(batch_size, ndim_y, device=device)

        orig_z_perturbed = (output.data[:, :ndim_z] + y_noise_scale *
                            torch.randn(batch_size, ndim_z, device=device))
        y_rev = torch.cat((orig_z_perturbed, pad_yz,
                           y), dim=1)
        y_rev_rand = torch.cat((torch.randn(batch_size, ndim_z, device=device), pad_yz,
                                y), dim=1)
        
        output_rev = model(y_rev, rev=True)[0]
        output_rev_rand = model(y_rev_rand, rev=True)[0]

        l_rev = (
            lambd_rev
            * loss_factor
            * loss_backward(output_rev_rand[:, :ndim_x],
                            x[:, :ndim_x])
        )

        l_rev += lambd_predict * loss_fit(output_rev, x)
        
        l_tot += l_rev.data.item()
        l_rev.backward()

        for p in model.parameters():
            if(p.grad is not None):
                p.grad.data.clamp_(-15.00, 15.00)

        optimizer.step()

    return l_tot / batch_idx

In [8]:
for param in trainable_parameters:
    param.data = 0.05*torch.randn_like(param)
            
model.to(device)

fig, axes = plt.subplots(1, 2, figsize=(8,4))
axes[0].set_xticks([])
axes[0].set_yticks([])
axes[0].set_title('Predicted labels (Forwards Process)')
axes[1].set_xticks([])
axes[1].set_yticks([])
axes[1].set_title('Generated Samples (Backwards Process)')
fig.show()
fig.canvas.draw()

N_samp = int(5760000/2)

x_samps = torch.cat([x for x,y in test_loader], dim=0)[:N_samp].float()
y_samps = torch.cat([y for x,y in test_loader], dim=0)[:N_samp].float()
c = np.where(y_samps)[1]
y_samps += y_noise_scale * torch.randn(N_samp, ndim_y)
y_samps = torch.cat([torch.randn(N_samp, ndim_z),
                     zeros_noise_scale * torch.zeros(N_samp, ndim_tot - ndim_y - ndim_z), 
                     y_samps], dim=1)
y_samps = y_samps.to(device)
            
try:
    t_start = time()
    for i_epoch in tqdm(range(n_epochs), ascii=True, ncols=80):

        train(i_epoch)

        rev_x = model(y_samps, rev=True)[0]
        rev_x = rev_x.cpu().data.numpy()
        
        pred_c = model(torch.cat((x_samps, torch.zeros(N_samp, ndim_tot - ndim_x)),
                                 dim=1).to(device)).data[:, -8:].argmax(dim=1)

        axes[0].clear()
        axes[0].scatter(x_samps.cpu()[:,0], x_samps.cpu()[:,1], c=pred_c.cpu(), cmap='Set1', s=1., vmin=0, vmax=9)
        axes[0].axis('equal')
        axes[0].axis([-3,3,-3,3])
        axes[0].set_xticks([])
        axes[0].set_yticks([])

        axes[1].clear()
        axes[1].scatter(rev_x[:,0], rev_x[:,1], c=c, cmap='Set1', s=1., vmin=0, vmax=9)
        axes[1].axis('equal')
        axes[1].axis([-3,3,-3,3])
        axes[1].set_xticks([])
        axes[1].set_yticks([])
        
        fig.canvas.draw()


except KeyboardInterrupt:
    pass
finally:
    print(f"\n\nTraining took {(time()-t_start)/60:.2f} minutes\n")

<IPython.core.display.Javascript object>

  0%|                                                    | 0/50 [00:00<?, ?it/s]

torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN
torch.Size([1000, 40])
ANAN


  0%|                                                    | 0/50 [00:01<?, ?it/s]




Training took 0.02 minutes



RuntimeError: Node 'coupling_7': [(40,)] -> GLOWCouplingBlock -> [(40,)] encountered an error.