In [1]:
# https://stackoverflow.com/questions/45113245/how-to-get-mini-batches-in-pytorch-in-a-clean-and-efficient-way

In [2]:
import numpy as np
import pandas as pd
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader

from sklearn.preprocessing import LabelEncoder

In [15]:
SESSION = 0
SPLIT = 128

EPOCHS = 1000
BATCH_SIZE = 128
INPUT_SIZE = 59
ENCODE_DIM = 8
INSIDE_BATCH = False

BASE_FOLDER = Path('../data')
SESSIONS = {0: 22, 1: 153, 2: 153}

device = torch.device("cpu")

In [16]:
class GaitDataset(Dataset):
    def __init__(self, filename):
        df = pd.read_csv(BASE_FOLDER.joinpath(Path(filename)), header=None)
        y = df[df.columns[-1]].values
        df.drop([df.columns[-1]], axis=1, inplace=True)
        y = LabelEncoder().fit_transform(y)
        
        self.Xdata = df
        self.Ydata = y
        
    def __len__(self):
        return len(self.Xdata)
    
    def __getitem__(self, index):
        vector = self.Xdata.iloc[index, :].values.astype(np.float32)
        label  = self.Ydata[index]
        
        return vector, label

In [17]:
def get_csv(session, split):
    return 'zju_gaitaccel_session_' + str(session) + '_' + str(split) + '.csv'

In [18]:
train_dataset = GaitDataset(get_csv(SESSION, SPLIT))
dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
lossloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [19]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(INPUT_SIZE, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, ENCODE_DIM),
            nn.ReLU(True))

        self.decoder = nn.Sequential(             
            nn.Linear(ENCODE_DIM, 16),
            nn.ReLU(True),
            nn.Linear(16, 32),
            nn.ReLU(True),
            nn.Linear(32, INPUT_SIZE),
            nn.Sigmoid())

    def forward(self,x):
        x_enc = self.encoder(x)
        x_dec = self.decoder(x_enc)
        return x_dec, x_enc

In [23]:
class SeparatorLoss(nn.Module):
    def __init__(self, loader, encoder):
        super(SeparatorLoss, self).__init__()
        self.loader = loader
        self.loader_iter = iter(loader)
        self.encoder = encoder
        self.pdist = nn.PairwiseDistance(p=2)
        
    def forward(self, x_pred, x_true, encoded, labels):    
        if INSIDE_BATCH:
            # Option A: inside current batch
            
            # Prepare similarity
            sep = torch.pdist(encoded, 2)
        else:
            # Option B: with other batch
        
            # Get a batch from the same dataset
            try:
                batch_X, batch_y = next(self.loader_iter)
            except StopIteration:
                self.loader_iter = iter(self.loader)
                batch_X, batch_y = next(self.loader_iter)
                
            # Encode it
            batch_encoded = self.encoder(batch_X)

            # Prepare same class vector
            X_labels = labels.view(-1, 1).repeat(1, BATCH_SIZE).view(1, BATCH_SIZE * BATCH_SIZE)
            batch_labels = batch_y.view(-1, 1).repeat(1, BATCH_SIZE).view(1, BATCH_SIZE * BATCH_SIZE)
            same = (X_labels == batch_labels).to(device, dtype=torch.float32)
            same = same * SESSIONS[SESSION] - 1
                
            # Prepare/repeat matrices
            dist_X = encoded.repeat(1, BATCH_SIZE).view(-1, ENCODE_DIM)
            dist_batch = batch_encoded.expand(BATCH_SIZE, BATCH_SIZE, ENCODE_DIM).reshape(-1, ENCODE_DIM)

            # Calculate pairwise distances
            sep = (self.pdist(dist_X, dist_batch) * same).sum() / BATCH_SIZE
 
        mse = F.mse_loss(x_pred, x_true)
#         print('sep: {:.4f}, mse: {:.4f}'.format(sep, mse))
        return sep + mse

In [24]:
model = Autoencoder().cpu()
# distance = nn.MSELoss()
distance = SeparatorLoss(lossloader, model.encoder)
# optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-5)
optimizer = torch.optim.SGD(model.parameters(), lr = 0.001, momentum = 0.9)

# print(model)

In [25]:
%%time
for epoch in range(EPOCHS):
    for data in dataloader:
        vec, labels = data
        vec = Variable(vec).cpu()
        # ===================forward=====================
        dec, enc = model(vec)
        loss = distance(dec, vec, enc, labels)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss: {:.4f}'.format(epoch + 1, EPOCHS, loss.item()))

epoch [1/1000], loss: -1.5027
epoch [2/1000], loss: 0.2008
epoch [3/1000], loss: -3.5080
epoch [4/1000], loss: 8.2182
epoch [5/1000], loss: -8.8318
epoch [6/1000], loss: 18.2291
epoch [7/1000], loss: 0.1820
epoch [8/1000], loss: -1.3934
epoch [9/1000], loss: 0.3211
epoch [10/1000], loss: -0.1066
epoch [11/1000], loss: 1.7930
epoch [12/1000], loss: -0.0880
epoch [13/1000], loss: 0.7170
epoch [14/1000], loss: 0.4973
epoch [15/1000], loss: 0.1754
epoch [16/1000], loss: -0.0485
epoch [17/1000], loss: 0.7357
epoch [18/1000], loss: 0.0204
epoch [19/1000], loss: 0.3859
epoch [20/1000], loss: 0.7475
epoch [21/1000], loss: 0.0345
epoch [22/1000], loss: 0.9075
epoch [23/1000], loss: 1.2093
epoch [24/1000], loss: 1.5316
epoch [25/1000], loss: -0.3024
epoch [26/1000], loss: 0.0169
epoch [27/1000], loss: 0.8899
epoch [28/1000], loss: -0.0367
epoch [29/1000], loss: 0.3610
epoch [30/1000], loss: 0.3663
epoch [31/1000], loss: 0.5559
epoch [32/1000], loss: 0.4449
epoch [33/1000], loss: 0.3047
epoch [34

epoch [268/1000], loss: 0.2902
epoch [269/1000], loss: 0.2853
epoch [270/1000], loss: 0.2956
epoch [271/1000], loss: 0.2830
epoch [272/1000], loss: 0.2842
epoch [273/1000], loss: 0.2850
epoch [274/1000], loss: 0.2885
epoch [275/1000], loss: 0.2876
epoch [276/1000], loss: 0.2864
epoch [277/1000], loss: 0.2894
epoch [278/1000], loss: 0.2884
epoch [279/1000], loss: 0.2868
epoch [280/1000], loss: 0.2907
epoch [281/1000], loss: 0.2864
epoch [282/1000], loss: 0.2815
epoch [283/1000], loss: 0.2845
epoch [284/1000], loss: 0.2844
epoch [285/1000], loss: 0.2830
epoch [286/1000], loss: 0.2919
epoch [287/1000], loss: 0.2858
epoch [288/1000], loss: 0.2846
epoch [289/1000], loss: 0.2829
epoch [290/1000], loss: 0.2798
epoch [291/1000], loss: 0.2795
epoch [292/1000], loss: 0.2821
epoch [293/1000], loss: 0.2831
epoch [294/1000], loss: 0.2815
epoch [295/1000], loss: 0.2856
epoch [296/1000], loss: 0.2807
epoch [297/1000], loss: 0.2772
epoch [298/1000], loss: 0.2847
epoch [299/1000], loss: 0.2838
epoch [3

epoch [533/1000], loss: 0.1997
epoch [534/1000], loss: 0.2061
epoch [535/1000], loss: 0.2026
epoch [536/1000], loss: 0.2033
epoch [537/1000], loss: 0.1987
epoch [538/1000], loss: 0.1979
epoch [539/1000], loss: 0.2021
epoch [540/1000], loss: 0.2050
epoch [541/1000], loss: 0.1953
epoch [542/1000], loss: 0.2012
epoch [543/1000], loss: 0.1983
epoch [544/1000], loss: 0.2008
epoch [545/1000], loss: 0.1958
epoch [546/1000], loss: 0.1969
epoch [547/1000], loss: 0.1978
epoch [548/1000], loss: 0.1983
epoch [549/1000], loss: 0.2032
epoch [550/1000], loss: 0.1958
epoch [551/1000], loss: 0.1940
epoch [552/1000], loss: 0.1972
epoch [553/1000], loss: 0.1924
epoch [554/1000], loss: 0.1927
epoch [555/1000], loss: 0.1975
epoch [556/1000], loss: 0.1933
epoch [557/1000], loss: 0.1937
epoch [558/1000], loss: 0.1898
epoch [559/1000], loss: 0.1910
epoch [560/1000], loss: 0.1898
epoch [561/1000], loss: 0.1949
epoch [562/1000], loss: 0.1931
epoch [563/1000], loss: 0.1931
epoch [564/1000], loss: 0.1929
epoch [5

epoch [798/1000], loss: 0.1238
epoch [799/1000], loss: 0.1241
epoch [800/1000], loss: 0.1243
epoch [801/1000], loss: 0.1222
epoch [802/1000], loss: 0.1204
epoch [803/1000], loss: 0.1244
epoch [804/1000], loss: 0.1269
epoch [805/1000], loss: 0.1252
epoch [806/1000], loss: 0.1215
epoch [807/1000], loss: 0.1184
epoch [808/1000], loss: 0.1215
epoch [809/1000], loss: 0.1242
epoch [810/1000], loss: 0.1204
epoch [811/1000], loss: 0.1224
epoch [812/1000], loss: 0.1225
epoch [813/1000], loss: 0.1234
epoch [814/1000], loss: 0.1217
epoch [815/1000], loss: 0.1192
epoch [816/1000], loss: 0.1215
epoch [817/1000], loss: 0.1186
epoch [818/1000], loss: 0.1197
epoch [819/1000], loss: 0.1239
epoch [820/1000], loss: 0.1221
epoch [821/1000], loss: 0.1230
epoch [822/1000], loss: 0.1201
epoch [823/1000], loss: 0.1201
epoch [824/1000], loss: 0.1226
epoch [825/1000], loss: 0.1212
epoch [826/1000], loss: 0.1199
epoch [827/1000], loss: 0.1178
epoch [828/1000], loss: 0.1220
epoch [829/1000], loss: 0.1188
epoch [8