In [1]:
import numpy as np
import numpy.linalg as nplin
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tensorboardX
import torch.utils.data as data

from l1shrink import shrink
from util import *
from networks import *
# from dae import DAE
from partial_mnist import PartialMNIST
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class arguments():
    def __init__(self):
        self.datadir='/home/jehyuk/PycharmProjects/RobustDAE/'
        self.data='mnist'
        self.image_ch=1
        # Model arguments
        self.model='dae'
        self.image_size=28
        self.n_ch=64
        self.dims=[784,200,20]
        self.kernels=[4,4,4,4]
        self.strides=[1,1,1,1]
        self.paddings=[0,0,0,0]
        self.out_act_fn='tanh'
        self.act_fn='relu'
        self.use_fc=False
        self.embed_dim=20
        #Train arguments
        self.batch_size=64
        self.lr = 0.0001
        self.n_corrupt_rows=5
        self.n_corrupt_cols=5
        self.noise_method=None
        self.n_workers=10
        self.device_num=0
        self.multi_gpus=[0]
        self.inner_epochs=100
        self.outer_epochs=5
        self.log_dir='/home/jehyuk/PycharmProjects/RobustDAE/logs'
        self.save_dir='/home/jehyuk/PycharmProjects/RobustDAE/models'
        self.result_dir='/home/jehyuk/PycharmProjects/RobustDAE/results'
        self.mode='train'
        self.save=False
        self.load=False


In [3]:
args = arguments()

In [4]:
device = torch.device(f'cuda:{args.device_num}')

In [5]:
class DAE(nn.Module):
    def __init__(self, args, device):
        super(DAE, self).__init__()
        self.args = args
        self.device = device
        self.encoder = Encoder(args).to(device)
        self.decoder = Decoder(args).to(device)
        self.device = device
        self.loss_func = nn.MSELoss(reduction='none')
        params = list(self.encoder.parameters()) + list(self.decoder.parameters())
        self.opt = optim.Adam(params=params, lr=args.lr, betas=(0.5, 0.999))
        
    def initialize_param(self):
        self.encoder.apply(initialize_weights)
        self.decoder.apply(initialize_weights)
    
    def fit(self, trn_loader):
        writer = tensorboardX.SummaryWriter(self.args.log_dir)
        for epoch in range(self.args.inner_epochs):
            self.encoder.train()
            self.decoder.train()
            trn_loss = self.partial_fit(trn_loader)
            if (epoch + 1) % 20 == 0:
                writer.add_scalar('trn_loss', trn_loss, global_step=epoch)
                print(f"In epoch {epoch + 1}, trn_loss = {trn_loss:.4f}")

    def partial_fit(self, trn_loader):
        for idx, (image, label) in enumerate(trn_loader):
            image, label = image.to(self.device), label.to(self.device)
            image_recon = self.reconstruct(image)
            trn_loss = torch.sum(self.loss_func(image_recon, image), dim=(0, 1, 2, 3)) / image.size()[0]
            self.opt.zero_grad()
            trn_loss.backward()
            self.opt.step()
            return trn_loss

    def get_embedding_vector(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            out = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
        else:
            out = self.encoder(x)
        return out

    def reconstruct(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            z = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
            x_recon = nn.parallel.data_parallel(self.decoder, z, device_ids=self.args.multi_gpus)
        else:
            z = self.encoder(x)
            x_recon = self.decoder(z)
        return x_recon

    def save_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        torch.save(self.encoder.state_dict(), os.path.join(save_path, 'encoder.pkl'))
        torch.save(self.decoder.state_dict(), os.path.join(save_path, 'decoder.pkl'))
        print('Save model!')

    def load_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        self.encoder.load_state_dict(torch.load(os.path.join(save_path, 'encoder.pkl')))
        self.decoder.load_state_dict(torch.load(os.path.join(save_path, 'decoder.pkl')))

In [6]:
trn_class_dict = dict()
trn_class_dict[0] = 2000
trn_class_dict[1] = 2000
trn_class_dict[2] = 2000
trn_class_dict[3] = 2000
trn_class_dict[4] = 2000
trn_class_dict[5] = 2000
trn_class_dict[6] = 2000
trn_class_dict[7] = 2000
trn_class_dict[8] = 2000
trn_class_dict[9] = 2000

In [7]:
class Noise(object):
    def __init__(self, corrupt_col_idx = [1,2,3,4,5], corrupt_row_idx = [x for x in range(1,6)], method='fixed'):
        self.corrupt_col_idx = corrupt_col_idx
        self.corrupt_row_idx = corrupt_row_idx
        self.method = method
        
    def __call__(self, sample):
        if self.method == 'fixed':
            for col in self.corrupt_col_idx:
                sample[:,:,col] = torch.rand(1)
        elif self.method == 'uniform':
            for col in self.corrupt_col_idx:
                for row in self.corrupt_row_idx:
                    sample[:,row,col] = torch.rand(1)
#         elif self.method == 'gaussian':
#             for col in self.corrupt_col_idx:
#                 for row in self.corrupt_row_idx:
#                     sample[:,row,col] = torch.randn(1)
        else:
            raise ValueError('Enter the proper noise method')
        
        return sample

In [8]:
class NewDataset(data.Dataset):
    def __init__(self, tensor_x, tensor_y, transform = None):
        self.tensor_x = tensor_x.view(-1, 1, 28, 28)
        self.tensor_y = tensor_y
        self.transform = transform
    
    def __len__(self):
        return len(self.tensor_x)
    
    def __getitem__(self, idx):
        return self.tensor_x[idx], self.tensor_y[idx]

In [9]:
trn_dset = PartialMNIST(root = '/home/jehyuk/PycharmProjects/RobustDAE/', sample_dict = trn_class_dict, train=True, download=True, 
                        transform = transforms.Compose([transforms.ToTensor(), Noise(method='uniform')]))
# trn_dset = PartialMNIST(root = '/home/jehyuk/PycharmProjects/RobustDAE/', sample_dict = trn_class_dict, train=True, download=True, 
#                         transform = transforms.Compose([transforms.ToTensor(), Noise(method='fixed'),
#                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

In [10]:
class RDAE(nn.Module):
    def __init__(self, args, device, lambda_=1.0, tol=1e-7):
        super(RDAE, self).__init__()
        self.args = args
        self.device = device
        self.lambda_ = lambda_
        self.tol = tol
        self.ae = DAE(args, device)
    
    def make_loader(self, dset, rdae=True):
        if rdae:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=False, num_workers = 0, drop_last = False)
        else:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=True, num_workers = self.args.n_workers, drop_last = True)
        return loader
    
    def get_flatX(self, loader):
        image_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            image_list.append(image)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in image_list], dim=0)
        labels = torch.cat([x for x in label_list])

        return flat_images, labels
    
    def get_flat_recon_X(self, loader):
        recon_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            recon = self.ae.reconstruct(image)
            recon_list.append(recon)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in recon_list], dim=0)
        labels = torch.cat([x for x in label_list])
        
        return flat_images, labels
    
    def fit(self, trn_dset, verbose=False):
        # Make data loader
        rdae_loader = self.make_loader(trn_dset, rdae=True)
        # Make flat data
        X, Y = self.get_flatX(rdae_loader)
        # Make L and S
        L = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        S = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        mu = (X.size()[0] * X.size()[1]) / torch.norm(X, 1)
        print(f'shrink param: {self.lambda_ / mu}')
        LS0 = L + S
        XFnorm = torch.norm(X, 'fro')
        for i in range(self.args.outer_epochs):
            L = X - S
            trn_dset = NewDataset(L.cpu(), Y.cpu())
            trn_loader = self.make_loader(trn_dset, rdae=False)
            self.ae.fit(trn_loader)
            L = self.ae.reconstruct(L)
            L = L.flatten(start_dim = 1)
            print(f"X.size(): {X.size()}")
            S = X-L
            S = shrink(self.lambda_ / mu, S, device)
#             S = shrink(self.lambda_/mu, S.view(X.size()[0], X.size()[1])).view(X.size()[0], X.size()[1])
            
            c1 = torch.norm(X-L-S, 'fro') / XFnorm
#             c2 = np.min([mu, np.sqrt(mu)]) * torch.norm(LS0 - L-S) / XFnorm
            c2 = mu * torch.norm(LS0 - L-S) / XFnorm
            
            self.L, self.S = L, S
            if verbose:
                print(f"c1: {c1}, c2: {c2}")
            if c1 < self.tol and c2 < self.tol:
                print("Early break")
                break
            LS0 = L + S
            
        return self.L, self.S
    
    def transform(self, x):
        L = x - self.S
        return self.ae.get_embedding_vector(L)
        
    def reconstruct(self, x):
        return self.ae.reconstruct(x)

In [11]:
rdae = RDAE(args, device)

In [12]:
rdae.fit(trn_dset)

shrink param: 0.14755019545555115


Traceback (most recent call last):
  File "/home/jehyuk/anaconda3/lib/python3.6/multiprocessing/util.py", line 262, in _run_finalizers
    finalizer()
  File "/home/jehyuk/anaconda3/lib/python3.6/multiprocessing/util.py", line 186, in __call__
    res = self._callback(*self._args, **self._kwargs)
  File "/home/jehyuk/anaconda3/lib/python3.6/shutil.py", line 484, in rmtree
    onerror(os.rmdir, path, sys.exc_info())
  File "/home/jehyuk/anaconda3/lib/python3.6/shutil.py", line 482, in rmtree
    os.rmdir(path)
OSError: [Errno 39] Directory not empty: '/tmp/pymp-llkj9cjl'


In epoch 20, trn_loss = 91.7201
In epoch 40, trn_loss = 69.8895
In epoch 60, trn_loss = 54.9350
In epoch 80, trn_loss = 52.1752
In epoch 100, trn_loss = 50.2719
X.size(): torch.Size([20000, 784])


TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.