In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import utils
import scipy.sparse as sparse
import time
from sklearn.preprocessing import normalize
from sklearn.neighbors import kneighbors_graph
import random
from tqdm import tqdm
import os
import csv
import time
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from scipy.io import loadmat
import scipy
from sklearn.utils import shuffle

class MLP(nn.Module):
    
    def __init__(self, input_dims, hid_dims, out_dims, kaiming_init=False):
        super(MLP, self).__init__()
        self.input_dims = input_dims
        self.hid_dims = hid_dims
        self.output_dims = out_dims
        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(self.input_dims, self.hid_dims[0]))
        self.layers.append(nn.ReLU())
        for i in range(len(hid_dims) - 1):
            self.layers.append(nn.Linear(self.hid_dims[i], self.hid_dims[i + 1]))
            self.layers.append(nn.ReLU())

        self.out_layer = nn.Linear(self.hid_dims[-1], self.output_dims)
        if kaiming_init:
            self.reset_parameters()
        
    def reset_parameters(self):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                init.kaiming_uniform_(layer.weight)
                init.zeros_(layer.bias)
        init.xavier_uniform_(self.out_layer.weight)
        init.zeros_(self.out_layer.bias)
        
    def forward(self, x):
        h = x
        for i, layer in enumerate(self.layers):
            h = layer(h)
        h = self.out_layer(h)
        h = torch.tanh_(h)
        return h

class SENet(nn.Module):

    def __init__(self, input_dims, hid_dims, out_dims, kaiming_init=True):
        super(SENet, self).__init__()
        self.input_dims = input_dims
        self.hid_dims = hid_dims
        self.out_dims = out_dims
        self.kaiming_init = kaiming_init
        self.net_q = MLP(input_dims=self.input_dims,
                         hid_dims=self.hid_dims,
                         out_dims=self.out_dims,
                         kaiming_init=self.kaiming_init)

        self.net_k = MLP(input_dims=self.input_dims,
                         hid_dims=self.hid_dims,
                         out_dims=self.out_dims,
                         kaiming_init=self.kaiming_init)

    def query_embedding(self, queries):
        q_emb = self.net_q(queries)
        return q_emb
    
    def key_embedding(self, keys):
        k_emb = self.net_k(keys)
        return k_emb

    def get_coeff(self, q_emb, k_emb):
        c=q_emb.mm(k_emb.t())
        return c

    def forward(self, queries, keys):
        q = self.query_embedding(queries)
        k = self.key_embedding(keys)
        out = self.get_coeff(q_emb=q, k_emb=k)
        return out


def regularizer(c, lmbd=1.0):
    return lmbd * torch.abs(c).sum() + (1.0 - lmbd) / 2.0 * torch.pow(c, 2).sum()


def get_sparse_rep(senet, data, batch_size=10, chunk_size=100, non_zeros=1000):

    N, D = data.shape
    non_zeros = min(N, non_zeros)
    C = torch.empty([batch_size, N])
    if (N % batch_size != 0):
        raise Exception("batch_size should be a factor of dataset size.")
    if (N % chunk_size != 0):
        raise Exception("chunk_size should be a factor of dataset size.")

    val = []
    indicies = []
    with torch.no_grad():
        senet.eval()
        for i in range(data.shape[0] // batch_size):
            chunk = data[i * batch_size:(i + 1) * batch_size].cuda()
            q = senet.query_embedding(chunk)
            for j in range(data.shape[0] // chunk_size):
                chunk_samples = data[j * chunk_size: (j + 1) * chunk_size].cuda()
                k = senet.key_embedding(chunk_samples)
                temp = senet.get_coeff(q, k)
                C[:, j * chunk_size:(j + 1) * chunk_size] = temp.cpu()
             

            rows = list(range(batch_size))
            cols = [j + i * batch_size for j in rows]
            C[rows, cols] = 0.0
            _, index = torch.topk(torch.abs(C), dim=1, k=non_zeros)
            
            val.append(C.gather(1, index).reshape([-1]).cpu().data.numpy())
            index = index.reshape([-1]).cpu().data.numpy()
            indicies.append(index)

    val = np.concatenate(val, axis=0)
    indicies = np.concatenate(indicies, axis=0)
    indptr = [non_zeros * i for i in range(N + 1)]
    
    C_sparse = sparse.csr_matrix((val, indicies, indptr), shape=[N, N])
    return C_sparse


def get_knn_Aff(C_sparse_normalized, k=3, mode='symmetric'):
    C_knn = kneighbors_graph(C_sparse_normalized, k, mode='connectivity', include_self=False, n_jobs=10)
    csn_array.append(C_knn.toarray())
    return csn_array


def evaluate(csn_array,senet, data, labels, num_subspaces, spectral_dim, non_zeros=1000, n_neighbors=3,
             batch_size=10000, chunk_size=10000, affinity='nearest_neighbor', knn_mode='symmetric'):
    C_sparse = get_sparse_rep(senet=senet, data=data, batch_size=batch_size,
                              chunk_size=chunk_size, non_zeros=non_zeros)
    C_sparse_normalized = normalize(C_sparse).astype(np.float32)
    csn_array = get_knn_Aff(C_sparse_normalized, k=n_neighbors, mode=knn_mode)
    return csn_array


def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="ORL")
    parser.add_argument('--num_subspaces', type=int, default=40)
    parser.add_argument('--gamma', type=float, default=200.0)
    parser.add_argument('--lmbd', type=float, default=0.9)
    parser.add_argument('--hid_dims', type=int, default=[1024, 1024, 1024])
    parser.add_argument('--out_dims', type=int, default=1024)
    parser.add_argument('--total_iters', type=int, default=100000)
    parser.add_argument('--save_iters', type=int, default=200000)
    parser.add_argument('--eval_iters', type=int, default=200000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr_min', type=float, default=0.0)
    parser.add_argument('--batch_size', type=int, default=400)
    parser.add_argument('--chunk_size', type=int, default=40000)
    parser.add_argument('--non_zeros', type=int, default=40000)
    parser.add_argument('--n_neighbors', type=int, default=3)
    parser.add_argument('--spectral_dim', type=int, default=15)
    parser.add_argument('--affinity', type=str, default="nearest_neighbor")
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    fit_msg = "Experiments on {}, numpy_seed=0, total_iters=100000, lambda=0.9, gamma=200.0".format(args.dataset, args.seed)
    print(fit_msg)

    same_seeds(args.seed)
    tic = time.time()

    full_samples = np.load( '../../../datasets/ORL/orltraindataforSb.npy')
    full_labels = np.load( '../../../datasets/ORL/orltrainlabelsforSb.npy')
    full_labels = full_labels.reshape(len(full_labels), )


    
    full_labels = full_labels - np.min(full_labels) 

    result = open('{}/results.csv'.format(folder), 'w')
    writer = csv.writer(result)
    writer.writerow(["N", "ACC", "NMI", "ARI"])

    global_steps = 0


    start = time.time()
    print("SDSC At start torch.cuda.memory_allocated for orl train: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
    print("SDSC At start torch.cuda.memory_reserved for orl train: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
    print("SDSC At start torch.cuda.max_memory_reserved for orl train: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))
    for N in [len(full_samples)]:
        sampled_idx = np.random.choice(full_samples.shape[0], N, replace=False)
        samples, labels = full_samples[sampled_idx], full_labels[sampled_idx]

        block_size = min(N, 10000)
      
        with open('{}/{}_samples_{}.pkl'.format(folder, args.dataset, N), 'wb') as f:
            pickle.dump(samples, f)
        with open('{}/{}_labels_{}.pkl'.format(folder, args.dataset, N), 'wb') as f:
            pickle.dump(labels, f)

        all_samples, ambient_dim = samples.shape[0], samples.shape[1]

        data = torch.from_numpy(samples).float()
        data = utils.p_normalize(data)

        n_iter_per_epoch = samples.shape[0] // args.batch_size
        n_step_per_iter = round(all_samples // block_size)
        n_epochs = args.total_iters // n_iter_per_epoch
        
        senet = SENet(ambient_dim, args.hid_dims, args.out_dims, kaiming_init=True).cuda()
        optimizer = optim.Adam(senet.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs, eta_min=args.lr_min)

        n_iters = 0
        pbar = tqdm(range(n_epochs), ncols=120)

        for epoch in pbar:
            pbar.set_description(f"Epoch {epoch}")
            randidx = torch.randperm(data.shape[0])

            for i in range(n_iter_per_epoch):
                senet.train()

                batch_idx = randidx[i * args.batch_size : (i + 1) * args.batch_size]
                batch = data[batch_idx].cuda()
                batch=data[i*args.batch_size:i*args.batch_size+args.batch_size].cuda()

                q_batch = senet.query_embedding(batch)
                k_batch = senet.key_embedding(batch)

                rec_batch = torch.zeros_like(batch).cuda()
                reg = torch.zeros([1]).cuda()
                for j in range(n_step_per_iter):
                    block = data[j * block_size: (j + 1) * block_size].cuda()
                    k_block = senet.key_embedding(block)
                    c = senet.get_coeff(q_batch, k_block)
                    rec_batch = rec_batch + c.mm(block)
                    reg = reg + regularizer(c, args.lmbd)

                #diag_c = senet.thres((q_batch * k_batch).sum(dim=1, keepdim=True)) * senet.shrink
                diag_c = (q_batch * k_batch).sum(dim=1, keepdim=True)
                rec_batch = rec_batch - diag_c * batch
                reg = reg - regularizer(diag_c, args.lmbd)

                rec_loss = torch.sum(torch.pow(batch - rec_batch, 2))
                loss = (0.5 * args.gamma * rec_loss + reg) / args.batch_size

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(senet.parameters(), 0.001)
                optimizer.step()

                global_steps += 1
                n_iters += 1


                if n_iters % args.save_iters == 0:
                    with open('{}/SENet_{}_N{:d}_iter{:d}.pth.tar'.format(folder, args.dataset, N, n_iters), 'wb') as f:
                        torch.save(senet.state_dict(), f)
                    print("Model Saved.")

                if n_iters % args.eval_iters == 0:
                    print("Evaluating on sampled data...")
                    acc, nmi, ari,csn_array = evaluate(csn_array,senet, data=data, labels=labels, num_subspaces=args.num_subspaces, affinity=args.affinity,spectral_dim=args.spectral_dim, non_zeros=args.non_zeros, n_neighbors=args.n_neighbors,batch_size=block_size, chunk_size=block_size,knn_mode='symmetric')
                    print("ACC-{:.6f}, NMI-{:.6f}, ARI-{:.6f}".format(acc, nmi, ari))

            pbar.set_postfix(loss="{:3.4f}".format(loss.item()),
                             rec_loss="{:3.4f}".format(rec_loss.item() / args.batch_size),
                             reg="{:3.4f}".format(reg.item() / args.batch_size))
            scheduler.step()
        end = time.time()
        print('SDSC time to converge C on orl train ', end - start)

        print("SDSC after convergence torch.cuda.memory_allocated for orl train: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
        print("SDSC After convergence torch.cuda.memory_reserved for orl train: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
        print("SDSC After convergence torch.cuda.max_memory_reserved for orl train: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))

        print("Evaluating on eyaleb full....".format(args.dataset))
        full_data = torch.from_numpy(full_samples).float()
        full_data = utils.p_normalize(full_data)
        csn_array = []
        d = []
        l = []
        start = time.time()
        for i in range(int(len(full_data) / args.batch_size)):
            databatch = full_data[i * args.batch_size:i * args.batch_size + args.batch_size]
            labelbatch = full_labels[i * args.batch_size:i * args.batch_size + args.batch_size]
            d.append(databatch.numpy())
            l.append(labelbatch)
            acc, nmi, ari, csn_array = evaluate(csn_array, senet, data=databatch, labels=labelbatch,
                                            num_subspaces=args.num_subspaces, affinity=args.affinity,
                                            spectral_dim=args.spectral_dim, non_zeros=args.non_zeros,
                                            n_neighbors=args.n_neighbors, batch_size=args.batch_size,
                                            chunk_size=args.batch_size, knn_mode='symmetric')
        print("N-{:d}: ACC-{:.6f}, NMI-{:.6f}, ARI-{:.6f}".format(N, acc, nmi, ari))
        end = time.time()
        print('time to compute C on orl train is : ', end - start)
        print(np.array(csn_array).shape)

        np.save( '../../../datasets/ORL/nSb_orltrain.npy',csn_array)
        np.save('../../../datasets/ORL/nd_orltrain.npy', np.array(d))
        np.save('../../../datasets/ORL/nl_orltrain.npy', np.array(l))

        print("SDSC At end torch.cuda.memory_allocated for orl train: %fGB" % (torch.cuda.memory_allocated(0) / 1024 / 1024 / 1024))
        print("SDSC At end torch.cuda.memory_reserved for orl train: %fGB" % (torch.cuda.memory_reserved(0) / 1024 / 1024 / 1024))
        print("SDSC At end torch.cuda.max_memory_reserved for orl train: %fGB" % (torch.cuda.max_memory_reserved(0) / 1024 / 1024 / 1024))

        writer.writerow([N, acc, nmi, ari])
        result.flush()

        with open('{}/SENet_{}_N{:d}.pth.tar'.format(folder, args.dataset, N), 'wb') as f:
            torch.save(senet.state_dict(), f)

        torch.cuda.empty_cache()
    result.close()
