In [None]:
import torch
from torch import optim
from torch import nn 
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import math
from sklearn.linear_model._base import LinearModel, LinearClassifierMixin
import numpy as np 
from torch.nn.modules.loss import _Loss
from timeit import default_timer as timer
import torch.nn.functional as F
from scipy import optimize
import copy


outpath = '/kaggle/working/'
data_path = '/kaggle/working/data/'


In [None]:
# layers
class Linear(nn.Linear, LinearModel, LinearClassifierMixin):
    def __init__(self, in_features, out_features, alpha=0.0, fit_bias=True,
                 penalty="l2", maxiter=1000):
        super(Linear, self).__init__(in_features, out_features, fit_bias)
        self.alpha = alpha
        self.fit_bias = fit_bias
        self.penalty = penalty
        self.maxiter = maxiter

    def forward(self, input, scale_bias=1.0):
        # out = super(Linear, self).forward(input)
        out = F.linear(input, self.weight, scale_bias * self.bias)
        return out

    def fit(self, x, y, criterion=None):
        # self.cuda()
        use_cuda = self.weight.is_cuda
        # print(use_cuda)
        if criterion is None:
            criterion = nn.CrossEntropyLoss()
        # reduction = criterion.reduction
        # criterion.reduction = 'sum'
        if isinstance(x, np.ndarray) or isinstance(y, np.ndarray):
            x = torch.from_numpy(x)
            y = torch.from_numpy(y)
        if use_cuda:
            x = x.cuda()
            y = y.cuda()

        alpha = self.alpha * x.shape[1] / x.shape[0]
        if self.bias is not None:
            scale_bias = (x ** 2).mean(-1).sqrt().mean().item()
            alpha *= scale_bias ** 2
        self.real_alpha = alpha
        self.scale_bias = scale_bias

        def eval_loss(w):
            w = w.reshape((self.out_features, -1))
            if self.weight.grad is not None:
                self.weight.grad = None
            if self.bias is None:
                self.weight.data.copy_(torch.from_numpy(w))
            else:
                if self.bias.grad is not None:
                    self.bias.grad = None
                self.weight.data.copy_(torch.from_numpy(w[:, :-1]))
                self.bias.data.copy_(torch.from_numpy(w[:, -1]))
            y_pred = self(x, scale_bias=scale_bias).squeeze_(-1)
            loss = criterion(y_pred, y)
            loss.backward()
            if alpha != 0.0:
                if self.penalty == "l2":
                    penalty = 0.5 * alpha * torch.norm(self.weight)**2
                elif self.penalty == "l1":
                    penalty = alpha * torch.norm(self.weight, p=1)
                    penalty.backward()
                loss = loss + penalty
            return loss.item()

        def eval_grad(w):
            dw = self.weight.grad.data
            if alpha != 0.0:
                if self.penalty == "l2":
                    dw.add_(alpha, self.weight.data)
            if self.bias is not None:
                db = self.bias.grad.data
                dw = torch.cat((dw, db.view(-1, 1)), dim=1)
            return dw.cpu().numpy().ravel().astype("float64")

        w_init = self.weight.data
        if self.bias is not None:
            w_init = torch.cat((w_init, 1./scale_bias * self.bias.data.view(-1, 1)), dim=1)
        w_init = w_init.cpu().numpy().astype("float64")

        w = optimize.fmin_l_bfgs_b(
            eval_loss, w_init, fprime=eval_grad, maxiter=self.maxiter, disp=0)
        if isinstance(w, tuple):
            w = w[0]

        w = w.reshape((self.out_features, -1))
        self.weight.grad.data.zero_()
        if self.bias is None:
            self.weight.data.copy_(torch.from_numpy(w))
        else:
            self.bias.grad.data.zero_()
            self.weight.data.copy_(torch.from_numpy(w[:, :-1]))
            self.bias.data.copy_(scale_bias * torch.from_numpy(w[:, -1]))

    def decision_function(self, x):
        x = torch.from_numpy(x)
        if self.weight.is_cuda:
            x = x.cuda()
        return self(x).data.cpu().numpy()

    def predict(self, x):
        return np.argmax(self.decision_function(x), axis=1)

    def predict_proba(self, x):
        return self._predict_proba_lr(x)

    @property
    def coef_(self):
        return self.weight.data.cpu().numpy()

    @property
    def intercept_(self):
        return self.bias.data.cpu().numpy()
    
    
    
    
    
class CKNLayer(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,
        padding="SAME", dilation=1, groups=1, subsampling=1, bias=False,
        kernel_func="exp", kernel_args=[0.5], kernel_args_trainable=False):
        """Define a CKN layer
        Args:
            kernel_args: an iterable object of paramters for kernel function
        """
        if padding == "SAME":
            padding = kernel_size // 2
        else:
            padding = 0
        super(CKNLayer, self).__init__(in_channels, out_channels, kernel_size, 
        stride=1, padding=padding, dilation=dilation, groups=groups, bias=False)
        self.normalize_()
        self.subsampling = subsampling

        self.patch_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
        
        self._need_lintrans_computed = True 

        self.kernel_args_trainable = kernel_args_trainable
        self.kernel_func = kernel_func
        if isinstance(kernel_args, (int, float)):
            kernel_args = [kernel_args]
        if kernel_func == "exp":
            kernel_args = [1./kernel_arg ** 2 for kernel_arg in kernel_args]
        self.kernel_args = kernel_args
        if kernel_args_trainable:
            self.kernel_args = nn.ParameterList(
                [nn.Parameter(torch.Tensor([kernel_arg])) for kernel_arg in kernel_args])

        kernel_func = self._get_kernel(kernel_func)
        self.kappa = lambda x: kernel_func(x, *self.kernel_args)

        self.register_buffer("ones",
            torch.ones(1, self.in_channels // self.groups, *self.kernel_size))
        self.init_pooling_filter()

        self.ckn_bias = None
        if bias:
            self.ckn_bias = nn.Parameter(
                torch.zeros(1, self.in_channels // self.groups, *self.kernel_size))

        self.register_buffer("lintrans",
            torch.Tensor(out_channels, out_channels))
        
    def _get_kernel(self, kernel_func):
        def exp(x, alpha):
            """Element wise non-linearity
            kernel_exp is defined as k(x)=exp(alpha * (x-1))
            return:
                same shape tensor as x
            """
            return torch.exp(alpha*(x - 1.))

        def poly(x, alpha=None):
            return x.pow(2)
        return exp if kernel_func == "exp" else poly

    def _gaussian_filter_1d(self, size, sigma=None):
        """Create 1D Gaussian filter
        """
        if size == 1:
            return torch.ones(1)
        if sigma is None:
            sigma = (size - 1.) / (2.*math.sqrt(2))
        m = (size - 1) / 2.
        filt = torch.arange(-m, m+1)
        filt = torch.exp(-filt.pow(2)/(2.*sigma*sigma))
        return filt/torch.sum(filt)   

    def init_pooling_filter(self):
        size = 2 * self.subsampling + 1
        pooling_filter = self._gaussian_filter_1d(size, self.subsampling/math.sqrt(2)).view(-1, 1)
        pooling_filter = pooling_filter.mm(pooling_filter.t())
        pooling_filter = pooling_filter.expand(self.out_channels, 1, size, size).clone()
        self.register_buffer("pooling_filter", pooling_filter)

    def train(self, mode=True):
        super(CKNLayer, self).train(mode)
        self._need_lintrans_computed = True 

    def _compute_lintrans(self):
        """Compute the linear transformation factor kappa(ZtZ)^(-1/2)
        Returns:
            lintrans: out_channels x out_channels
        """
        if not self._need_lintrans_computed:
            return self.lintrans
        lintrans = self.weight.view(self.out_channels, -1)
        lintrans = lintrans.mm(lintrans.t())
        lintrans = self.kappa(lintrans)
        lintrans = matrix_inverse_sqrt(lintrans)
        if not self.training:
            self._need_lintrans_computed = False 
            self.lintrans.data = lintrans.data 

        return lintrans

    def _conv_layer(self, x_in):
        """Convolution layer
        Compute x_out = ||x_in|| x kappa(Zt x_in/||x_in||)
        Args:
            x_in: batch_size x in_channels x H x W
            self.filters: out_channels x in_channels x *kernel_size
            x_out: batch_size x out_channels x (H - kernel_size + 1) x (W - kernel_size + 1)
        """
        if self.ckn_bias is not None:
            # compute || x - b ||
            patch_norm_x = F.conv2d(x_in.pow(2), self.ones, bias=None,
                                    stride=1, padding=self.padding,
                                    dilation=self.dilation, 
                                    groups=self.groups)
            patch_norm = patch_norm_x - 2 * F.conv2d(x_in, self.ckn_bias, bias=None,
                stride=1, padding=self.padding, dilation=self.dilation, 
                groups=self.groups)
            patch_norm = patch_norm + self.ckn_bias.pow(2).sum()
            patch_norm = torch.sqrt(patch_norm.clamp(min=1e-6))

            x_out = super(CKNLayer, self).forward(x_in)
            bias = torch.sum(
                (self.weight * self.ckn_bias).view(self.out_channels, -1), dim=-1)
            bias = bias.view(1, self.out_channels, 1, 1)
            x_out = x_out - bias
            x_out = x_out / patch_norm.clamp(min=1e-6)
            x_out = patch_norm * self.kappa(x_out)
            return x_out

        patch_norm = torch.sqrt(F.conv2d(x_in.pow(2), self.ones, bias=None,
            stride=1, padding=self.padding, dilation=self.dilation, 
            groups=self.groups).clamp(min=1e-6))

        x_out = super(CKNLayer, self).forward(x_in)
        x_out = x_out / patch_norm.clamp(min=1e-6)
        x_out = patch_norm * self.kappa(x_out)
        return x_out

    def _mult_layer(self, x_in, lintrans):
        """Multiplication layer
        Compute x_out = kappa(ZtZ)^(-1/2) x x_in
        Args:
            x_in: batch_size x in_channels x H x W
            lintrans: in_channels x in_channels
            x_out: batch_size x in_channels x H x W
        """
        batch_size, in_c, H, W = x_in.size()
        x_out = torch.bmm(
            lintrans.expand(batch_size, in_c, in_c).clone(), x_in.view(batch_size, in_c, -1))
        return x_out.view(batch_size, in_c, H, W)

    def _pool_layer(self, x_in):
        """Pooling layer
        Compute I(z) = \sum_{z'} phi(z') x exp(-\beta_1 ||z'-z||_2^2)
        Args:
            x_in: batch_size x out_channels x H x W
        """
        if self.subsampling <= 1:
            return x_in
        x_out = F.conv2d(x_in, self.pooling_filter, bias=None, 
            stride=self.subsampling, padding=self.subsampling, 
            groups=self.out_channels)
        return x_out

    def forward(self, x_in):
        """Encode function for a CKN layer
        Args:
            x_in: batch_size x in_channels x H x W
        """
        x_out = self._conv_layer(x_in)
        #print(x_out.shape)
        x_out = self._pool_layer(x_out)
        lintrans = self._compute_lintrans()
        x_out = self._mult_layer(x_out, lintrans)
        #print(x_out.shape)
        return x_out

    def extract_2d_patches(self, x):
        """
        x: batch_size x C x H x W
        out: (batch_size * nH * nW) x (C * kernel_size)
        """
        h, w = self.kernel_size
#         print(h)
#         print(w)
#         print(self.patch_dim)
        return x.unfold(2, h, 1).unfold(3, w, 1).transpose(1, 3).contiguous().view(-1, self.patch_dim)

    def sample_patches(self, x_in, n_sampling_patches=1000):
        """Sample patches from the given Tensor
        Args:
            x_in (batch_size x in_channels x H x W)
            n_sampling_patches (int): number of patches to sample
        Returns:
            patches: (batch_size x (H - filter_size + 1)) x (in_channels x filter_size)
        """
#         print(x_in)
        patches = self.extract_2d_patches(x_in)
        
        n_sampling_patches = min(patches.size(0), n_sampling_patches)
        patches = patches[:n_sampling_patches]
        return patches

    def unsup_train_(self, patches):
        """Unsupervised training for a CKN layer
        Args:
            patches: n x (in_channels x *kernel_size)
        Updates:
            filters: out_channels x in_channels x *kernel_size
        """
        if self.ckn_bias is not None:
            print("estimating bias")
            m_patches = patches.mean(0)
            self.ckn_bias.data.copy_(m_patches.view_as(self.ckn_bias.data))
            patches -= m_patches
        patches = normalize_(patches)
        block_size = None if self.patch_dim < 1000 else 10 * self.patch_dim
        weight = spherical_kmeans(patches, self.out_channels, block_size=block_size)
        weight = weight.view_as(self.weight.data)
        self.weight.data.copy_(weight)
        self._need_lintrans_computed = True 

    def normalize_(self):
        norm = self.weight.data.view(
            self.out_channels, -1).norm(p=2, dim=-1).view(-1, 1, 1, 1)
        self.weight.data.div_(norm.clamp_(min=1e-6))

    def extra_repr(self):
        s = super(CKNLayer, self).extra_repr()
        s += ', subsampling={}'.format(self.subsampling)
        s += ', kernel=({}, {})'.format(self.kernel_func, self.kernel_args)
        return s

In [None]:
# utils

def create_dataset(train=True, dataugmentation=False):
    # load dataset
        
    tr = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.Resize(32),
    transforms.ToTensor(),
#     transforms.Lambda(rep),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    if dataugmentation:
        dt = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
        tr = dt + tr
    dataset = torchvision.datasets.MNIST(
        data_path,
        train=train,
        transform=tr,
        download=True,
    )
    return dataset

def count_parameters(model):
    count = 0
    for param in model.parameters():
        count += np.prod(param.data.size())
    return count

def normalize_(x, p=2, dim=-1):
    norm = x.norm(p=p, dim=dim, keepdim=True)
    x.div_(norm.clamp(min=1e-6))
    return x 

def spherical_kmeans(x, n_clusters, max_iters=100, block_size=None, verbose=True, init=None):
    """Spherical kmeans
    Args:
        x (Tensor n_samples x n_features): data points
        n_clusters (int): number of clusters
    """
    print(x.shape)
    use_cuda = x.is_cuda
    n_samples, n_features = x.size()
    if init is None:
        indices = torch.randperm(n_samples)[:n_clusters]
        if use_cuda:
            indices = indices.cuda()
        clusters = x[indices]

    prev_sim = np.inf
    tmp = x.new_empty(n_samples)
    assign = x.new_empty(n_samples, dtype=torch.long)
    if block_size is None or block_size == 0:
        block_size = x.shape[0]

    for n_iter in range(max_iters):
        # assign data points to clusters
        for i in range(0, n_samples, block_size):
            end_i = min(i + block_size, n_samples)
            cos_sim = x[i: end_i].mm(clusters.t())
            tmp[i: end_i], assign[i: end_i] = cos_sim.max(dim=-1)
        # cos_sim = x.mm(clusters.t())
        # tmp, assign = cos_sim.max(dim=-1)
        sim = tmp.mean()
        if (n_iter + 1) % 10 == 0 and verbose:
            print("Spherical kmeans iter {}, objective value {}".format(
                n_iter + 1, sim))

        # update clusters
        for j in range(n_clusters):
            index = assign == j
            if index.sum().item() == 0:
                idx = tmp.argmin()
                clusters[j] = x[idx]
                tmp[idx] = 1.
            else:
                xj = x[index]
                c = xj.mean(0)
                clusters[j] = c / c.norm().clamp(min=1e-6)

        if torch.abs(prev_sim - sim)/(torch.abs(sim)+1e-20) < 1e-6:
            break
        prev_sim = sim
    return clusters


class MatrixInverseSqrt(torch.autograd.Function):
    """Matrix inverse square root for a symmetric definite positive matrix
    """
    @staticmethod
    def forward(ctx, input, eps=1e-2):
        use_cuda = input.is_cuda
        #if input.size(0) < 300:
        #    input = input.cpu()
        input = input.cpu()
        #print(torch.isnan(input).any())
        e, v = torch.linalg.eigh(input)
        if use_cuda:
            e = e.cuda()
            v = v.cuda()
        e.clamp_(min=0)
        e_sqrt = e.sqrt_().add_(eps)
        ctx.save_for_backward(e_sqrt, v)
        e_rsqrt = e_sqrt.reciprocal()

        output = v.mm(torch.diag(e_rsqrt).mm(v.t()))
        return output

    @staticmethod
    def backward(ctx, grad_output):
        e_sqrt, v = ctx.saved_variables
        ei = e_sqrt.expand_as(v)
        ej = e_sqrt.view([-1, 1]).expand_as(v)
        f = torch.reciprocal((ei + ej) * ei * ej)
        grad_input = -v.mm((f*(v.t().mm(grad_output.mm(v)))).mm(v.t()))
        return grad_input, None


def matrix_inverse_sqrt(input, eps=1e-2):
    """Wrapper for MatrixInverseSqrt"""
    return MatrixInverseSqrt.apply(input, eps)

In [None]:
# models


class CKNSequential(nn.Module):
    def __init__(self, in_channels, out_channels_list, kernel_sizes, 
                 subsamplings, kernel_funcs=None, kernel_args_list=None,
                 kernel_args_trainable=False, **kwargs):

        assert len(out_channels_list) == len(kernel_sizes) == len(subsamplings), "incompatible dimensions"
        super(CKNSequential, self).__init__()

        self.n_layers = len(out_channels_list)
        self.in_channels = in_channels
        self.out_channels = out_channels_list[-1]
        
        ckn_layers = []

        for i in range(self.n_layers):
            if kernel_funcs is None:
                kernel_func = "exp"
            else:
                kernel_func = kernel_funcs[i] 
            if kernel_args_list is None:
                kernel_args = 0.5
            else:
                kernel_args = kernel_args_list[i]
            
            ckn_layer = CKNLayer(in_channels, out_channels_list[i],
                                 kernel_sizes[i], subsampling=subsamplings[i],
                                 kernel_func=kernel_func, kernel_args=kernel_args,
                                 kernel_args_trainable=kernel_args_trainable, **kwargs)

            ckn_layers.append(ckn_layer)
            in_channels = out_channels_list[i]

        self.ckn_layers = nn.Sequential(*ckn_layers)

    def __getitem__(self, idx):
        return self.ckn_layers[idx]

    def __len__(self):
        return len(self.ckn_layers)

    def __iter__(self):
        return self.ckn_layers._modules.values().__iter__()

    def forward_at(self, x, i=0):
        assert x.size(1) == self.ckn_layers[i].in_channels, "bad dimension"
        return self.ckn_layers[i](x)

    def forward(self, x):
        return self.ckn_layers(x)

    def representation(self, x, n=0):
        if n == -1:
            n = self.n_layers
        for i in range(n):
            x = self.forward_at(x, i)
        return x 

    def normalize_(self):
        for module in self.ckn_layers:
            module.normalize_()

    def unsup_train_(self, data_loader, n_sampling_patches=100000, use_cuda=False, top_layers=None):
        """
        x: size x C x H x W 
        top_layers: module object represents layers before this layer
        """
        self.train(False)
        if use_cuda:
            self.cuda()
        with torch.no_grad():
            for i, ckn_layer in enumerate(self.ckn_layers):
                print()
                print('-------------------------------------')
                print('   TRAINING LAYER {}'.format(i + 1))
                print('-------------------------------------')
                n_patches = 0 
                try:
                    n_patches_per_batch = (n_sampling_patches + len(data_loader) - 1) // len(data_loader) 
                except:
                    n_patches_per_batch = 1000
                patches = torch.Tensor(n_sampling_patches, ckn_layer.patch_dim)
                if use_cuda:
                    patches = patches.cuda()

                for data, _ in data_loader:
                    if use_cuda:
                        data = data.cuda()
                    # data = Variable(data, volatile=True)
                    if top_layers is not None:
                        data = top_layers(data)
                    data = self.representation(data, i)
#                     print(n_patches_per_batch)
                    data_patches = ckn_layer.sample_patches(data.data, n_patches_per_batch)
                    size = data_patches.size(0)
                    if n_patches + size > n_sampling_patches:
                        size = n_sampling_patches - n_patches
                        data_patches = data_patches[:size]
                    patches[n_patches: n_patches + size] = data_patches
                    n_patches += size 
                    if n_patches >= n_sampling_patches:
                        break

                print("total number of patches: {}".format(n_patches))
                patches = patches[:n_patches]
                ckn_layer.unsup_train_(patches)
    
class CKNet(nn.Module):
    def __init__(self, nclass, in_channels, out_channels_list, kernel_sizes, 
                 subsamplings, kernel_funcs=None, kernel_args_list=None,
                 kernel_args_trainable=False, image_size=32,
                 fit_bias=True, alpha=0.0, maxiter=1000, **kwargs):
        super(CKNet, self).__init__()
        self.features = CKNSequential(
            in_channels, out_channels_list, kernel_sizes, 
            subsamplings, kernel_funcs, kernel_args_list,
            kernel_args_trainable, **kwargs)

        out_features = out_channels_list[-1]
        factor = 1
        for s in subsamplings:
            factor *= s
        factor = (image_size - 1) // factor + 1
        self.out_features = factor * factor * out_features
        self.nclass = nclass

        self.initialize_scaler()
        self.classifier = Linear(
            self.out_features, nclass, fit_bias=fit_bias, alpha=alpha, maxiter=maxiter)

    def initialize_scaler(self, scaler=None):
        pass

    def forward(self, input):
        features = self.representation(input)
        return self.classifier(features)

    def representation(self, input):
        features = self.features(input).view(input.shape[0], -1)
        if hasattr(self, 'scaler'):
            features = self.scaler(features)
        return features

    def unsup_train_ckn(self, data_loader, n_sampling_patches=1000000,
                        use_cuda=False):
        self.features.unsup_train_(data_loader, n_sampling_patches, use_cuda=use_cuda)

    def unsup_train_classifier(self, data_loader, criterion=None, use_cuda=False):
        encoded_train, encoded_target = self.predict(
            data_loader, only_representation=True, use_cuda=use_cuda)
        self.classifier.fit(encoded_train, encoded_target, criterion)

    def predict(self, data_loader, only_representation=False, use_cuda=False):
        self.eval()
        if use_cuda:
            self.cuda()
        n_samples = len(data_loader.dataset)
        batch_start = 0
        for i, (data, target) in enumerate(data_loader):
            batch_size = data.shape[0]
            if use_cuda:
                data = data.cuda()
            with torch.no_grad():
                if only_representation:
                    batch_out = self.representation(data).data.cpu()
                else:
                    batch_out = self(data).data.cpu()
            if i == 0:
                output = batch_out.new_empty(n_samples, batch_out.shape[-1])
                target_output = target.new_empty(n_samples)
            output[batch_start:batch_start+batch_size] = batch_out
            target_output[batch_start:batch_start+batch_size] = target
            batch_start += batch_size
        return output, target_output

    def normalize_(self):
        self.features.normalize_()

    def print_norm(self):
        norms = []
        with torch.no_grad():
            for module in self.features:
                norms.append(module.weight.sum().item())
            norms.append(self.classifier.weight.sum().item())
        print(norms)

class SupCKNetMnist10_5(CKNet):
    def __init__(self, alpha=0.0, **kwargs):
        kernel_sizes = [3, 1, 3, 1, 3]
        filters = [128, 128, 128, 128, 128]
        subsamplings = [2, 1, 2, 1, 3]
        kernel_funcs = ['exp', 'poly', 'exp', 'poly', 'exp']
        kernel_args_list = [0.5, 2, 0.5, 2, 0.5]
        super(SupCKNetMnist10_5, self).__init__(
            10, 3, filters, kernel_sizes, subsamplings, kernel_funcs=kernel_funcs,
            kernel_args_list=kernel_args_list, fit_bias=True, alpha=alpha, maxiter=5000, **kwargs)


In [None]:
# loss function
class HingeLoss(_Loss):
    def __init__(self, nclass=10, weight=None, size_average=None, reduce=None,
                 reduction='elementwise_mean', pos_weight=None, squared=True):
        super(HingeLoss, self).__init__(size_average, reduce, reduction)
        self.nclass = nclass
        self.squared = squared
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)

    def forward(self, input, target):
        if not (target.size(0) == input.size(0)):
            raise ValueError(
                "Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
        if self.pos_weight is not None:
            pos_weight = 1 + (self.pos_weight - 1) * target
        target = 2 * F.one_hot(target, num_classes=self.nclass) - 1
        target = target.float()
        loss = F.relu(1. - target * input)
        if self.squared:
            loss = 0.5 * loss ** 2
        if self.weight is not None:
            loss = loss * self.weight
        if self.pos_weight is not None:
            loss = loss * pos_weight
        loss = loss.sum(dim=-1)
        if self.reduction == 'none':
            return loss
        elif self.reduction == 'elementwise_mean':
            return loss.mean()
        else:
            return loss.sum()

In [None]:
# training
def sup_train(model, data_loader):
    criterion = HingeLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
    
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [60, 85, 100], gamma=0.1)
       

    print("Initialing CKN")
    tic = timer()
    model.unsup_train_ckn(
        data_loader['init'], 150000, use_cuda=False)
    toc = timer()
    print("Finished, elapsed time: {:.2f}min".format((toc - tic)/60))

    epoch_loss = None
    best_loss = float('inf')
    best_acc = 0
#### change to 105
    for epoch in range(15):
        print('Epoch {}/{}'.format(epoch + 1, 105))
        print('-' * 10)

        model.train(False)
        tic = timer()
        model.unsup_train_classifier(
            data_loader['train'], criterion=criterion, use_cuda=False)
        toc = timer()
        print('Last layer trained, elapsed time: {:.2f}s'.format(toc - tic))
            
        for phase in ['train', 'val']:
            if phase == 'train':
                if lr_scheduler is not None and epoch > 0:
                    try:
                        lr_scheduler.step(metrics=epoch_loss)
                    except:
                        lr_scheduler.step()
                print("current LR: {}".format(
                            optimizer.param_groups[0]['lr']))
                model.train()
            else:
                print("Evaluating...")
                model.eval()

            running_loss = 0.0
            running_acc = 0

            tic = timer()
            for data, target in data_loader[phase]:
                size = data.size(0)
#                 if args.gpu:
#                     data = data.cuda()
#                     target = target.cuda()

                # forward
                if phase == 'train':
                    optimizer.zero_grad()
                    output = model(data)
                    loss = criterion(output, target)
                    pred = output.data.argmax(dim=1)
                    loss.backward()
                    optimizer.step()
                    model.normalize_()
                else:
                    with torch.no_grad():
                        output = model(data)
                        loss = criterion(output, target)
                        pred = output.data.argmax(dim=1)
                
                running_loss += loss.item() * size
                running_acc += torch.sum(pred == target.data).item()
            toc = timer()

            epoch_loss = running_loss / len(data_loader[phase].dataset)
            epoch_acc = running_acc / len(data_loader[phase].dataset)
            
            print('{} Loss: {:.4f} Acc: {:.2f}% Elapsed time: {:.2f}s'.format(
                    phase, epoch_loss, epoch_acc * 100, toc - tic))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_loss = epoch_loss
                best_epoch = epoch
                best_weights = copy.deepcopy(model.state_dict())
        print()

    print('Best epoch: {}'.format(best_epoch + 1))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('Best val loss: {:4f}'.format(best_loss))
    model.load_state_dict(best_weights)

    return best_acc

In [None]:
torch.manual_seed(1234)
init_dset = create_dataset()
train_dset = create_dataset(dataugmentation=False)
test_dset = create_dataset(train=False)


init_loader = DataLoader(
        init_dset, batch_size=64, shuffle=False, num_workers=0)
train_loader = DataLoader(
        train_dset, batch_size=128, shuffle=True, num_workers=0)
test_loader = DataLoader(
        test_dset, batch_size=128, shuffle=False, num_workers=0)


model = SupCKNetMnist10_5(alpha=0.001)
print(model)
n_params = count_parameters(model)
print('number of paramters: {}'.format(n_params))


data_loader = {'init': init_loader, 'train': train_loader, 'val': test_loader}
tic = timer()
score = sup_train(model, data_loader)
toc = timer()
training_time = (toc - tic) / 60
print("Final accuracy: {:6.2f}%, elapsed time: {:.2f}min".format(score * 100, training_time))


In [None]:

import csv
table = {'acc': score, 'training time': training_time}
with open(outpath + '/metric.csv', 'w') as f:
    w = csv.DictWriter(f, table.keys())
    w.writeheader()
    w.writerow(table)

torch.save({
    'state_dict': model.state_dict()},
    outpath + '/model.pkl')

##### 

In [None]:
# init_dset = create_dataset()


# init_loader = DataLoader(
#         init_dset, batch_size=64, shuffle=False, num_workers=0)




# import matplotlib.pyplot as plt

# # fig, axs = plt.subplots(nrows=int(np.ceil(np.sqrt(3))), ncols=int(np.ceil(np.sqrt(3))), figsize= (30,30))
# # fig.suptitle('All reconstructed faces')
# # axs = axs.flatten()
# # print(test_loader.size())
# for i, (data, target) in enumerate(init_loader):
# #     print(target)
# #     print(data)
#     plt.imshow(data[0].T)
#     break
# #     print(data.size())
# # #     face_plot = np.reshape(data, (200, 200)).T
# #     ax.imshow(data[0], cmap='viridis')
# #     ax.set_xticks([])
# #     ax.set_yticks([])
# #     ax.set_title(target)

In [None]:
# test_dset

In [None]:
# init_dset = create_dataset()
# init_loader = DataLoader(
#         init_dset, batch_size=64, shuffle=False, num_workers=0)

# for data, _ in init_loader:
#     print('x')
#     break

In [None]:
# n_patches_per_batch = (150 + len(init_loader) - 1) // len(init_loader) 
# n_patches_per_batch

In [None]:
# train_loader = DataLoader(
#         train_dset, batch_size=128, shuffle=True, num_workers=4)
# len(train_loader)

In [None]:
# def create_dataset2( train=True, dataugmentation=False):
#     # load dataset

#     mean_pix = [x/255.0 for x in [125.3, 123.0, 113.9]]
#     std_pix = [x/255.0 for x in [63.0, 62.1, 66.7]]
#     tr = [transforms.ToTensor(), transforms.Normalize(mean=mean_pix, std=std_pix)]
#     if dataugmentation:
#         dt = [transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()]
#         tr = dt + tr
#     dataset = torchvision.datasets.CIFAR10(
#         './data/',
#         train=train,
#         transform=transforms.Compose(tr),
#         download=True,
#     )
#     return dataset
    
# init_dset2 = create_dataset2()

# init_loader2 = DataLoader(
#         init_dset, batch_size=64, shuffle=False, num_workers=2)
# len(init_loader2)

In [None]:
# len(init_loader2)

In [None]:
# for data, _ in init_loader:
#     print(data.shape)
#     print(data)
#     print(data.unfold(2, 3, 1).unfold(3, 3, 1).transpose(1, 3).contiguous().shape)


#     y = data.unfold(2, 3, 1).unfold(3, 3, 1).transpose(1, 3).contiguous().view(-1, 27)


#     break

In [None]:
# for data, _ in init_loader2:
#     print(data.shape)
#     print(data)
    
#     y = data.unfold(2, 3, 1).unfold(3, 3, 1).transpose(1, 3).contiguous().view(-1, 27)
#     print(data.unfold(2, 3, 1).unfold(3, 3, 1).transpose(1, 3).contiguous().shape)
#     print(y.shape)
#     print(y)
#     break

### Bibiliography:
https://github.com/Teo777Andrei/Image-Kernel<br>
https://github.com/SaturdayGenfo/convolutionalKernelNet<br>
https://github.com/ryanaleksander/kernel-convolution<br>
https://github.com/cjones6/yesweckn<br>
https://github.com/lemonhu/RE-CNN-pytorch<br>
https://github.com/claying/CKN-Pytorch-image <br>
https://github.com/YF-W/MU-Net<br>
https://github.com/hsouri/DCNN_SVM<br>