In [0]:
# Importing Libraries

import argparse
import os
import sys
import random
import shutil
import time
import warnings
import numpy as np
import pickle
import inspect
import numpy as np
import math
import logging
from PIL import Image
from glob import glob
import matplotlib as mpl
if os.environ.get('DISPLAY','') == '':
    print('no display found. Using non-interactive Agg backend')
    mpl.use('Agg')
import matplotlib.pyplot as plt
import scipy.io as sio

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torchvision import transforms
from torchvision import datasets

In [0]:
# Transformer

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn
    
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
        if not mask==None:
            mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn
    
class GELU(nn.Module):
    """
    BERT used the GELU instead of RELU
    """

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

In [0]:
class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
        self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)
        self.activation = GELU()
    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)
        output = self.w_2(self.activation(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output
    
class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output *= non_pad_mask

        enc_output = self.pos_ffn(enc_output)
        enc_output *= non_pad_mask

        return enc_output, enc_slf_attn


In [0]:
# Model

class attention_pooling(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):

        super().__init__()

        self.layer_stack = nn.ModuleList([EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) for _ in range(n_layers)])

    def forward(self, src_seq, return_attns=False):

        enc_slf_attn_list = []
        enc_output = src_seq
        
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output,non_pad_mask=1,slf_attn_mask=None)
            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output[:, 0, :]

class SelfieModel(nn.Module):
    def __init__(self, n_layers, n_heads, d_in, d_model, d_ff, n_split, dropout=0.1, use_cuda=True, gpu = None):
        super(SelfieModel, self).__init__()
        self.n_split = n_split
        self.at_pool = attention_pooling(n_layers + 1, n_heads, d_in, d_in, d_model, d_ff)

        if use_cuda:
            if gpu is None:
                self.row_embeddings = nn.Parameter(torch.randn(n_split, d_model).cuda())
                self.column_embeddings = nn.Parameter(torch.zeros(n_split, d_model).cuda())
                self.u0 = nn.Parameter(torch.zeros(1,1,d_model).cuda())
            else:
                self.row_embeddings = nn.Parameter(torch.randn(n_split, d_model).cuda(gpu))
                self.column_embeddings = nn.Parameter(torch.zeros(n_split, d_model).cuda(gpu))
                self.u0 = nn.Parameter(torch.zeros(1,1,d_model).cuda(gpu))
        else:
            self.row_embeddings = nn.Parameter(torch.randn(n_split, d_model))
            self.column_embeddings = nn.Parameter(torch.zeros(n_split, d_model))
            self.u0 = nn.Parameter(torch.zeros(1,1,d_model))

    def forward(self, src_seq, pos, return_attns=False):
        u = self.u0.repeat((src_seq.shape[0], 1, 1))
        src_seq = torch.cat([u, src_seq], dim=1)
        before_embeddings =  self.at_pool(src_seq)
        final = []
        rows = map(lambda x: np.trunc(x / self.n_split).astype("int"), pos)
        cols = map(lambda x: np.mod(x, self.n_split).astype("int"), pos)
        for (i,j) in zip(rows, cols):
            sum_up = before_embeddings + self.row_embeddings[i,:] + self.column_embeddings[j,:]
            sum_up = sum_up.unsqueeze(1)
            final.append(sum_up)

        return torch.cat(final, 1)

In [0]:
# Resnet

class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.relu = nn.ReLU(inplace = True)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn0 = nn.BatchNorm2d(in_planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        pre = self.relu(self.bn0(x))
        out = self.relu(self.bn1(self.conv1(pre)))
        out = self.relu(self.bn2(self.conv2(out)))
        out = self.conv3(out)
        if len(self.shortcut)==0:
            out += self.shortcut(x)
        else:
            out += self.shortcut(pre)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class P(nn.Module):
    def __init__(self, block, num_blocks):
        super(P, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        return out

def get_P_model():
    return P(Bottleneck, [3,4,6])

def ResNet50(number_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes=number_classes)


In [0]:
# Miscellaneous

class NormalizeByChannelMeanStd(nn.Module):
    def __init__(self, mean, std):
        super(NormalizeByChannelMeanStd, self).__init__()
        if not isinstance(mean, torch.Tensor):
            mean = torch.tensor(mean)
        if not isinstance(std, torch.Tensor):
            std = torch.tensor(std)
        self.register_buffer("mean", mean)
        self.register_buffer("std", std)

    def forward(self, tensor):
        return normalize_fn(tensor, self.mean, self.std)

    def extra_repr(self):
        return 'mean={}, std={}'.format(self.mean, self.std)


def normalize_fn(tensor, mean, std):
    """Differentiable version of torchvision.functional.normalize"""
    # here we assume the color channel is in at dim=1
    mean = mean[None, :, None, None]
    std = std[None, :, None, None]
    return tensor.sub(mean).div(std)

class stats:
    def __init__(self, path, start_epoch):
        if start_epoch != 0:
           stats_ = sio.loadmat(os.path.join(path,'stats.mat'))
           data = stats_['data']
           content = data[0,0]
           self.trainObj = content['trainObj'][:,:start_epoch].squeeze().tolist()
           self.trainTop1 = content['trainTop1'][:,:start_epoch].squeeze().tolist()
           self.valObj = content['valObj'][:,:start_epoch].squeeze().tolist()
           self.valTop1 = content['valTop1'][:,:start_epoch].squeeze().tolist()

           self.avalObj = content['adv_valObj'][:,:start_epoch].squeeze().tolist()
           self.avalTop1 = content['adv_prec1'][:,:start_epoch].squeeze().tolist()
           if start_epoch == 1:
               self.trainObj = [self.trainObj]
               self.trainTop1 = [self.trainTop1]
               self.valObj = [self.valObj]
               self.valTop1 = [self.valTop1]
               self.avalObj = [self.avalObj]
               self.avalTop1 = [self.avalTop1]
        else:
           self.trainObj = []
           self.trainTop1 = []
           self.valObj = []
           self.valTop1 = []
           self.avalObj = []
           self.avalTop1 = []
           
    def _update(self, trainObj, top1, valObj, prec1, avalObj, aprec1):
        self.trainObj.append(trainObj)
        self.trainTop1.append(top1.cpu().numpy())
        self.valObj.append(valObj)
        self.valTop1.append(prec1.cpu().numpy())
        self.avalObj.append(avalObj)
        self.avalTop1.append(aprec1.cpu().numpy())


def plot_curve(stats, path, iserr):
    
    trainObj = np.array(stats.trainObj)
    valObj = np.array(stats.valObj)
    avalObj = np.array(stats.avalObj)
    if iserr:
        trainTop1 = 100 - np.array(stats.trainTop1)
        valTop1 = 100 - np.array(stats.valTop1)
        avalTop1 = 100 - np.array(stats.avalTop1)
        titleName = 'error'
    else:
        trainTop1 = np.array(stats.trainTop1)
        valTop1 = np.array(stats.valTop1)
        avalTop1 = np.array(stats.avalTop1)
        titleName = 'accuracy'
        
    epoch = len(trainObj)
    figure = plt.figure()
    obj = plt.subplot(1,2,1)
    obj.plot(range(1,epoch+1),trainObj,'o-',label = 'train')
    obj.plot(range(1,epoch+1),valObj,'o-',label = 'val')
    obj.plot(range(1,epoch+1),avalObj,'o-',label = 'adv_val')
    plt.xlabel('epoch')
    plt.title('objective')
    handles, labels = obj.get_legend_handles_labels()
    obj.legend(handles[::-1], labels[::-1])
    top1 = plt.subplot(1,2,2)
    top1.plot(range(1,epoch+1),trainTop1,'o-',label = 'train')
    top1.plot(range(1,epoch+1),valTop1,'o-',label = 'val')
    top1.plot(range(1,epoch+1),avalTop1,'o-',label = 'adv_val')
    plt.title('top1'+titleName)
    plt.xlabel('epoch')
    handles, labels = top1.get_legend_handles_labels()
    top1.legend(handles[::-1], labels[::-1])
    filename = os.path.join(path, 'net-train.pdf')
    figure.savefig(filename, bbox_inches='tight')
    plt.close()


In [0]:
# Dataset

class CifarDataset(Dataset):

    def __init__(self, _dir, train, transform, percent):

        self.dir=osp.join( _dir, 'cifar-10-batches-py')
        self.transforms=transform
        train_filenames = ['data_batch_{}'.format(ii + 1) for ii in range(5)]
        eval_filename = 'test_batch'

        if train:
            data_images = np.zeros((50000, 32, 32, 3), dtype='uint8')
            data_labels = np.zeros(50000, dtype='int32')
            for ii, fname in enumerate(train_filenames):
                cur_images, cur_labels = self._load_datafile(osp.join(self.dir, fname))
                data_images[ii * 10000 : (ii+1) * 10000, ...] = cur_images
                data_labels[ii * 10000 : (ii+1) * 10000, ...] = cur_labels
            permutation = np.random.permutation(50000)
            self.choose_images=data_images[permutation]
            self.choose_target=data_labels[permutation]
            choose = []
            all_indexes = np.array(range(50000))
            self.number = int(50000 * percent)
            for i in range(10):
                indexes = all_indexes[self.choose_target == i]
                choose.append(indexes[:int(len(indexes) * percent)])
            choose = np.concatenate(choose, 0)
            self.choose_images = self.choose_images[choose]
            self.choose_target = self.choose_target[choose]
        else:
            data_images, data_labels = self._load_datafile(osp.join(self.dir, eval_filename))
            self.number=int(10000*percent)
            permutation = np.random.permutation(10000)
            self.choose_images=data_images[permutation[:self.number]]
            self.choose_target=data_labels[permutation[:self.number]]
            self.number = 10000

    def __len__(self):
        return self.number

    def __getitem__(self, index):
        img=self.choose_images[index]
        target = self.choose_target[index]
        img = self.transforms(img)

        return img,target

    @staticmethod
    def _load_datafile(filename):
        with open(filename, 'rb') as fo:
            data_dict = pickle.load(fo, encoding='bytes')
            assert data_dict[b'data'].dtype == np.uint8
            image_data = data_dict[b'data']
            image_data = image_data.reshape((10000, 3, 32, 32)).transpose(0, 2, 3, 1)
            return image_data, np.array(data_dict[b'labels'])
        
class ImageNetDataset(Dataset):

    def __init__(self, _dir, train, transform, percent):

        self.dir=_dir
        self.transforms=transform
        train_filenames = ['train_data_batch_{}'.format(ii + 1) for ii in range(10)]
        eval_filename = 'val_data'

        if train:
            data_images = []
            data_labels = []
            for ii, fname in enumerate(train_filenames):
                cur_images, cur_labels = self._load_datafile(osp.join(self.dir, fname))
                data_images.append(cur_images)
                data_labels.append(cur_labels)
            data_images = np.concatenate(data_images, axis = 0)
            data_labels = np.concatenate(data_labels, axis = 0)
            n = data_images.shape[0]
            permutation = np.random.permutation(n)
            self.choose_images=data_images[permutation]
            self.choose_target=data_labels[permutation]
            choose = []
            all_indexes = np.array(range(n))
            for i in range(1000):
                indexes = all_indexes[self.choose_target == i]
                choose.append(indexes[:int(len(indexes) * percent)])
            choose = np.concatenate(choose, 0)
            self.choose_images = self.choose_images[choose]
            self.choose_target = self.choose_target[choose]
        else:
            data_images, data_labels = self._load_datafile(osp.join(self.dir, eval_filename))
            n = data_images.shape[0]
            self.number=int(n*percent)
            self.choose_images=data_images[:self.number]
            self.choose_target=data_labels[:self.number]

    def __len__(self):
        return len(self.choose_images)

    def __getitem__(self, index):
        img = self.choose_images[index]
        target = self.choose_target[index]
        img = self.transforms(img)

        return img,target -1

    @staticmethod
    def _load_datafile(filename):
        with open(filename, 'rb') as fo:
            data_dict = pickle.load(fo, encoding='bytes')
            assert data_dict['data'].dtype == np.uint8
            image_data = data_dict['data']
            image_data = image_data.reshape((image_data.shape[0], 3, 32, 32)).transpose(0, 2, 3, 1)
            return image_data, np.array(data_dict['labels'])


In [0]:
# Arguments

parser = argparse.ArgumentParser(description='Selfie')
  
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
                    help='model architecture: ')
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of steps of selfie')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate')
parser.add_argument('--lr-method', default='step', type=str,
                    help='method of learning rate')
parser.add_argument('--lr-params', default=[], dest='lr_params',nargs='*',type=float,
                    action='append', help='params of lr method')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=50, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--data',default="./data/",
                    help='path to dataset')
parser.add_argument('--dataset', type=str, default="cifar") 
parser.add_argument('--modeldir', default="imagenet_adv_selfie", type=str,
                    help='director of checkpoint')
parser.add_argument('--store-model-everyepoch', dest='store_model_everyepoch', action='store_true',
                    help='store checkpoint in every epoch')
parser.add_argument('--percent', type=float,
                    help="Used data percent", default=1.0)
parser.add_argument('--evaluation', action="store_true")
parser.add_argument('--classification-model', type=str, default="")
parser.add_argument('--split-gpu', action="store_true")
parser.add_argument('--resume', action="store_true")
parser.add_argument('--finetune', action="store_true")
parser.add_argument('--evaluation-selfie', action="store_true")
parser.add_argument('--num-classes', type=int, default=10)
parser.add_argument('--seed', type=int, default=10)
best_prec1 = 0

In [0]:
def main():
    global args, best_prec1

    args = parser.parse_args()
    print(args)
    
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)
    torch.cuda.set_device(int(args.gpu))

    setup_seed(args.seed)

    # Data Preprocess
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    data_transforms = {
        'train': transforms.Compose([
            transforms.ToPILImage(),
            transforms.Pad(2),
            transforms.RandomCrop(32),
            transforms.ToTensor(),
        ]),
        'val': transforms.Compose([
            transforms.ToPILImage(),
            transforms.ToTensor(),
        ])
    } 
    if args.dataset == 'cifar':
        train_dataset = CifarDataset(args.data, True, data_transforms['train'], args.percent)
        test_dataset = CifarDataset(args.data, False, data_transforms['val'], 1)
    elif args.dataset == 'imagenet':
        train_dataset = ImageNetDataset(args.data, True, data_transforms['train'], args.percent)
        test_dataset = ImageNetDataset(args.data, False, data_transforms['val'], 1)

    elif args.dataset == 'imagenet224':
        data_transforms = {
            'train': transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]),
            'val': transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ])
        } 
        train_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['train'])
        test_dataset = datasets.ImageNet(args.data, 'train', True, data_transforms['val'])

    valid_size = 0.1
    indices = list(range(len(train_dataset)))
    split = int(np.floor(valid_size*len(train_dataset)))
    np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = torch.utils.data.Subset(train_dataset, train_idx)
    valid_sampler = torch.utils.data.Subset(train_dataset, valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_sampler,
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        valid_sampler,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)


    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    # define model 
    n_split = 4
    selfie_model = get_selfie_model(n_split)
    selfie_model = selfie_model.cuda()

    P=get_P_model()
    normalize = NormalizeByChannelMeanStd(
    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    P = nn.Sequential(normalize, P)
    P = P.cuda()

    #define optimizer and scheduler 
    params_list = [{'params': selfie_model.parameters(), 'lr': args.lr,
                        'weight_decay': args.weight_decay},]
    params_list.append({'params': P.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay})
    optimizer = torch.optim.SGD(params_list, lr=args.lr,
                                momentum=0.9,
                                weight_decay=args.weight_decay, nesterov = True)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: cosine_annealing(
            step,
            args.epochs * len(train_loader),
            1,  # since lr_lambda computes multiplicative factor
            1e-7 / args.lr))

    print("Training model.")
    step = 0
    if os.path.exists(args.modeldir) is not True:
        os.mkdir(args.modeldir)
    stats_ = stats(args.modeldir, args.start_epoch)

    if args.epochs > 0:

        #order of patches 
        all_seq=[np.random.permutation(16) for ind in range(400)]
        pickle.dump(all_seq, open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'wb'))
        # all_seq=pickle.load(open(os.path.join(args.modeldir, 'img_test_seq.pkl'),'rb'))
        
        print("Begin selfie training...")
        for epoch in range(args.start_epoch, args.epochs):
            print("The learning rate is {}".format(optimizer.param_groups[0]['lr']))
            trainObj, top1 = train_selfie(train_loader, selfie_model, P, criterion, optimizer, epoch, scheduler)

            valObj, prec1 = val_selfie(val_loader, selfie_model, P, criterion, all_seq)

            stats_._update(trainObj, top1, valObj, prec1, valObj, prec1)

            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            if is_best:
                torch.save(
                    {
                    'epoch': epoch,
                    'P_state': P.state_dict(),
                    'selfie_state': selfie_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                    }, os.path.join(args.modeldir, 'std_selfie_TA_model_best.pth.tar'))

            torch.save(
                    {
                    'epoch': epoch,
                    'P_state': P.state_dict(),
                    'selfie_state': selfie_model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_prec1': best_prec1,
                    }, os.path.join(args.modeldir, 'std_selfie_checkpoint.pth.tar'))

            plot_curve(stats_, args.modeldir, True)
            data = stats_
            sio.savemat(os.path.join(args.modeldir,'stats.mat'), {'data':data})
   

        print("testing TA best selfie model from checkpoint...")
        model_path = os.path.join(args.modeldir, 'std_selfie_TA_model_best.pth.tar')
        model_loaded = torch.load(model_path)

        P.load_state_dict(model_loaded['P_state'])
        selfie_model.load_state_dict(model_loaded['selfie_state'])
        print("Best TA selfie model loaded! ")
        
        valObj, prec1 = val_selfie(test_loader, selfie_model, P, criterion, all_seq)
        

In [0]:
def train_selfie(train_loader, selfie_model, P, criterion, optimizer, epoch, scheduler):
    global args
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    selfie_model.train()
    P.train()

    for index, (input, _) in enumerate(train_loader):
        data_time.update(time.time() - end)

        cur_batch_size = input.size(0)

        # if epoch == 0:
        #     warmup_lr(index, optimizer, 200)
        
        total=16
        seq = np.random.permutation(total)
        t = seq[:(total // 4)]
        v = seq[(total // 4):]
        v = torch.from_numpy(v).cuda()
        pos = t
        t = torch.from_numpy(np.array(pos)).cuda()

        input = input.cuda()

        #selfie forward
        batches = split_image_selfie(input, 8)

        batches = list(map(lambda x: x.unsqueeze(1), batches))
        batches = torch.cat(batches, 1) # (B, L, C, H, W)

        input_batches = torch.split(batches, 1, 1)
        input_batches = list(map(lambda x: x.squeeze(1), input_batches))
        input_batches = torch.cat(input_batches, 0)

        output_batches = P(input_batches)

        output_batches = output_batches.unsqueeze(1)
        output_batches = torch.split(output_batches, cur_batch_size, 0)
        output_batches = torch.cat(output_batches,1)

        output_decoder = output_batches.index_select(1, t)
        
        output_encoder = output_batches.index_select(1, v)
        output_encoder = selfie_model(output_encoder, pos)

        features = []
        for i in range(len(pos)):
            feature = output_decoder[:, i, :]
            feature = feature.unsqueeze(2)
            features.append(feature)

        features = torch.cat(features, 2) # (B, F, NP)
        patch_loss = 0

        for i in range(len(t)):
            activate = output_encoder[:, i, :].unsqueeze(1)
            pre = torch.bmm(activate, features)
            logit = nn.functional.softmax(pre, 2).view(-1, len(t))
            temptarget = torch.ones(logit.shape[0]).cuda() * i
            temptarget = temptarget.long()
            loss_ = criterion(logit, temptarget)
            patch_loss += loss_
            prec1_adv, _ = accuracy(logit, temptarget, topk=(1,3))
            top1.update(prec1_adv[0], 1)

        optimizer.zero_grad()
        patch_loss.backward()
        optimizer.step()
        scheduler.step()

        all_loss = patch_loss.float()
        losses.update(all_loss.item(), input.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        if index % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                   epoch, index, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
    return losses.avg, top1.avg


In [0]:
def val_selfie(val_loader, selfie_model, P, criterion, all_seq):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()
    selfie_model.eval()
    P.eval()
    with torch.no_grad():
        for index, (input, _) in enumerate(val_loader):
            #print(input)
            data_time.update(time.time() - end)
            input = input.cuda()

            cur_batch_size = input.size(0)

            total=16
            seq = all_seq[index]
            t = seq[:(total // 4)]
            v = seq[(total // 4):]
            v = torch.from_numpy(v).cuda()
            pos = t

            t = torch.from_numpy(np.array(pos)).cuda()

            #selfie forward
            batches = split_image_selfie(input, 8)

            batches = list(map(lambda x: x.unsqueeze(1), batches))
            batches = torch.cat(batches, 1) # (B, L, C, H, W)

            input_batches = torch.split(batches, 1, 1)
            input_batches = list(map(lambda x: x.squeeze(1), input_batches))
            input_batches = torch.cat(input_batches, 0)

            output_batches = P(input_batches)

            output_batches = output_batches.unsqueeze(1)
            output_batches = torch.split(output_batches, cur_batch_size, 0)
            output_batches = torch.cat(output_batches,1)

            output_decoder = output_batches.index_select(1, t)
            
            output_encoder = output_batches.index_select(1, v)
            output_encoder = selfie_model(output_encoder, pos)

            features = []
            for i in range(len(pos)):
                feature = output_decoder[:, i, :]
                feature = feature.unsqueeze(2)
                features.append(feature)

            features = torch.cat(features, 2) # (B, F, NP)
            patch_loss = 0

            for i in range(len(t)):
                activate = output_encoder[:, i, :].unsqueeze(1)
                pre = torch.bmm(activate, features)
                logit = nn.functional.softmax(pre, 2).view(-1, len(t))
                temptarget = torch.ones(logit.shape[0]).cuda() * i
                temptarget = temptarget.long()
                loss_ = criterion(logit, temptarget)

                prec1, _ = accuracy(logit, temptarget, topk=(1,3))

                losses.update(loss_.item(), 1)
                top1.update(prec1[0], 1)

            batch_time.update(time.time() - end)
            end = time.time()

            if index % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'.format(
                           index, len(val_loader), batch_time=batch_time, loss=losses,
                           top1=top1))
            #raise NotImplementedError
        print(' * Prec@1 {top1.avg:.3f}'
              .format(top1=top1))
        return losses.avg, top1.avg
        

In [0]:
# Utils

def get_selfie_model(n_split):
    n_layers = 12
    d_model = 1024 #vector length after the patch routed in P
    d_in = 64
    n_heads = d_model// d_in
    d_ff = 2048
    model = SelfieModel(n_layers, n_heads, d_in, d_model, d_ff, n_split)
    return model

def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))

def accuracy(output, target, topk=(1,)):
    #print(output.shape)
    #print(target.shape)
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        #print(target)
        if (target.dim() > 1):
            target = torch.argmax(target, 1)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename[0])
    if is_best:
        shutil.copyfile(filename[0], filename[1])

def setup_seed(seed): 
    torch.manual_seed(seed) 
    torch.cuda.manual_seed_all(seed) 
    np.random.seed(seed) 
    random.seed(seed) 
    torch.backends.cudnn.deterministic = True 

def warmup_lr(step, optimizer, speed):
    lr = 0.01+step*(0.1-0.01)/speed
    lr = min(lr,0.1)
    for p in optimizer.param_groups:
        p['lr']=lr

def split_image_selfie(image, N):
    """
    image: (B, C, W, H)
    """
    batches = []

    for i in list(torch.split(image, N, dim=2)):
        batches.extend(list(torch.split(i, N, dim=3)))

    return batches

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [0]:
if __name__ == '__main__':
    main()