Importing Modules

In [1]:
import torch
import torchvision
import torch.nn as nn
import matplotlib
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.optim as optim
import os
import time
import numpy as np
import argparse
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
matplotlib.style.use('ggplot')
import sys, importlib as impL
sys.path.insert(1,'/home/wsubuntu/GitHub/keyhandshapediscovery')
import helperFuncs as funcH
import pandas as pd

In [2]:
def calc_bottleneck_acc(bottleneck_vec, lab_vec):
    pred_vec = np.argmax(bottleneck_vec.T, axis=0).T.squeeze()
    centroid_info_pdf = funcH.get_cluster_centroids(bottleneck_vec, pred_vec, kluster_centers=None, verbose=0)
    _confMat_preds, kluster2Classes, kr_pdf, weightedPurity, cnmxh_perc = funcH.countPredictionsForConfusionMat(lab_vec, pred_vec, centroid_info_pdf=centroid_info_pdf, labelNames=None)
    sampleCount = np.sum(np.sum(_confMat_preds))
    acc = 100 * np.sum(np.diag(_confMat_preds)) / sampleCount
    bmx, bmn = np.max(bottleneck_vec), np.min(bottleneck_vec)
    return acc, bmx, bmn

funcH.setPandasDisplayOpts()

Constructing the Argument Parsers

In [3]:
#ap = argparse.ArgumentParser()
#ap.add_argument('-e', '--epochs', type=int, default=10, help='number of epochs to train our network for')
#ap.add_argument('-l', '--reg_param', type=float, default=0.001, help='regularization parameter `lambda`')
#ap.add_argument('-sc', '--add_sparse', type=str, default='yes', help='whether to add sparsity contraint or not')
#args = vars(ap.parse_args())
epochs = 100  # args['epochs']
reg_param = 0.001  # args['reg_param']
add_sparsity = 'yes'  # args['add_sparse']
learning_rate = 1e-4
batch_size = 32
print(f"Add sparsity regularization: {add_sparsity}")

Add sparsity regularization: yes


here I will change the data loader per my need

In [4]:
# get the computation device
def get_device():
    return 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = get_device()
print(device)

cpu


In [5]:
# image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

FOLDERS = {
    "data": '/media/wsubuntu/SSD_Data/DataPath',
    "experiment": '/media/wsubuntu/SSD_Data/vaesae_experiments/sparse_torch_ae_ws_002',
}
FOLDERS["model_save"] = os.path.join(FOLDERS["experiment"], "model")
FOLDERS["decoder_image_path_tr"] = os.path.join(FOLDERS["experiment"], "output_images_tr")
FOLDERS["decoder_image_path_va"] = os.path.join(FOLDERS["experiment"], "output_images_va")
funcH.createDirIfNotExist(FOLDERS["model_save"])
funcH.createDirIfNotExist(FOLDERS["decoder_image_path_tr"])
funcH.createDirIfNotExist(FOLDERS["decoder_image_path_va"])

trainset = datasets.FashionMNIST(
    root=FOLDERS["data"],
    train=True, 
    download=True,
    transform=transform
)
testset = datasets.FashionMNIST(
    root=FOLDERS["data"],
    train=False,
    download=True,
    transform=transform
)
 
# trainloader
trainloader = DataLoader(
    trainset, 
    batch_size=batch_size,
    shuffle=True
)
#testloader
testloader = DataLoader(
    testset, 
    batch_size=batch_size, 
    shuffle=False
)

In [6]:
# define the autoencoder model
class SparseAutoencoder(nn.Module):
    def __init__(self, loss_type):
        super(SparseAutoencoder, self).__init__()
 
        # encoder
        self.enc1 = nn.Linear(in_features=784, out_features=256)
        self.enc2 = nn.Linear(in_features=256, out_features=128)
        self.enc3 = nn.Linear(in_features=128, out_features=64)
        self.enc4 = nn.Linear(in_features=64, out_features=32)
        self.enc5 = nn.Linear(in_features=32, out_features=16)
 
        # decoder 
        self.dec1 = nn.Linear(in_features=16, out_features=32)
        self.dec2 = nn.Linear(in_features=32, out_features=64)
        self.dec3 = nn.Linear(in_features=64, out_features=128)
        self.dec4 = nn.Linear(in_features=128, out_features=256)
        self.dec5 = nn.Linear(in_features=256, out_features=784)
        
        self.loss_type=loss_type
        self.device = get_device()
 
    def forward(self, x):
        # encoding
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.enc3(x))
        x = F.relu(self.enc4(x))
        bottleneck = F.relu(self.enc5(x))

        # decoding
        x = F.relu(self.dec1(bottleneck))
        x = F.relu(self.dec2(x))
        x = F.relu(self.dec3(x))
        x = F.relu(self.dec4(x))
        x = F.relu(self.dec5(x))
        return x, bottleneck
model = SparseAutoencoder(loss_type='l1').to(device)

In [7]:
# the loss function
criterion = nn.MSELoss()
# the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [8]:
# get the layers as a list
model_children = list(model.children())
[print(i) for i in model_children]

Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=128, bias=True)
Linear(in_features=128, out_features=64, bias=True)
Linear(in_features=64, out_features=32, bias=True)
Linear(in_features=32, out_features=16, bias=True)
Linear(in_features=16, out_features=32, bias=True)
Linear(in_features=32, out_features=64, bias=True)
Linear(in_features=64, out_features=128, bias=True)
Linear(in_features=128, out_features=256, bias=True)
Linear(in_features=256, out_features=784, bias=True)


[None, None, None, None, None, None, None, None, None, None]

In [9]:
def loss_l1(bottleneck):
    return torch.mean(torch.abs(bottleneck))

def loss_l2(bottleneck):
    return torch.mean(torch.pow(bottleneck, torch.tensor(2.0).to(device))).sqrt()

def kl_divergence(bottleneck):
    rho = 0.05
    bottleneck = torch.mean(torch.sigmoid(bottleneck), 1)  # sigmoid because we need the probability distributions
    rho = torch.tensor([rho] * len(bottleneck)).to(device)
    loss_ret_1 = torch.nn.functional.kl_div(bottleneck, rho, reduction='batchmean')
    # torch.sum(rho * torch.log(rho / bottleneck) + (1 - rho) * torch.log((1 - rho) / (1 - bottleneck)))
    return loss_ret_1

In [10]:
# define the sparse loss function
def sparse_loss(autoencoder, images, print_info, loss_type):
    loss = 0
    values = images
    for i in range(len(model_children)):
        values = F.relu((model_children[i](values)))
        #if print_info:
            #print(i, ' shape=', values.shape)
        if loss_type=='l1':
            loss += loss_l1(values)
        if loss_type=='l2':
            loss += loss_l2(values)
        if loss_type=='kl':
            loss += kl_divergence(values)
        if print_info:
            print(loss_type,loss)
    return loss

In [11]:
def save_decoded_image(img, name):
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, name)

# define the training function
def fit(model, dataloader, epoch, print_losses_fit):
    print('TrEpoch({:03d}) - '.format(epoch), end='')
    model.train()
    running_loss = 0.0
    
    lab_vec = []
    bottleneck_vec = []
    sparsity_loss_sum = 0
    mse_sum = 0
       
    for data in dataloader:
        img, lb = data
        lab_vec.append(lb)
        
        img = img.to(device)
        img = img.view(img.size(0), -1)
        optimizer.zero_grad()
        outputs, bottleneck = model(img)
        bottleneck_vec.append(bottleneck)
        mse_loss = criterion(outputs, img)
        mse_sum += mse_loss.item()
        #if print_losses_fit:
            #print("mse_loss:", mse_loss.to('cpu'))
            #print("bottleneck:", bottleneck.to('cpu'))
        if add_sparsity == 'yes':
            sp_loss = sparse_loss(model, img, print_losses_fit, model.loss_type)
            sparsity_loss_sum += sp_loss.item()
            # add the sparsity penalty
            if print_losses_fit:
                print("sp_loss:", sparsity_loss_sum)
            loss = mse_loss + reg_param * sp_loss
        else:
            loss = mse_loss
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print_losses_fit = False
    
    lab_vec = np.asarray(torch.cat(lab_vec).to(torch.device('cpu')))
    bottleneck_vec = np.asarray(torch.cat(bottleneck_vec).to(torch.device('cpu')).detach().numpy())
    acc, bmx, bmn = calc_bottleneck_acc(bottleneck_vec, lab_vec)
    #print("tr bottleneck accuracy=", acc, ", max=", bmx, ", min=", bmn, ", sparsity_loss_sum=", sparsity_loss_sum)
  
    result_df = pd.DataFrame(np.array([[acc, bmx, bmn, mse_sum, sparsity_loss_sum, running_loss]]), columns=['acc','bmx','bmn','mse','spr','run'])
    #print(df.iloc[0]['mse']) #'acc','bmx','bmn','mse','spr','run'
    print("\n",result_df)
    if epoch % 2 == 0:
        difn = os.path.join(FOLDERS["decoder_image_path_tr"], "train"+str(epoch).zfill(3)+".png")
        save_decoded_image(outputs.cpu().data, difn)
    return result_df

In [12]:
# define the validation function
def validate(model, dataloader, epoch, print_losses_fit):
    print('ValEpoch({:03d}) - '.format(epoch), end='')
    model.eval()
    running_loss = 0.0
    lab_vec = []
    bottleneck_vec = []
    with torch.no_grad():
        for data in dataloader:
            img, lb = data
            lab_vec.append(lb)
            img = img.to(device)
            img = img.view(img.size(0), -1)
            outputs, bottleneck = model(img)
            bottleneck_vec.append(bottleneck)
            loss = criterion(outputs, img)
            running_loss += loss.item()
    # save the reconstructed images every 5 epochs
    lab_vec = np.asarray(torch.cat(lab_vec).to(torch.device('cpu')))
    bottleneck_vec = np.asarray(torch.cat(bottleneck_vec).to(torch.device('cpu')).detach().numpy())
    acc, bmx, bmn = calc_bottleneck_acc(bottleneck_vec, lab_vec)

    result_df = pd.DataFrame(np.array([[acc, bmx, bmn, running_loss]]), columns=['acc','bmx','bmn','run'])
    print("\n",result_df)
    
    if epoch % 2 == 0:
        outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().data
        difn = os.path.join(FOLDERS["decoder_image_path_va"], "reconstruction"+str(epoch).zfill(3)+".png")
        save_image(outputs, difn)
    return result_df

In [13]:
# train and validate the autoencoder neural network
start = time.time()
print_losses_fit = True

train_loss = []
trn_spars_loss = []
trn_bot_acc = []
val_loss = []
val_bot_acc = []

result_df_tr_all = pd.DataFrame(columns=['acc','bmx','bmn','mse','spr','run'])
result_df_va_all = pd.DataFrame(columns=['acc','bmx','bmn','run'])

print("stae_ws05_02 - l1 - loss = mse_loss **+** reg_param * sp_loss")

for epoch in range(epochs):
    print(f"*****\n Epoch {epoch} of {epochs}")
    result_df_tr = fit(model, trainloader, epoch, print_losses_fit)
    result_df_va = validate(model, testloader, epoch, print_losses_fit)
    print_losses_fit = epoch%5==0 and epoch>0
    result_df_tr_all = result_df_tr_all.append(result_df_tr, ignore_index=True)
    result_df_va_all = result_df_va_all.append(result_df_va, ignore_index=True)
    
end = time.time()
 
print(f"{(end-start)/60:.3} minutes")
# save the trained model

mofn = os.path.join(FOLDERS["model_save"], "sparse_ae_"+str(epoch).zfill(3)+".pth")
torch.save(model.state_dict(), mofn)

stae_ws05_02 - l1 - loss = mse_loss **+** reg_param * sp_loss
*****
 Epoch 0 of 100
TrEpoch(000) - l1 tensor(0.1170, grad_fn=<AddBackward0>)
l1 tensor(0.1725, grad_fn=<AddBackward0>)
l1 tensor(0.1969, grad_fn=<AddBackward0>)
l1 tensor(0.2304, grad_fn=<AddBackward0>)
l1 tensor(0.2745, grad_fn=<AddBackward0>)
l1 tensor(0.3567, grad_fn=<AddBackward0>)
l1 tensor(0.4071, grad_fn=<AddBackward0>)
l1 tensor(0.4479, grad_fn=<AddBackward0>)
l1 tensor(0.4787, grad_fn=<AddBackward0>)
l1 tensor(0.4975, grad_fn=<AddBackward0>)
sp_loss: 0.4974929690361023

      acc    bmx  bmn      mse       spr     run
0  10.34  3.569  0.0  198.533  2457.281  200.99
ValEpoch(000) - 
     acc    bmx  bmn     run
0  8.08  3.094  0.0  21.638
*****
 Epoch 1 of 100
TrEpoch(001) - 
       acc    bmx  bmn      mse       spr      run
0  18.938  3.564  0.0  114.773  2574.492  117.347
ValEpoch(001) - 
      acc    bmx  bmn     run
0  19.04  2.582  0.0  17.904
*****
 Epoch 2 of 100
TrEpoch(002) - 
       acc    bmx  bmn     m


       acc    bmx  bmn     mse       spr     run
0  26.633  1.113  0.0  50.537  1873.583  52.411
ValEpoch(026) - 
      acc    bmx  bmn    run
0  21.72  1.088  0.0  8.457
*****
 Epoch 27 of 100
TrEpoch(027) - 
       acc    bmx  bmn     mse       spr     run
0  26.742  1.082  0.0  50.021  1857.908  51.879
ValEpoch(027) - 
      acc    bmx  bmn    run
0  22.58  1.083  0.0  8.465
*****
 Epoch 28 of 100
TrEpoch(028) - 
       acc    bmx  bmn     mse       spr     run
0  26.663  1.097  0.0  49.883  1842.386  51.726
ValEpoch(028) - 
      acc    bmx  bmn   run
0  22.69  1.073  0.0  8.44
*****
 Epoch 29 of 100
TrEpoch(029) - 
       acc   bmx  bmn     mse       spr     run
0  21.932  1.09  0.0  49.014  1830.271  50.844
ValEpoch(029) - 
      acc   bmx  bmn    run
0  24.94  1.02  0.0  8.228
*****
 Epoch 30 of 100
TrEpoch(030) - 
       acc    bmx  bmn     mse       spr     run
0  16.448  1.033  0.0  48.678  1816.222  50.494
ValEpoch(030) - 
      acc    bmx  bmn    run
0  22.64  0.977  0.0  

TrEpoch(055) - 
       acc    bmx  bmn     mse       spr     run
0  24.675  0.827  0.0  41.169  1655.447  42.825
ValEpoch(055) - 
      acc    bmx  bmn    run
0  23.75  0.773  0.0  6.987
*****
 Epoch 56 of 100
TrEpoch(056) - l1 tensor(0.0400, grad_fn=<AddBackward0>)
l1 tensor(0.0680, grad_fn=<AddBackward0>)
l1 tensor(0.0927, grad_fn=<AddBackward0>)
l1 tensor(0.1320, grad_fn=<AddBackward0>)
l1 tensor(0.1864, grad_fn=<AddBackward0>)
l1 tensor(0.2494, grad_fn=<AddBackward0>)
l1 tensor(0.3292, grad_fn=<AddBackward0>)
l1 tensor(0.4562, grad_fn=<AddBackward0>)
l1 tensor(0.6040, grad_fn=<AddBackward0>)
l1 tensor(0.8584, grad_fn=<AddBackward0>)
sp_loss: 0.8584169745445251

       acc    bmx  bmn     mse       spr    run
0  25.005  0.855  0.0  41.038  1652.864  42.69
ValEpoch(056) - 
      acc    bmx  bmn    run
0  24.44  0.764  0.0  6.984
*****
 Epoch 57 of 100
TrEpoch(057) - 
       acc   bmx  bmn     mse       spr     run
0  25.048  0.82  0.0  40.894  1649.168  42.543
ValEpoch(057) - 
      

ValEpoch(081) - 
      acc    bmx  bmn    run
0  30.35  0.654  0.0  6.585
*****
 Epoch 82 of 100
TrEpoch(082) - 
       acc    bmx  bmn     mse       spr     run
0  20.583  0.679  0.0  38.495  1572.448  40.068
ValEpoch(082) - 
      acc    bmx  bmn    run
0  28.19  0.653  0.0  6.556
*****
 Epoch 83 of 100
TrEpoch(083) - 
       acc    bmx  bmn   mse       spr     run
0  30.773  0.691  0.0  38.4  1571.477  39.972
ValEpoch(083) - 
      acc    bmx  bmn    run
0  25.92  0.643  0.0  6.563
*****
 Epoch 84 of 100
TrEpoch(084) - 
       acc    bmx  bmn     mse      spr    run
0  24.827  0.673  0.0  38.299  1571.34  39.87
ValEpoch(084) - 
     acc    bmx  bmn    run
0  22.3  0.659  0.0  6.534
*****
 Epoch 85 of 100
TrEpoch(085) - 
       acc    bmx  bmn     mse       spr     run
0  30.987  0.684  0.0  38.189  1569.718  39.758
ValEpoch(085) - 
      acc    bmx  bmn    run
0  25.66  0.636  0.0  6.509
*****
 Epoch 86 of 100
TrEpoch(086) - l1 tensor(0.0425, grad_fn=<AddBackward0>)
l1 tensor(0.0728

In [14]:
print(result_df_tr_all)

       acc    bmx  bmn      mse       spr      run
0   10.340  3.569  0.0  198.533  2457.281  200.990
1   18.938  3.564  0.0  114.773  2574.492  117.347
2   20.305  2.683  0.0  102.260  2466.176  104.726
3   19.670  2.499  0.0   92.352  2412.836   94.764
4   23.177  2.160  0.0   75.472  2574.698   78.047
5   25.953  2.055  0.0   67.901  2545.863   70.447
6   28.560  2.031  0.0   63.970  2449.728   66.420
7   28.302  1.963  0.0   62.313  2371.121   64.684
8   26.335  1.835  0.0   61.034  2327.248   63.361
9   28.450  1.718  0.0   60.055  2286.423   62.342
10  29.287  1.616  0.0   58.706  2251.900   60.958
11  28.952  1.598  0.0   57.515  2219.898   59.735
12  20.335  1.563  0.0   56.886  2183.527   59.069
13  23.167  1.484  0.0   56.333  2150.861   58.484
14  24.813  1.454  0.0   55.731  2128.642   57.859
15  25.542  1.461  0.0   55.076  2110.468   57.187
16  27.237  1.396  0.0   54.503  2084.872   56.588
17  27.888  1.338  0.0   53.495  2062.080   55.557
18  25.517  1.304  0.0   52.495

In [15]:
print(result_df_va_all)

      acc    bmx  bmn     run
0    8.08  3.094  0.0  21.638
1   19.04  2.582  0.0  17.904
2   19.65  2.818  0.0  16.543
3   23.52  2.340  0.0  13.886
4   24.35  2.030  0.0  11.809
5   23.21  1.985  0.0  10.947
6   24.62  1.791  0.0  10.574
7   24.75  1.798  0.0  10.334
8   29.59  1.571  0.0  10.156
9   27.54  1.593  0.0  10.000
10  27.37  1.561  0.0   9.736
11  23.20  1.488  0.0   9.592
12  25.00  1.477  0.0   9.525
13  24.12  1.365  0.0   9.442
14  21.17  1.504  0.0   9.335
15  25.20  1.432  0.0   9.209
16  21.03  1.408  0.0   9.147
17  19.41  1.360  0.0   8.950
18  26.42  1.317  0.0   8.795
19  26.07  1.239  0.0   8.764
20  26.41  1.302  0.0   8.747
21  21.40  1.233  0.0   8.692
22  23.06  1.198  0.0   8.682
23  21.58  1.143  0.0   8.640
24  19.30  1.164  0.0   8.620
25  21.19  1.103  0.0   8.586
26  21.72  1.088  0.0   8.457
27  22.58  1.083  0.0   8.465
28  22.69  1.073  0.0   8.440
29  24.94  1.020  0.0   8.228
30  22.64  0.977  0.0   8.224
31  23.69  0.983  0.0   8.130
32  24.48 