In [None]:
from architectures import DENOISERS_ARCHITECTURES, get_architecture, IMAGENET_CLASSIFIERS, CLASSIFIERS_ARCHITECTURES
from datasets import get_dataset, DATASETS, get_num_classes, get_normalize_layer

import argparse
import numpy as np
import pandas as pd
import os
import sys
from PIL import Image

import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision import transforms, datasets, models

import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
parser = argparse.ArgumentParser()
# Directory
parser.add_argument('--data-dir', default='data/', help='data path')
parser.add_argument('--ckpt-dir', default='checkpoint_unet/', help='checkpoint path')
parser.add_argument('--pretrained', default='../denoised-smoothing/pretrained_models/cifar10_classifiers/ResNet110_90epochs/noise_0.00/checkpoint.pth.tar', help='pretrained_model path')
parser.add_argument('--pretrained-denoiser', default='../denoised-smoothing/pretrained_models/trained_denoisers/cifar10/stab_obj/cifar10_smoothness_obj_adamThenSgd_6/multi_classifiers/dncnn/noise_0.12/checkpoint.pth.tar', help='pretrained_model path')
parser.add_argument('--noise-sd', type=float, default=0.12, help='sd for noise')
parser.add_argument('--name', type=str, default='mnet_unet0.00_csv1', help='name of saved checkpoints')
parser.add_argument('--dataset', default='cifar10', choices=DATASETS)
parser.add_argument('--split', default='test', choices=['train','test'])
parser.add_argument('--type', default='ds', choices=['ours','ds', 'rs'])


args = parser.parse_args("".split())

In [None]:
dataset = get_dataset(args.dataset, args.split)

test_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=100,
                                          shuffle=False,
                                          num_workers=8)

data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=256,
                                          shuffle=False,
                                          num_workers=8)

classes =  {0:'airplane',
            1:'automobile',
            2:'bird',
            3:'cat',
            4:'deer',
            5:'dog',
            6:'frog',
            7:'horse',
            8:'ship',
            9: 'truck'}

In [None]:
class pipeline(nn.Module):

    def __init__(self, model):
        super(pipeline, self).__init__()
        self.model = model

    def forward(self, input):
        if args.type=='ours':
            x = normalize(input)
        else:
            x = input
        n = torch.randn(x.shape).cuda()*args.noise_sd
        return self.model[0](x + n)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

if args.type in ['ours','ds']:
    if args.type =='ours':
        print('=====> Loading trained model from checkpoint...')
        noise_sd = 0.25
        checkpoint = torch.load(args.ckpt_dir + args.name + '.ckpt')

        model = checkpoint['model']
        rng_state = checkpoint['rng_state']
        torch.set_rng_state(rng_state)
    elif args.type =='ds':
        if args.pretrained_denoiser:
            print('=====> Loading Pretrained Classifier...')
            checkpoint = torch.load(args.pretrained_denoiser)
            model = get_architecture(checkpoint['arch'], args.dataset)
            model.load_state_dict(checkpoint['state_dict'], strict=False)       

    if args.pretrained:
        print('=====> Loading Pretrained Classifier...')
        checkpoint = torch.load(args.pretrained)
        cs_model = get_architecture(checkpoint['arch'], args.dataset)
        cs_model.load_state_dict(checkpoint['state_dict'], strict=False)
    
    
    model = torch.nn.Sequential(model, cs_model).to(device)
    
elif args.type == 'rs':
    if args.pretrained:
        print('=====> Loading Pretrained Classifier...')
        checkpoint = torch.load(args.pretrained)
        model = get_architecture(checkpoint['arch'], args.dataset)
        model.load_state_dict(checkpoint['state_dict'], strict=False) 
        
model.eval()

pipe = pipeline(model)

In [None]:
recon_prob = torch.cat([pipe(X.to(device)).detach().cpu() for X,_ in data_loader], dim=0).numpy()

In [None]:
X,y = next(iter(test_loader))
y_label = y.numpy()

In [None]:
df = pd.DataFrame([classes[i] for i in y_label], columns = ['label'])

In [None]:
label = df['label']
label.value_counts()

In [None]:
%%time
train = StandardScaler().fit_transform(recon_prob)
tsne_res = tsne.fit_transform(train)

In [None]:
path = ''
plt.figure(figsize=(10,8))

sns.scatterplot(x = tsne_res[:,0], y = tsne_res[:,1], hue = label, palette = sns.hls_palette(10), legend = 'full');
plt.axis('off')
plt.savefig(path + 'image{}.png'.format(args.noise_sd), bbox_inches='tight')