In [3]:
import numpy as np
import numpy.linalg as nplin
import os
import torch
import torch.nn as nn
import torch.optim as optim
import tensorboardX
import torch.utils.data as data

from shrink import l1shrink
from util import *
from networks import *
# from dae import DAE
from dset import PartialMNIST
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
a = torch.ones((4, 784))

In [8]:
len(torch.norm(a, dim=0))

784

In [10]:
aa = 'abcd'
b = 'a'
b in aa

True

In [None]:
class arguments():
    def __init__(self):
        self.datadir='/home/jehyuk/PycharmProjects/RobustDAE/'
        self.data='mnist'
        self.image_ch=1
        # Model arguments
        self.model='dae'
        self.image_size=28
        self.n_ch=64
        self.dims=[784,200,20]
        self.kernels=[4,4,4,4]
        self.strides=[1,1,1,1]
        self.paddings=[0,0,0,0]
        self.out_act_fn='sigmoid'
        self.act_fn='relu'
        self.use_fc=False
        self.embed_dim=20
        self.bias=True
        #Train arguments
        self.batch_size=64
        self.lr = 0.0001
        self.n_corrupt_rows=5
        self.n_corrupt_cols=5
        self.noise_method=None
        self.n_workers=10
        self.device_num=0
        self.multi_gpus=[0]
        self.inner_epochs=2000
        self.outer_epochs=20
        self.log_dir='/home/jehyuk/PycharmProjects/RobustDAE/logs'
        self.save_dir='/home/jehyuk/PycharmProjects/RobustDAE/models'
        self.result_dir='/home/jehyuk/PycharmProjects/RobustDAE/results'
        self.mode='train'
        self.save=False
        self.load=False


In [None]:
args = arguments()

In [None]:
device = torch.device(f'cuda:{args.device_num}')

In [None]:
class DAE(nn.Module):
    def __init__(self, args, device):
        super(DAE, self).__init__()
        self.args = args
        self.device = device
        self.encoder = Encoder(args).to(device)
        self.decoder = Decoder(args).to(device)
        self.loss_func = nn.MSELoss(reduction='none')
        params = list(self.encoder.parameters()) + list(self.decoder.parameters())
        self.opt = optim.Adam(params=params, lr=args.lr, betas=(0.5, 0.999))
        
    def initialize_param(self):
        self.encoder.apply(initialize_weights)
        self.decoder.apply(initialize_weights)
    
    def fit(self, trn_loader):
        writer = tensorboardX.SummaryWriter(self.args.log_dir)
        for epoch in range(self.args.inner_epochs):
            self.encoder.train()
            self.decoder.train()
            trn_loss = self.partial_fit(trn_loader)
            if (epoch + 1) % 20 == 0:
                writer.add_scalar('trn_loss', trn_loss, global_step=epoch)
                print(f"In epoch {epoch + 1}, trn_loss = {trn_loss:.4f}")

    def partial_fit(self, trn_loader):
        for idx, (image, label) in enumerate(trn_loader):
            image, label = image.to(self.device), label.to(self.device)
            image_recon = self.reconstruct(image)
            trn_loss = torch.sum(self.loss_func(image_recon, image), dim=(0, 1, 2, 3)) / image.size()[0]
            self.opt.zero_grad()
            trn_loss.backward()
            self.opt.step()
            return trn_loss

    def get_embedding_vector(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            out = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
        else:
            out = self.encoder(x)
        return out

    def reconstruct(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            z = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
            x_recon = nn.parallel.data_parallel(self.decoder, z, device_ids=self.args.multi_gpus)
        else:
            z = self.encoder(x)
            x_recon = self.decoder(z)
        return x_recon

    def save_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        torch.save(self.encoder.state_dict(), os.path.join(save_path, 'encoder.pkl'))
        torch.save(self.decoder.state_dict(), os.path.join(save_path, 'decoder.pkl'))
        print('Save model!')

    def load_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        self.encoder.load_state_dict(torch.load(os.path.join(save_path, 'encoder.pkl')))
        self.decoder.load_state_dict(torch.load(os.path.join(save_path, 'decoder.pkl')))

In [None]:
class CDAE(nn.Module):
    def __init__(self, args, device):
        super(CDAE, self).__init__()
        self.args = args
        self.device = device
        self.encoder = Encoder_conv2d(args).to(self.device)
        embed_w, embed_h = args.image_size, args.image_size
        for i in range(len(args.kernels)):
            embed_w, embed_h = conv2d_output_size(embed_w, embed_h, args.kernels[i], args.kernels[i], args.strides[i],
                                                  args.paddings[i])
        self.decoder = Decoder_conv2d(args, embed_w, embed_h, out_act_fn=None).to(self.device)
        self.device = device
        self.loss_func = nn.MSELoss(reduction='none')
        params = list(self.encoder.parameters()) + list(self.decoder.parameters())
        self.opt = optim.Adam(params=params, lr=0.001, betas=(0.5, 0.999))
        
    def initialize_param(self):
        self.encoder.apply(initialize_weights)
        self.decoder.apply(initialize_weights)
    
    def fit(self, trn_loader):
        writer = tensorboardX.SummaryWriter(self.args.log_dir)
        for epoch in range(self.args.max_epochs):
            self.encoder.train()
            self.decoder.train()
            trn_loss = self.partial_fit(trn_loader)
            if (epoch + 1) % 20 == 0:
                writer.add_scalar('trn_loss', trn_loss, global_step=epoch)
                print(f"In epoch {epoch + 1}, trn_loss = {trn_loss:.4f}")

    def partial_fit(self, trn_loader):
        for idx, (image, label) in enumerate(trn_loader):
            image, label = image.to(self.device), label.to(self.device)
            image_recon = self.reconstruct(image)
            trn_loss = torch.sum(self.loss_func(image_recon, image), dim=(0, 1, 2, 3)) / image.size()[0]
            self.opt.zero_grad()
            trn_loss.backward()
            self.opt.step()
            return trn_loss

    def get_embedding_vector(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            out = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
        else:
            out = self.encoder(x)
        return out

    def reconstruct(self, x):
        if x.is_cuda and len(self.args.multi_gpus) > 1:
            z = nn.parallel.data_parallel(self.encoder, x, device_ids=self.args.multi_gpus)
            x_recon = nn.parallel.data_parallel(self.decoder, z, device_ids=self.args.multi_gpus)
        else:
            z = self.encoder(x)
            x_recon = self.decoder(z)
        return x_recon

    def save_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        torch.save(self.encoder.state_dict(), os.path.join(save_path, 'encoder.pkl'))
        torch.save(self.decoder.state_dict(), os.path.join(save_path, 'decoder.pkl'))
        print('Save model!')

    def load_model(self, save_path=None):
        if save_path is None:
            save_path = self.args.save_dir
        self.encoder.load_state_dict(torch.load(os.path.join(save_path, 'encoder.pkl')))
        self.decoder.load_state_dict(torch.load(os.path.join(save_path, 'decoder.pkl')))

In [None]:
trn_class_dict = dict()
trn_class_dict[0] = 2000
trn_class_dict[1] = 2000
trn_class_dict[2] = 2000
trn_class_dict[3] = 2000
trn_class_dict[4] = 2000
trn_class_dict[5] = 2000
trn_class_dict[6] = 2000
trn_class_dict[7] = 2000
trn_class_dict[8] = 2000
trn_class_dict[9] = 2000

In [None]:
noise = torch.zeros((2,6,5))

In [None]:
noise

In [None]:
noise[0] + noise[1]

In [None]:
for i in [1,2,4]:
    for j in [1,2,4]:
        noise[:, i, j] = torch.rand(2)

In [None]:
noise

In [None]:
class Noise(object):
    def __init__(self, corrupt_col_idx = [1,2,3,4,5], corrupt_row_idx = [x for x in range(1,6)], method='fixed'):
        self.corrupt_col_idx = corrupt_col_idx
        self.corrupt_row_idx = corrupt_row_idx
        self.method = method
        
    def __call__(self, sample):
        if self.method == 'fixed':
            for col in self.corrupt_col_idx:
                sample[:,:,col] = torch.rand(1)
        elif self.method == 'uniform':
            for col in self.corrupt_col_idx:
                for row in self.corrupt_row_idx:
                    sample[:,row,col] = torch.rand(1)
#         elif self.method == 'gaussian':
#             for col in self.corrupt_col_idx:
#                 for row in self.corrupt_row_idx:
#                     sample[:,row,col] = torch.randn(1)
        else:
            raise ValueError('Enter the proper noise method')
        
        return sample

In [None]:
class NewDataset(data.Dataset):
    def __init__(self, tensor_x, tensor_y, transform = None):
        self.tensor_x = tensor_x.view(-1, 1, 28, 28)
        self.tensor_y = tensor_y
        self.transform = transform
    
    def __len__(self):
        return len(self.tensor_x)
    
    def __getitem__(self, idx):
        return self.tensor_x[idx], self.tensor_y[idx]

In [None]:
trn_dset = PartialMNIST(root = '/home/jehyuk/PycharmProjects/RobustDAE/', sample_dict = trn_class_dict, train=True, download=True, 
                        transform = transforms.Compose([transforms.ToTensor(), Noise(method='uniform')]))
# trn_dset = PartialMNIST(root = '/home/jehyuk/PycharmProjects/RobustDAE/', sample_dict = trn_class_dict, train=True, download=True, 
#                         transform = transforms.Compose([transforms.ToTensor(), Noise(method='fixed'),
#                                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

In [None]:
class RCDAE(nn.Module):
    def __init__(self, args, device, lambda_=1.0, tol=1e-7):
        super(RCDAE, self).__init__()
        self.args = args
        self.device = device
        self.lambda_ = lambda_
        self.tol = tol
        self.ae = CDAE(args, device)
    
    def make_loader(self, dset, rdae=True):
        if rdae:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=False, num_workers = 0, drop_last = False)
        else:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=True, num_workers = self.args.n_workers, drop_last = True)
        return loader
    
    def get_flatX(self, loader):
        image_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            image_list.append(image)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in image_list], dim=0)
        labels = torch.cat([x for x in label_list])

        return flat_images, labels
    
    def get_flat_recon_X(self, loader):
        recon_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            recon = self.ae.reconstruct(image)
            recon_list.append(recon)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in recon_list], dim=0)
        labels = torch.cat([x for x in label_list])
        
        return flat_images, labels
    
    def fit(self, trn_dset, verbose=True):
        # Make data loader
        rdae_loader = self.make_loader(trn_dset, rdae=True)
        # Make flat data
        X, Y = self.get_flatX(rdae_loader)
        # Make L and S
        L = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        S = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        mu = (X.size()[0] * X.size()[1]) / torch.norm(X, 1)
        print(f'shrink param: {self.lambda_ / mu}')
        LS0 = L + S
        XFnorm = torch.norm(X, 'fro')
        for i in range(self.args.outer_epochs):
            print(f">>{i+1}th epoch")
            L = X - S
            trn_dset = NewDataset(L.cpu(), Y.cpu())
            trn_loader = self.make_loader(trn_dset, rdae=False)
            print('>>>>start train ae')
            self.ae.fit(trn_loader)
            print('>>>>end train ae')
            L = self.ae.reconstruct(L)
            L = L.flatten(start_dim = 1)
            S = X-L
            S = shrink(self.lambda_ / mu, S, device)
#             S = shrink(self.lambda_/mu, S.view(X.size()[0], X.size()[1])).view(X.size()[0], X.size()[1])
            
            c1 = torch.norm(X-L-S, 'fro') / XFnorm
#             c2 = np.min([mu, np.sqrt(mu)]) * torch.norm(LS0 - L-S) / XFnorm
            c2 = mu * torch.norm(LS0 - L-S) / XFnorm
            
            self.L, self.S = L, S
            if verbose:
                print(f"c1: {c1:.4f}, c2: {c2:.4f}")
            if c1 < self.tol and c2 < self.tol:
                print("Early break")
                break
            LS0 = L + S
            
        return self.L, self.S
    
    def transform(self, x):
        L = x - self.S
        return self.ae.get_embedding_vector(L)
        
    def reconstruct(self, x):
        return self.ae.reconstruct(x)

In [None]:
trn_loader = torch.utils.data.DataLoader(trn_dset, batch_size = args.batch_size, shuffle=True, num_workers = args.n_workers, drop_last = True)

In [None]:
cdae = CDAE(args, device)

In [None]:
class RDAE(nn.Module):
    def __init__(self, args, device, lambda_=1.0, tol=1e-7):
        super(RDAE, self).__init__()
        self.args = args
        self.device = device
        self.lambda_ = lambda_
        self.tol = tol
        self.ae = DAE(args, device)
    
    def make_loader(self, dset, rdae=True):
        if rdae:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=False, num_workers = 0, drop_last = False)
        else:
            loader = torch.utils.data.DataLoader(dset, batch_size = self.args.batch_size, 
                                                 shuffle=True, num_workers = self.args.n_workers, drop_last = True)
        return loader
    
    def get_flatX(self, loader):
        image_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            image_list.append(image)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in image_list], dim=0)
        labels = torch.cat([x for x in label_list])

        return flat_images, labels
    
    def get_flat_recon_X(self, loader):
        recon_list, label_list = list(), list()
        for _, (image, label) in enumerate(loader):
            image, label = image.to(self.device), label.to(self.device)
            recon = self.ae.reconstruct(image)
            recon_list.append(recon)
            label_list.append(label)
        flat_images = torch.cat([x.flatten(start_dim=1) for x in recon_list], dim=0)
        labels = torch.cat([x for x in label_list])
        
        return flat_images, labels
    
    def fit(self, trn_dset, verbose=True):
        # Make data loader
        rdae_loader = self.make_loader(trn_dset, rdae=True)
        # Make flat data
        X, Y = self.get_flatX(rdae_loader)
        # Make L and S
        L = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        S = torch.zeros((X.size()[0], X.size()[1])).to(self.device)
        mu = (X.size()[0] * X.size()[1]) / torch.norm(X, 1)
        print(f'shrink param: {self.lambda_ / mu}')
        LS0 = L + S
        XFnorm = torch.norm(X, 'fro')
        for i in range(self.args.outer_epochs):
            print(f">>{i+1}th epoch")
            L = X - S
            trn_dset = NewDataset(L.cpu(), Y.cpu())
            trn_loader = self.make_loader(trn_dset, rdae=False)
            print('>>>>start train ae')
            self.ae.fit(trn_loader)
            print('>>>>end train ae')
            L = self.ae.reconstruct(L)
            L = L.flatten(start_dim = 1)
            S = X-L
            S = shrink(self.lambda_ / mu, S, device)
#             S = shrink(self.lambda_/mu, S.view(X.size()[0], X.size()[1])).view(X.size()[0], X.size()[1])
            
            c1 = torch.norm(X-L-S, 'fro') / XFnorm
#             c2 = np.min([mu, np.sqrt(mu)]) * torch.norm(LS0 - L-S) / XFnorm
            c2 = mu * torch.norm(LS0 - L-S) / XFnorm
            
            self.L, self.S = L, S
            if verbose:
                print(f"c1: {c1:.4f}, c2: {c2:.4f}")
            if c1 < self.tol and c2 < self.tol:
                print("Early break")
                break
            LS0 = L + S
            
        return self.L, self.S
    
    def transform(self, x):
        L = x - self.S
        return self.ae.get_embedding_vector(L)
        
    def reconstruct(self, x):
        return self.ae.reconstruct(x)

In [None]:
rdae = RDAE(args, device)

In [None]:
rdae.fit(trn_dset)

In [None]:
tmp_loader = torch.utils.data.DataLoader(trn_dset, batch_size = 10, 
                                         shuffle=False, num_workers = 0, drop_last = False)

In [None]:
for i, (image, label) in enumerate(tmp_loader):
    image_idx = 7
    image = image.to(device)
    recon = rdae.ae.reconstruct(image)
    x = image[image_idx].cpu()
    x_recon = recon[image_idx].cpu().detach()
    if i == 0:
        break

In [None]:
plt.imshow(x.squeeze())
plt.colorbar()

In [None]:
plt.imshow(x_recon.squeeze())
plt.colorbar()

In [None]:
dae = DAE(args, device)

In [None]:
trn_loader = torch.utils.data.DataLoader(trn_dset, batch_size = args.batch_size, shuffle=True, num_workers = args.n_workers, drop_last = True)

In [None]:
dae.fit(trn_loader)

In [None]:
tmp_loader = torch.utils.data.DataLoader(trn_dset, batch_size = 10, 
                                         shuffle=False, num_workers = 0, drop_last = False)

In [None]:
for i, (image, label) in enumerate(tmp_loader):
    image_idx = 9
    image = image.to(device)
    recon = dae.reconstruct(image)
    x = image[image_idx].cpu()
    x_recon = recon[image_idx].cpu().detach()
    if i == 0:
        break

In [None]:
plt.imshow(x.squeeze())
plt.colorbar()

In [None]:
plt.imshow(x_recon.squeeze())
plt.colorbar()