In [1]:
import torch
from torch import nn

from torch.optim import lr_scheduler

from torch.utils.data import random_split,Dataset,DataLoader

import torch.nn.functional as F
import torch.nn.init as init

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from joblib.externals.loky.backend.context import get_context
import time
import copy
import random
import pickle
import tarfile

import time
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [2]:
import scanpy as sc
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
from sklearn.model_selection import train_test_split

In [3]:
from unifan.networks import Encoder, Decoder, Set2Gene
import random

# Hyperparameters

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
AE_batch_size=128
AE_learning_rate = 0.001
AE_decay_factor = 0.9
AE_epochs = 300
AE_random_seed = 123
best_ari = 0
best_nmi = 0

# Processing

In [5]:
def processing(adata):
    
    
    sc.pp.filter_cells(adata, min_genes=20)
    sc.pp.filter_genes(adata, min_cells=3)
    
    data = adata.copy()
    
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.normalize_total(data, target_sum=1e4)
    sc.pp.log1p(adata)
    
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    data = data[:, adata.var.highly_variable]
    
    return data

# My Data

In [6]:
adata = sc.read("data/cortex.h5ad", dtype='float64')

In [7]:
adata = processing(adata)

In [8]:
gene_matrix = adata.X

In [9]:
G = adata.X.shape[1]

In [10]:
clusters_true = adata.obs["label"]

# Dataset&DataLoader

In [11]:
class MyDataset(Dataset):
    def __init__(self,gene_matrix):
        
        self.gene_matrix = torch.from_numpy(gene_matrix)
    
    def __len__(self):
        
        return self.gene_matrix.shape[0]

    def __getitem__(self,idx):
        
        X = self.gene_matrix[idx]
        
        return X

In [12]:
def create_loader(X,batch_size):

    #print(type(X))
    mydataset = MyDataset(X)

    dataloader = DataLoader(dataset=mydataset,batch_size=batch_size,shuffle=True)
    
    return dataloader

# autoencoder

In [13]:
class autoencoder(nn.Module):


    def __init__(self, input_dim: int = 10000,
                 hidden_dim:int = 128,encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 z_dim: int = 32,
                 dropout_rate:float = 0.1):

        super().__init__()
        
        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder,hidden_dim=encoder_dim)
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

        # initialize loss
        self.mse_loss = nn.MSELoss()

    def forward(self, data):

        x = data

        z_e,_ = self.encoder(x)

        x_e,_ = self.decoder(z_e)

        return x_e, z_e

    def _loss_reconstruct(self, x, x_e):

        l_e = self.mse_loss(x, x_e)
        
        mse_l = l_e

        return mse_l

    def loss(self, x, x_e):
        
        l = self._loss_reconstruct(x, x_e)
        
        return l


# Pretrain Autoencoder

In [14]:
AE_dataloader = create_loader(gene_matrix,batch_size = AE_batch_size)

In [15]:
AE = autoencoder(input_dim = G).to(device)

In [16]:
AE_optimizer = torch.optim.Adam(AE.parameters(), lr=AE_learning_rate)
AE_scheduler = torch.optim.lr_scheduler.StepLR(AE_optimizer, 1000,AE_decay_factor)

In [17]:
def cal_result_AE(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

In [18]:
for epoch in range(AE_epochs):
    
    total_loss = 0
    for batch_idx,X_batch in enumerate(AE_dataloader):
        X_batch = X_batch.to(device).float()
        
        AE_optimizer.zero_grad()

        x_e, z_e = AE(X_batch)

        loss = AE.loss(X_batch.float(), x_e.float())
        
        total_loss += loss
        
        
        loss.backward()
        AE_optimizer.step()
        AE_scheduler.step()

    ari,nmi = cal_result_AE(AE,AE_dataloader,clusters_true,gene_matrix)
    
    if best_ari < ari and best_nmi < nmi:
        
        best_ari = ari
        
        best_nmi = nmi
        
    print(f"epoch:{epoch+1} total_loss:{total_loss/len(AE_dataloader.sampler)} ari:{best_ari} nmi:{best_nmi}")

epoch:1 total_loss:0.16099269688129425 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:2 total_loss:0.14011918008327484 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:3 total_loss:0.10924022644758224 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:4 total_loss:0.09797268360853195 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:5 total_loss:0.0807659849524498 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:6 total_loss:0.057419005781412125 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:7 total_loss:0.06376472115516663 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:8 total_loss:0.049345821142196655 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:9 total_loss:0.049604665488004684 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:10 total_loss:0.05567879229784012 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:11 total_loss:0.044946715235710144 ari:0.5141890490459206 nmi:0.6787273902418247
epoch:12 total_loss:0.03868047520518303 ari:0.51418904904592

epoch:96 total_loss:0.019235758110880852 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:97 total_loss:0.019248245283961296 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:98 total_loss:0.01891438290476799 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:99 total_loss:0.018950751051306725 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:100 total_loss:0.018684634938836098 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:101 total_loss:0.018460039049386978 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:102 total_loss:0.01841580681502819 ari:0.5499972930937969 nmi:0.6875672701267156
epoch:103 total_loss:0.018349478021264076 ari:0.6802982612481075 nmi:0.7124181576900883
epoch:104 total_loss:0.01843211241066456 ari:0.6802982612481075 nmi:0.7124181576900883
epoch:105 total_loss:0.019538162276148796 ari:0.6955721375816963 nmi:0.7263188439075597
epoch:106 total_loss:0.020017368718981743 ari:0.6955721375816963 nmi:0.7263188439075597
epoch:107 total_loss:0.0185900833457708

epoch:190 total_loss:0.016581712290644646 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:191 total_loss:0.01633116416633129 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:192 total_loss:0.016152450814843178 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:193 total_loss:0.01674424111843109 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:194 total_loss:0.016911501064896584 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:195 total_loss:0.017354145646095276 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:196 total_loss:0.016402319073677063 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:197 total_loss:0.016237346455454826 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:198 total_loss:0.015758024528622627 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:199 total_loss:0.015686558559536934 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:200 total_loss:0.01590256579220295 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:201 total_loss:0.016041381284

epoch:284 total_loss:0.013927982188761234 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:285 total_loss:0.014433654025197029 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:286 total_loss:0.015396786853671074 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:287 total_loss:0.017681526020169258 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:288 total_loss:0.015292495489120483 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:289 total_loss:0.014918343164026737 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:290 total_loss:0.014597473666071892 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:291 total_loss:0.014692779630422592 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:292 total_loss:0.014470295049250126 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:293 total_loss:0.014311570674180984 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:294 total_loss:0.014002897776663303 ari:0.7222927430717622 nmi:0.7416291499338649
epoch:295 total_loss:0.013875942

In [19]:
AE.to(torch.device("cpu")).eval()
z_init,_ = AE.encoder(torch.from_numpy(gene_matrix))

In [20]:
z_init = z_init.detach().numpy()

In [21]:
adata = sc.AnnData(X=z_init)
adata.obsm['X_unifan'] = z_init
sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=AE_random_seed)
sc.tl.leiden(adata, resolution=0.5, random_state=AE_random_seed)
clusters_pre = adata.obs['leiden'].astype('int').values  # original as string

In [22]:
df_cluster = pd.DataFrame(z_init)
cluster_labels = np.unique(clusters_pre)
M = len(set(cluster_labels))  # set as number of clusters
df_cluster['cluster'] = clusters_pre

# get centroids
centroids = df_cluster.groupby('cluster').mean().values
centroids_torch = torch.from_numpy(centroids)

In [23]:
centroids = torch.from_numpy(centroids)

In [24]:
ari_smaller = adjusted_rand_score(clusters_pre,
                                  clusters_true)
nmi_smaller = adjusted_mutual_info_score(clusters_pre, clusters_true)

In [25]:
nmi_smaller

0.6963418870614873

In [26]:
ari_smaller

0.5567875907200872

# VQVAE

In [27]:
class VQVAE_T(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = True,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l


In [28]:
class VQVAE_N(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = False,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

# Adam

In [29]:
def Adam_training(model,n_epochs,learning_rate,dataloader,clusters_true,gene_matrix,
                non_blocking = True,decay_factor = 0.9):
    
    model.to(device)
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000, decay_factor)
    
    for epoch in range(n_epochs):
        
        total_loss = 0
    
        for batch_idx, X_batch in enumerate(dataloader):
            
            X_batch = X_batch.to(device, non_blocking=non_blocking).float()

            optimizer.zero_grad(set_to_none=True)

            x_q,z_e, z_q = model(X_batch)

            l = model.loss(X_batch,x_q,z_e,z_q)
       
            total_loss += l

            l.backward(retain_graph=True)
            
            optimizer.step()
            
        ari,nmi = cal_result(model,dataloader,clusters_true,gene_matrix)
        global best_ari,best_nmi
        if best_ari < ari and best_nmi < nmi:

            best_ari = ari

            best_nmi = nmi

        print("------------------------------")
        print(f"epoch:{epoch+1} total_loss:{total_loss/len(dataloader)} \n ari:{ari} nmi:{nmi} \n best ari:{best_ari} bset nmi:{best_nmi}")

                
    # move network back to cpu and return
    model.cpu()
    
    return model

# GA

In [30]:
def crossover_and_mutation(parents, sigma=0.01):

    
    base_sd = parents[0].state_dict()
    keys = base_sd                    # use all layers to be affected
    
    # Sum of the weights of the parent
    for i in range(1, len(parents)):
        parent_sd = parents[i].state_dict()
        for key in keys:
            base_sd[key] = base_sd[key] + parent_sd[key]
            
    
    # Average and add mutation
    num_parents = len(parents)
    
    for key in keys:
        
        tensor_size = base_sd[key].size()
        random_tensor = torch.normal(mean=0.0, std=sigma, size=tensor_size)
        
        base_sd[key] = (base_sd[key] / num_parents) + random_tensor
    
    # create offspring
    
    if random.randint(0,1) == 0:
        offspring = VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
    else:
    
        offspring = VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
        
    return offspring
    

def create_offspring(population,fitness,rho,sigma):

    
    # Perform selection
    parents = random.choices(population, weights=fitness, k=rho) 
    
    # Perform crossover and mutation
    offspring = crossover_and_mutation(parents, sigma)
    
    
    return offspring

def sigmoid(x):
    
    return 1/(1 + np.exp(-x))


def GA_training(population, pop_size, offspring_size, elitist_level, rho, sigma, dataloader):
    
    #Calculate fitness of trained population

    fitness = [cal_loss(population[i],dataloader) for i in range(pop_size)]
    
    print(f"--- -- Finished fitness evaluation, length: {len(fitness)}")
    
    #Create offspring population
    fitness_weighted = [1/f for f in fitness]   # take inverse of loss so lower losses get higher fitness-values
    
    offspring_population = [create_offspring(population,fitness_weighted, rho, sigma) for i in range(offspring_size)]
    
    print("--- -- Finished creating offspring population")
    
    #Evaluate fitness of offsprings 
    
    offspring_fitness = [cal_loss(offspring_population[i],dataloader) for i in range(offspring_size)]
    
    print("--- -- Finished evaluating fitness of offspring population")
    
    # Combine fitness and population lists
    
    combined_fitness = fitness + offspring_fitness
    combined_population = population + offspring_population
    
    # sort and select population by their fitness values
    
    sorted_population = [pop for _, pop in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    sorted_fitness = [loss for loss, _ in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    
    m = int(pop_size * elitist_level)
    new_population = sorted_population[0:m]
    
    # Fill up rest of population
    difference = pop_size - m
    remaining_population = list(set(sorted_population) - set(new_population))
    filler_population = random.sample(remaining_population, difference)
    
    # assemble new population and return
    new_population = new_population + filler_population
    
    return new_population, sorted_fitness

In [31]:
def cal_loss(model,dataloader,non_blocking=True):
    
    model.to(device)

    total_loss = 0.0
    
    for batch_idx, X_batch in enumerate(dataloader):
        
        X_batch = X_batch.to(device, non_blocking=non_blocking).float()

        x_e, z_e, z_q = model(X_batch)

        l = model.loss(X_batch,x_e, z_e,z_q)
        
        total_loss += l
        
    model.cpu()

    return float(total_loss)

In [32]:
def cal_result(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e,z_q = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

# GA Neural

In [33]:
def GA_Neural_train(population,
                    pop_size,
                    max_generations, 
                    SGD_steps, GA_steps, 
                    offspring_size, elitist_level, rho,
                    learning_rate,
                    dataloader,
                    clusters_true,
                    gene_matrix):
    
    
    print(f"Starting with population of size: {pop_size}")

    for k in range(max_generations):
        print(f"Currently in generation {k+1}")
        
        #Adam
        print(f"--- Starting Adam")
        
        # Sequential version
        
        population_copy = []
        
        for i in range(pop_size):
            
            model = Adam_training(population[i],SGD_steps,learning_rate,dataloader,clusters_true,gene_matrix)
            population_copy.append(model)
        
        print(f"--- Finished Adam")
        
        population = population_copy
         
        # GA
        print(f"--- Starting Model GA")
        GA_start = time.time()
        sorted_fitness = []          # store the sorted fitness values to maybe use in data collection
        for i in range(0, GA_steps):
            
            sigma = 0.01 / (k+1)
            population, sorted_fitness = GA_training(population, 
                                                     pop_size, offspring_size, elitist_level, rho, sigma, dataloader)
        
        GA_end = time.time()
        
        print(f"--- Finished Model GA,Time:{(GA_end - GA_start) * 1000}ms")
        
        
    print(f"Finished training process")
    
    return population

# Train VQVAE

In [34]:
# Hyperparameters
#device_train = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 1024
pop_size_T = 1
pop_size_N = 0
max_generations = 1
SGD_steps = 500
GA_steps = 0
offspring_size = 30
elitist_level = 0.4
rho = 4
learning_rate = 1e-5
in_feats = 400
n_hidden = 200
weight_decay = 5e-4
best_nmi = 0
best_ari = 0

In [35]:
dataloader = create_loader(gene_matrix,batch_size = batch_size)
pop_size  = pop_size_T + pop_size_N

In [36]:
# Create population and start training process
population_T = [VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25,n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_T)]

In [37]:
population_N = [VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1.0,gama=0.25, n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_N)]

In [38]:
population = population_T + population_N

In [39]:
Train_start = time.time()
trained_population = GA_Neural_train(population=population,
                                    pop_size = pop_size,
                                    max_generations=max_generations,
                                    SGD_steps=SGD_steps,GA_steps=GA_steps,
                                    offspring_size=offspring_size,elitist_level=elitist_level,rho=rho,
                                    learning_rate=learning_rate,
                                    dataloader=dataloader,clusters_true=clusters_true,gene_matrix=gene_matrix)
Train_end = time.time()
print(f"All Time:{(Train_end-Train_start)*1000}ms")

Starting with population of size: 1
Currently in generation 1
--- Starting Adam
------------------------------
epoch:1 total_loss:39.21955490112305 
 ari:0.6556233683717171 nmi:0.6188162434724475 
 best ari:0.6556233683717171 bset nmi:0.6188162434724475
------------------------------
epoch:2 total_loss:39.00996017456055 
 ari:0.675597976854282 nmi:0.6109253710174847 
 best ari:0.6556233683717171 bset nmi:0.6188162434724475
------------------------------
epoch:3 total_loss:38.805694580078125 
 ari:0.7336493803718099 nmi:0.6676322329806802 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:4 total_loss:38.666748046875 
 ari:0.6870166054882278 nmi:0.6328723359716261 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:5 total_loss:38.45074462890625 
 ari:0.6815927127260328 nmi:0.6274357902450303 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:6 total_loss:38

------------------------------
epoch:48 total_loss:27.552427291870117 
 ari:0.4657672747049723 nmi:0.5352077251736199 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:49 total_loss:27.31348419189453 
 ari:0.4441363633618123 nmi:0.5238208866517975 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:50 total_loss:26.982147216796875 
 ari:0.4071348987554789 nmi:0.49463635203279466 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:51 total_loss:26.729164123535156 
 ari:0.3917246395767387 nmi:0.49104978093290347 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:52 total_loss:26.399532318115234 
 ari:0.38816889032877916 nmi:0.502841144371197 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:53 total_loss:26.16261100769043 
 ari:0.35906198930662336 nmi:0.46874444873171667 

------------------------------
epoch:95 total_loss:15.976348876953125 
 ari:0.2657703663381344 nmi:0.4097170979657187 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:96 total_loss:15.794008255004883 
 ari:0.26111759397519535 nmi:0.4095500026305438 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:97 total_loss:15.620155334472656 
 ari:0.28116430048906155 nmi:0.4093190753793526 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:98 total_loss:15.470423698425293 
 ari:0.2773905427728152 nmi:0.42609110020094837 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:99 total_loss:15.336413383483887 
 ari:0.2968234494861195 nmi:0.4152914246950944 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:100 total_loss:15.144491195678711 
 ari:0.2919336493157044 nmi:0.428479323052712

------------------------------
epoch:142 total_loss:10.309150695800781 
 ari:0.31060639301391224 nmi:0.41666217241272585 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:143 total_loss:10.24964714050293 
 ari:0.26818620705761786 nmi:0.40007668547583947 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:144 total_loss:10.180763244628906 
 ari:0.254132843823876 nmi:0.39860641692408555 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:145 total_loss:10.085018157958984 
 ari:0.26126294807956785 nmi:0.4008755650529237 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:146 total_loss:10.034904479980469 
 ari:0.24338568589260534 nmi:0.3996362765548727 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:147 total_loss:9.950241088867188 
 ari:0.27662931133496976 nmi:0.40934981

------------------------------
epoch:189 total_loss:7.8811469078063965 
 ari:0.21283257365266386 nmi:0.37138991194727744 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:190 total_loss:7.856992244720459 
 ari:0.24429055187166657 nmi:0.362951627540253 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:191 total_loss:7.8093414306640625 
 ari:0.20675100863749957 nmi:0.36578827126627056 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:192 total_loss:7.773139953613281 
 ari:0.2440874495146037 nmi:0.37303550449508677 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:193 total_loss:7.789043426513672 
 ari:0.22501349689550978 nmi:0.37446584652826637 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:194 total_loss:7.717839241027832 
 ari:0.22067101827044722 nmi:0.358003644

------------------------------
epoch:236 total_loss:6.734661102294922 
 ari:0.2216330144076983 nmi:0.3660290303083152 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:237 total_loss:6.646989345550537 
 ari:0.20907461934336388 nmi:0.35645679274221287 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:238 total_loss:6.669534206390381 
 ari:0.20109226813948983 nmi:0.35851141648803625 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:239 total_loss:6.651716232299805 
 ari:0.2046581714098745 nmi:0.36127402758458427 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:240 total_loss:6.6168012619018555 
 ari:0.21648137666892334 nmi:0.37136766866780846 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:241 total_loss:6.564619064331055 
 ari:0.22020540753043355 nmi:0.3774427728

------------------------------
epoch:283 total_loss:6.083150386810303 
 ari:0.20063489471643753 nmi:0.34881206664257813 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:284 total_loss:6.037908554077148 
 ari:0.2130507708472535 nmi:0.3581094074088259 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:285 total_loss:6.0167436599731445 
 ari:0.2066806033843023 nmi:0.3537046760396345 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:286 total_loss:6.0313849449157715 
 ari:0.19228737732013024 nmi:0.3468960107801847 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:287 total_loss:6.004184246063232 
 ari:0.20866089594227705 nmi:0.3431788148980132 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:288 total_loss:6.044660568237305 
 ari:0.2059616722212698 nmi:0.3571744261763

------------------------------
epoch:330 total_loss:5.719821929931641 
 ari:0.19427543137405168 nmi:0.32145514624639887 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:331 total_loss:5.735854148864746 
 ari:0.20376558121417176 nmi:0.31882474937193267 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:332 total_loss:5.740372657775879 
 ari:0.1928605482598188 nmi:0.33562860796919564 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:333 total_loss:5.708521842956543 
 ari:0.2080041330508392 nmi:0.3231819193479275 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:334 total_loss:5.736978054046631 
 ari:0.1922255906910044 nmi:0.3196028859633242 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:335 total_loss:5.699875831604004 
 ari:0.17965697560934105 nmi:0.3140733563892

------------------------------
epoch:377 total_loss:5.553556442260742 
 ari:0.15115208990465356 nmi:0.25759469917778216 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:378 total_loss:5.530569076538086 
 ari:0.17461439615804422 nmi:0.2833012656973836 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:379 total_loss:5.525123596191406 
 ari:0.14794329400349002 nmi:0.2464900484414518 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:380 total_loss:5.544450283050537 
 ari:0.20793184913832258 nmi:0.31344908415101375 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:381 total_loss:5.551052093505859 
 ari:0.17770121747313275 nmi:0.28854516559333865 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:382 total_loss:5.494585990905762 
 ari:0.17445234730278864 nmi:0.2944469917

------------------------------
epoch:424 total_loss:5.380226135253906 
 ari:0.10345195022486003 nmi:0.18076214731916498 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:425 total_loss:5.362119197845459 
 ari:0.13678505393251988 nmi:0.22751690938143682 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:426 total_loss:5.41951847076416 
 ari:0.11317197281527304 nmi:0.1678369457381147 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:427 total_loss:5.407985210418701 
 ari:0.10737499661698334 nmi:0.16566362314220284 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:428 total_loss:5.423600196838379 
 ari:0.1092803145079074 nmi:0.1880679833055547 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:429 total_loss:5.420503616333008 
 ari:0.09740354263405197 nmi:0.166124807181

------------------------------
epoch:471 total_loss:5.308261394500732 
 ari:0.03141150112880761 nmi:0.04939478127989821 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:472 total_loss:5.321782112121582 
 ari:0.01890478840376891 nmi:0.040256728937815796 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:473 total_loss:5.316341400146484 
 ari:0.05036880175455292 nmi:0.11426295181054416 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:474 total_loss:5.280928611755371 
 ari:0.02017218463792696 nmi:0.054801823632977 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:475 total_loss:5.312276840209961 
 ari:0.023172565481481177 nmi:0.04862906297150746 
 best ari:0.7336493803718099 bset nmi:0.6676322329806802
------------------------------
epoch:476 total_loss:5.315445899963379 
 ari:0.025012831158145028 nmi:0.0481240