In [1]:
import torch
from torch import nn

from torch.optim import lr_scheduler

from torch.utils.data import random_split,Dataset,DataLoader

import torch.nn.functional as F
import torch.nn.init as init

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from joblib.externals.loky.backend.context import get_context
import time
import copy
import random
import pickle
import tarfile

import time
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [2]:
import scanpy as sc
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
from sklearn.model_selection import train_test_split

In [3]:
from unifan.networks import Encoder, Decoder, Set2Gene
import random

# Hyperparameters

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
AE_batch_size=128
AE_learning_rate = 0.001
AE_decay_factor = 0.9
AE_epochs = 300
AE_random_seed = 123
best_ari = 0
best_nmi = 0

# Processing

In [5]:
def processing(adata):
    
    
    sc.pp.filter_cells(adata, min_genes=20)
    sc.pp.filter_genes(adata, min_cells=3)
    
    data = adata.copy()
    
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.normalize_total(data, target_sum=1e4)
    sc.pp.log1p(adata)
    
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    data = data[:, adata.var.highly_variable]
    
    return data

# My Data

In [6]:
adata = sc.read("data/cortex.h5ad", dtype='float64')

In [7]:
adata = processing(adata)

In [8]:
gene_matrix = adata.X

In [9]:
G = adata.X.shape[1]

In [10]:
clusters_true = adata.obs["label"]

# Dataset&DataLoader

In [11]:
class MyDataset(Dataset):
    def __init__(self,gene_matrix):
        
        self.gene_matrix = torch.from_numpy(gene_matrix)
    
    def __len__(self):
        
        return self.gene_matrix.shape[0]

    def __getitem__(self,idx):
        
        X = self.gene_matrix[idx]
        
        return X

In [12]:
def create_loader(X,batch_size):

    #print(type(X))
    mydataset = MyDataset(X)

    dataloader = DataLoader(dataset=mydataset,batch_size=batch_size,shuffle=True)
    
    return dataloader

# autoencoder

In [13]:
class autoencoder(nn.Module):


    def __init__(self, input_dim: int = 10000,
                 hidden_dim:int = 128,encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 z_dim: int = 32,
                 dropout_rate:float = 0.1):

        super().__init__()
        
        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder,hidden_dim=encoder_dim)
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

        # initialize loss
        self.mse_loss = nn.MSELoss()

    def forward(self, data):

        x = data

        z_e,_ = self.encoder(x)

        x_e,_ = self.decoder(z_e)

        return x_e, z_e

    def _loss_reconstruct(self, x, x_e):

        l_e = self.mse_loss(x, x_e)
        
        mse_l = l_e

        return mse_l

    def loss(self, x, x_e):
        
        l = self._loss_reconstruct(x, x_e)
        
        return l


# Pretrain Autoencoder

In [14]:
AE_dataloader = create_loader(gene_matrix,batch_size = AE_batch_size)

In [15]:
AE = autoencoder(input_dim = G).to(device)

In [16]:
AE_optimizer = torch.optim.Adam(AE.parameters(), lr=AE_learning_rate)
AE_scheduler = torch.optim.lr_scheduler.StepLR(AE_optimizer, 1000,AE_decay_factor)

In [17]:
def cal_result_AE(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

In [18]:
for epoch in range(AE_epochs):
    
    total_loss = 0
    for batch_idx,X_batch in enumerate(AE_dataloader):
        X_batch = X_batch.to(device).float()
        
        AE_optimizer.zero_grad()

        x_e, z_e = AE(X_batch)

        loss = AE.loss(X_batch.float(), x_e.float())
        
        total_loss += loss
        
        
        loss.backward()
        AE_optimizer.step()
        AE_scheduler.step()

    ari,nmi = cal_result_AE(AE,AE_dataloader,clusters_true,gene_matrix)
    
    if best_ari < ari and best_nmi < nmi:
        
        best_ari = ari
        
        best_nmi = nmi
        
    print(f"epoch:{epoch+1} total_loss:{total_loss/len(AE_dataloader.sampler)} ari:{best_ari} nmi:{best_nmi}")

epoch:1 total_loss:0.16198071837425232 ari:0.4385318183303632 nmi:0.6440550666826844
epoch:2 total_loss:0.15056836605072021 ari:0.4719195599193421 nmi:0.6608487559518773
epoch:3 total_loss:0.11188949644565582 ari:0.4719195599193421 nmi:0.6608487559518773
epoch:4 total_loss:0.11074287444353104 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:5 total_loss:0.0888274759054184 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:6 total_loss:0.07658809423446655 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:7 total_loss:0.06418208032846451 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:8 total_loss:0.05020730942487717 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:9 total_loss:0.04467375949025154 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:10 total_loss:0.04604581370949745 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:11 total_loss:0.04678075388073921 ari:0.4830729583086395 nmi:0.6683135741697493
epoch:12 total_loss:0.038607120513916016 ari:0.4830729583086395 

epoch:97 total_loss:0.01970222406089306 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:98 total_loss:0.020132508128881454 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:99 total_loss:0.019804608076810837 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:100 total_loss:0.01950179785490036 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:101 total_loss:0.019255321472883224 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:102 total_loss:0.019282741472125053 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:103 total_loss:0.01909669116139412 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:104 total_loss:0.019315244629979134 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:105 total_loss:0.01927422359585762 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:106 total_loss:0.019159190356731415 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:107 total_loss:0.019535494968295097 ari:0.5467359764141766 nmi:0.6780064374309415
epoch:108 total_loss:0.0196328293532133

epoch:191 total_loss:0.015586808323860168 ari:0.7076960237439903 nmi:0.7128205862423791
epoch:192 total_loss:0.015764176845550537 ari:0.7076960237439903 nmi:0.7128205862423791
epoch:193 total_loss:0.015903722494840622 ari:0.7076960237439903 nmi:0.7128205862423791
epoch:194 total_loss:0.01683981716632843 ari:0.7095326832188015 nmi:0.722439884038592
epoch:195 total_loss:0.016231661662459373 ari:0.7095326832188015 nmi:0.722439884038592
epoch:196 total_loss:0.01633261702954769 ari:0.7095326832188015 nmi:0.722439884038592
epoch:197 total_loss:0.01583130657672882 ari:0.7095326832188015 nmi:0.722439884038592
epoch:198 total_loss:0.015835827216506004 ari:0.7095326832188015 nmi:0.722439884038592
epoch:199 total_loss:0.015767568722367287 ari:0.7095326832188015 nmi:0.722439884038592
epoch:200 total_loss:0.015564650297164917 ari:0.7095326832188015 nmi:0.722439884038592
epoch:201 total_loss:0.015613867901265621 ari:0.7095326832188015 nmi:0.722439884038592
epoch:202 total_loss:0.015538076870143414 a

epoch:285 total_loss:0.014533805660903454 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:286 total_loss:0.014478502795100212 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:287 total_loss:0.014414207078516483 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:288 total_loss:0.014238283038139343 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:289 total_loss:0.014317541383206844 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:290 total_loss:0.014217257499694824 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:291 total_loss:0.014362969435751438 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:292 total_loss:0.014194171875715256 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:293 total_loss:0.01420255284756422 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:294 total_loss:0.014100761152803898 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:295 total_loss:0.014256492257118225 ari:0.7192656772487659 nmi:0.7261499343231906
epoch:296 total_loss:0.0141683761

In [19]:
AE.to(torch.device("cpu")).eval()
z_init,_ = AE.encoder(torch.from_numpy(gene_matrix))

In [20]:
z_init = z_init.detach().numpy()

In [21]:
adata = sc.AnnData(X=z_init)
adata.obsm['X_unifan'] = z_init
sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=AE_random_seed)
sc.tl.leiden(adata, resolution=0.5, random_state=AE_random_seed)
clusters_pre = adata.obs['leiden'].astype('int').values  # original as string

In [22]:
df_cluster = pd.DataFrame(z_init)
cluster_labels = np.unique(clusters_pre)
M = len(set(cluster_labels))  # set as number of clusters
df_cluster['cluster'] = clusters_pre

# get centroids
centroids = df_cluster.groupby('cluster').mean().values
centroids_torch = torch.from_numpy(centroids)

In [23]:
centroids = torch.from_numpy(centroids)

In [24]:
ari_smaller = adjusted_rand_score(clusters_pre,
                                  clusters_true)
nmi_smaller = adjusted_mutual_info_score(clusters_pre, clusters_true)

In [25]:
nmi_smaller

0.714440227274026

In [26]:
ari_smaller

0.7046204954312052

# VQVAE

In [27]:
class VQVAE_T(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = True,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

In [28]:
class VQVAE_N(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = False,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

# Adam

In [29]:
def Adam_training(model,n_epochs,learning_rate,dataloader,clusters_true,gene_matrix,
                non_blocking = True,decay_factor = 0.9):
    
    model.to(device)
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(n_epochs):
        
        total_loss = 0
    
        for batch_idx, X_batch in enumerate(dataloader):
            
            X_batch = X_batch.to(device, non_blocking=non_blocking).float()

            optimizer.zero_grad(set_to_none=True)

            x_q, z_e, z_q = model(X_batch)

            l = model.loss(X_batch,x_q,z_e,z_q)
       
            total_loss += l

            l.backward(retain_graph=True)
            
            optimizer.step()
            
        ari,nmi = cal_result(model,dataloader,clusters_true,gene_matrix)
        global best_ari,best_nmi
        if best_ari < ari and best_nmi < nmi:

            best_ari = ari

            best_nmi = nmi

        print("------------------------------")
        print(f"epoch:{epoch+1} total_loss:{total_loss/len(dataloader)} \n ari:{ari} nmi:{nmi} \n best ari:{best_ari} bset nmi:{best_nmi}")

            
            
    # move network back to cpu and return
    model.cpu()
    
    return model

# GA

In [30]:
def crossover_and_mutation(parents, sigma=0.01):

    
    base_sd = parents[0].state_dict()
    keys = base_sd                    # use all layers to be affected
    
    # Sum of the weights of the parent
    for i in range(1, len(parents)):
        parent_sd = parents[i].state_dict()
        for key in keys:
            base_sd[key] = base_sd[key] + parent_sd[key]
            
    
    # Average and add mutation
    num_parents = len(parents)
    
    for key in keys:
        
        tensor_size = base_sd[key].size()
        random_tensor = torch.normal(mean=0.0, std=sigma, size=tensor_size)
        
        base_sd[key] = (base_sd[key] / num_parents) + random_tensor
    
    # create offspring
    
    if random.randint(0,1) == 0:
        offspring = VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
    else:
    
        offspring = VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
        
    return offspring
    

def create_offspring(population,fitness,rho,sigma):

    
    # Perform selection
    parents = random.choices(population, weights=fitness, k=rho) 
    
    # Perform crossover and mutation
    offspring = crossover_and_mutation(parents, sigma)
    
    
    return offspring

def sigmoid(x):
    
    return 1/(1 + np.exp(-x))


def GA_training(population, pop_size, offspring_size, elitist_level, rho, sigma, dataloader):
    
    #Calculate fitness of trained population

    fitness = [cal_loss(population[i],dataloader) for i in range(pop_size)]
    
    print(f"--- -- Finished fitness evaluation, length: {len(fitness)}")
    
    #Create offspring population
    fitness_weighted = [1/f for f in fitness]   # take inverse of loss so lower losses get higher fitness-values
    
    offspring_population = [create_offspring(population,fitness_weighted, rho, sigma) for i in range(offspring_size)]
    
    print("--- -- Finished creating offspring population")
    
    #Evaluate fitness of offsprings 
    
    offspring_fitness = [cal_loss(offspring_population[i],dataloader) for i in range(offspring_size)]
    
    print("--- -- Finished evaluating fitness of offspring population")
    
    # Combine fitness and population lists
    
    combined_fitness = fitness + offspring_fitness
    combined_population = population + offspring_population
    
    # sort and select population by their fitness values
    
    sorted_population = [pop for _, pop in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    sorted_fitness = [loss for loss, _ in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    
    m = int(pop_size * elitist_level)
    new_population = sorted_population[0:m]
    
    # Fill up rest of population
    difference = pop_size - m
    remaining_population = list(set(sorted_population) - set(new_population))
    filler_population = random.sample(remaining_population, difference)
    
    # assemble new population and return
    new_population = new_population + filler_population
    
    return new_population, sorted_fitness

In [31]:
def cal_loss(model,dataloader,non_blocking=True):
    
    model.to(device)

    total_loss = 0.0
    
    for batch_idx, X_batch in enumerate(dataloader):
        
        X_batch = X_batch.to(device, non_blocking=non_blocking).float()

        x_e, z_e, z_q = model(X_batch)

        l = model.loss(X_batch,x_e, z_e, z_q)
        
        total_loss += l
        
    model.cpu()

    return float(total_loss)

In [32]:
def cal_result(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e,z_q = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

# GA Neural

In [33]:
def GA_Neural_train(population,
                    pop_size,
                    max_generations, 
                    SGD_steps, GA_steps, 
                    offspring_size, elitist_level, rho,
                    learning_rate,
                    dataloader,
                    clusters_true,
                    gene_matrix):
    
    
    print(f"Starting with population of size: {pop_size}")

    for k in range(max_generations):
        print(f"Currently in generation {k+1}")
        
        #Adam
        print(f"--- Starting Adam")
        
        # Sequential version
        
        population_copy = []
        
        for i in range(pop_size):
            
            model = Adam_training(population[i],SGD_steps,learning_rate,dataloader,clusters_true,gene_matrix)
            population_copy.append(model)
        
        print(f"--- Finished Adam")
        
        population = population_copy
         
        # GA
        print(f"--- Starting Model GA")
        GA_start = time.time()
        sorted_fitness = []          # store the sorted fitness values to maybe use in data collection
        for i in range(0, GA_steps):
            
            sigma = 0.01 / (k+1)
            population, sorted_fitness = GA_training(population, 
                                                     pop_size, offspring_size, elitist_level, rho, sigma, dataloader)
        
        GA_end = time.time()
        
        print(f"--- Finished Model GA,Time:{(GA_end - GA_start) * 1000}ms")
        
        
    print(f"Finished training process")
    
    return population

# Train VQVAE

In [34]:
# Hyperparameters
#device_train = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 1024
pop_size_T = 0
pop_size_N = 1
max_generations = 1
SGD_steps = 500
GA_steps = 0
offspring_size = 30
elitist_level = 0.6
rho = 2
learning_rate = 1e-5
in_feats = 400
n_hidden = 200
weight_decay = 5e-4
best_nmi = 0
best_ari = 0

In [35]:
dataloader = create_loader(gene_matrix,batch_size = batch_size)
pop_size  = pop_size_T + pop_size_N

In [36]:
# Create population and start training process
population_T = [VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25,n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_T)]

In [37]:
population_N = [VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25,n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_N)]

In [38]:
population = population_T + population_N

In [39]:
Train_start = time.time()
trained_population = GA_Neural_train(population=population,
                                    pop_size = pop_size,
                                    max_generations=max_generations,
                                    SGD_steps=SGD_steps,GA_steps=GA_steps,
                                    offspring_size=offspring_size,elitist_level=elitist_level,rho=rho,
                                    learning_rate=learning_rate,
                                    dataloader=dataloader,clusters_true=clusters_true,gene_matrix=gene_matrix)
Train_end = time.time()
print(f"All Time:{(Train_end-Train_start)*1000}ms")

Starting with population of size: 1
Currently in generation 1
--- Starting Adam
------------------------------
epoch:1 total_loss:36.134857177734375 
 ari:0.6772370592904112 nmi:0.6121642027312777 
 best ari:0.6772370592904112 bset nmi:0.6121642027312777
------------------------------
epoch:2 total_loss:35.831905364990234 
 ari:0.6826938664239893 nmi:0.6183947468005668 
 best ari:0.6826938664239893 bset nmi:0.6183947468005668
------------------------------
epoch:3 total_loss:35.644744873046875 
 ari:0.6938278729387327 nmi:0.626682049883329 
 best ari:0.6938278729387327 bset nmi:0.626682049883329
------------------------------
epoch:4 total_loss:35.49232482910156 
 ari:0.6947865832715392 nmi:0.6318854112972214 
 best ari:0.6947865832715392 bset nmi:0.6318854112972214
------------------------------
epoch:5 total_loss:35.37330627441406 
 ari:0.6923821885012691 nmi:0.6302219286574287 
 best ari:0.6947865832715392 bset nmi:0.6318854112972214
------------------------------
epoch:6 total_loss

------------------------------
epoch:48 total_loss:25.051494598388672 
 ari:0.4829195089816966 nmi:0.5560772593139007 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:49 total_loss:24.81011962890625 
 ari:0.48084755442524146 nmi:0.5689824669525987 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:50 total_loss:24.49390983581543 
 ari:0.48678657843021195 nmi:0.5593734083920926 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:51 total_loss:24.252958297729492 
 ari:0.4784684677110127 nmi:0.5688589756024476 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:52 total_loss:23.984975814819336 
 ari:0.4373966991881455 nmi:0.5374004729185485 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:53 total_loss:23.653120040893555 
 ari:0.4727384916989271 nmi:0.5763602824583697 
 

------------------------------
epoch:95 total_loss:14.049200057983398 
 ari:0.3166301493951569 nmi:0.47226286924874206 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:96 total_loss:13.92136287689209 
 ari:0.3337249891624915 nmi:0.4657366386055758 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:97 total_loss:13.699689865112305 
 ari:0.315184788057983 nmi:0.45288046698638357 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:98 total_loss:13.528215408325195 
 ari:0.3170365393085463 nmi:0.4649384878029993 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:99 total_loss:13.437296867370605 
 ari:0.3135383329762061 nmi:0.46490476581015855 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:100 total_loss:13.232236862182617 
 ari:0.3465168662931704 nmi:0.45847904693196795

------------------------------
epoch:142 total_loss:9.199562072753906 
 ari:0.27448820682117375 nmi:0.41783301857608646 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:143 total_loss:9.120628356933594 
 ari:0.2541817083649987 nmi:0.4119281423485207 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:144 total_loss:9.059402465820312 
 ari:0.2558836969040897 nmi:0.4068830318704283 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:145 total_loss:9.022734642028809 
 ari:0.27902864701361146 nmi:0.41974410929062445 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:146 total_loss:8.986193656921387 
 ari:0.28173457302703786 nmi:0.40946927172676256 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:147 total_loss:8.922274589538574 
 ari:0.25312405238343816 nmi:0.418978855275

------------------------------
epoch:189 total_loss:7.322054862976074 
 ari:0.21302158003415653 nmi:0.37273110201372106 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:190 total_loss:7.307912349700928 
 ari:0.2445144645957647 nmi:0.3834948649692337 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:191 total_loss:7.280489921569824 
 ari:0.23580711481716443 nmi:0.3849900941510891 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:192 total_loss:7.273406982421875 
 ari:0.24413969370675231 nmi:0.3891501323184272 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:193 total_loss:7.2384796142578125 
 ari:0.24330504694330476 nmi:0.3834546738867981 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:194 total_loss:7.183522701263428 
 ari:0.21828034221202586 nmi:0.370066571513

------------------------------
epoch:236 total_loss:6.4161696434021 
 ari:0.20304563148657717 nmi:0.34883584248042543 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:237 total_loss:6.386185646057129 
 ari:0.22067407612730117 nmi:0.34189794190077843 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:238 total_loss:6.3736114501953125 
 ari:0.21468544523318359 nmi:0.3512225469660391 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:239 total_loss:6.371387481689453 
 ari:0.21236157437406078 nmi:0.34149451713881873 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:240 total_loss:6.326209545135498 
 ari:0.21435171275795722 nmi:0.33954095253278793 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:241 total_loss:6.3269453048706055 
 ari:0.20855741556169052 nmi:0.339476392

------------------------------
epoch:283 total_loss:5.953849792480469 
 ari:0.20990951274208203 nmi:0.3281108323450888 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:284 total_loss:5.9431047439575195 
 ari:0.21926623255288707 nmi:0.3381734342342376 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:285 total_loss:5.937417030334473 
 ari:0.20449335941337266 nmi:0.3183396361759942 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:286 total_loss:5.931924819946289 
 ari:0.20501541010997532 nmi:0.32669801207949206 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:287 total_loss:5.896142482757568 
 ari:0.20409369597152302 nmi:0.3256942312635394 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:288 total_loss:5.906036376953125 
 ari:0.20517195852307332 nmi:0.32940163288

------------------------------
epoch:330 total_loss:5.696202754974365 
 ari:0.2032778752557641 nmi:0.3025765899284591 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:331 total_loss:5.6785712242126465 
 ari:0.20016658801925186 nmi:0.30047479597959775 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:332 total_loss:5.6856513023376465 
 ari:0.2006401085139985 nmi:0.2993991131998042 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:333 total_loss:5.661650657653809 
 ari:0.2159606274259462 nmi:0.30630548797151563 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:334 total_loss:5.723639488220215 
 ari:0.1817205942696441 nmi:0.2904332798117376 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:335 total_loss:5.6666460037231445 
 ari:0.18784860684639101 nmi:0.293114083586

------------------------------
epoch:377 total_loss:5.524589538574219 
 ari:0.14702495709643976 nmi:0.23320191782945965 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:378 total_loss:5.514186859130859 
 ari:0.0855400837991878 nmi:0.1693470443490333 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:379 total_loss:5.5104875564575195 
 ari:0.08718933452451777 nmi:0.16226385129037033 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:380 total_loss:5.514497756958008 
 ari:0.09246878021794674 nmi:0.16429777475774524 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:381 total_loss:5.495824813842773 
 ari:0.09409609403959468 nmi:0.16742115784703307 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:382 total_loss:5.5064697265625 
 ari:0.13500868454534212 nmi:0.19524046044

------------------------------
epoch:424 total_loss:5.430575370788574 
 ari:0.03594133974551447 nmi:0.06831544400687627 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:425 total_loss:5.388092041015625 
 ari:0.02440546943707453 nmi:0.0492160484203903 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:426 total_loss:5.377652645111084 
 ari:0.0325425549532934 nmi:0.06663754114531004 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:427 total_loss:5.421689510345459 
 ari:0.034020933586557986 nmi:0.06656841822822115 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:428 total_loss:5.425565719604492 
 ari:0.05141091443772239 nmi:0.089128581916123 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:429 total_loss:5.448884963989258 
 ari:0.01858856081124277 nmi:0.04501675781

------------------------------
epoch:470 total_loss:5.327437400817871 
 ari:0.022031191557698976 nmi:0.05699888312316193 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:471 total_loss:5.319453716278076 
 ari:0.013855574930540637 nmi:0.05072765382266024 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:472 total_loss:5.318053722381592 
 ari:0.018037776673574583 nmi:0.04898694010614676 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:473 total_loss:5.351808071136475 
 ari:0.022942097707551262 nmi:0.06253614847668003 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:474 total_loss:5.359147548675537 
 ari:0.022813644476780734 nmi:0.04786428429722017 
 best ari:0.7453310045156862 bset nmi:0.6843188961076703
------------------------------
epoch:475 total_loss:5.312175273895264 
 ari:0.024782220869324946 nmi:0.05