In [1]:
import torch
from torch import nn

from torch.optim import lr_scheduler

from torch.utils.data import random_split,Dataset,DataLoader

import torch.nn.functional as F
import torch.nn.init as init

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from joblib.externals.loky.backend.context import get_context
import time
import copy
import random
import pickle
import tarfile

import time
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [2]:
import scanpy as sc
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
from sklearn.model_selection import train_test_split

In [3]:
from unifan.networks import Encoder, Decoder, Set2Gene
import random

# Hyperparameters

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
AE_batch_size=128
AE_learning_rate = 0.001
AE_decay_factor = 0.9
AE_epochs = 300
AE_random_seed = 123
best_ari = 0
best_nmi = 0

# Preprocessing

In [5]:
def processing(adata):
    
    
    sc.pp.filter_cells(adata, min_genes=20)
    sc.pp.filter_genes(adata, min_cells=3)
    
    data = adata.copy()
    
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.normalize_total(data, target_sum=1e4)
    sc.pp.log1p(adata)
    
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    
    data = data[:, adata.var.highly_variable]
    
    return data

# My Data

In [6]:
adata = sc.read("data/cortex.h5ad", dtype='float64')

In [7]:
adata = processing(adata)

In [8]:
gene_matrix = adata.X

In [9]:
G = adata.X.shape[1]

In [10]:
clusters_true = adata.obs["label"]

# Dataset&DataLoader

In [11]:
class MyDataset(Dataset):
    def __init__(self,gene_matrix):
        
        self.gene_matrix = torch.from_numpy(gene_matrix)
    
    def __len__(self):
        
        return self.gene_matrix.shape[0]

    def __getitem__(self,idx):
        
        X = self.gene_matrix[idx]
        
        return X

In [12]:
def create_loader(X,batch_size):

    #print(type(X))
    mydataset = MyDataset(X)

    dataloader = DataLoader(dataset=mydataset,batch_size=batch_size,shuffle=True)
    
    return dataloader

# autoencoder

In [13]:
class autoencoder(nn.Module):


    def __init__(self, input_dim: int = 10000,
                 hidden_dim:int = 128,encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 z_dim: int = 32,
                 dropout_rate:float = 0.1):

        super().__init__()
        
        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder,hidden_dim=encoder_dim)
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)
        

        # initialize loss
        self.mse_loss = nn.MSELoss()

    def forward(self, data):

        x = data

        z_e,_ = self.encoder(x)

        x_e,_ = self.decoder(z_e)

        return x_e, z_e

    def _loss_reconstruct(self, x, x_e):

        l_e = self.mse_loss(x, x_e)
        
        mse_l = l_e

        return mse_l

    def loss(self, x, x_e):
        
        l = self._loss_reconstruct(x, x_e)
        
        return l


# Pretrain Autoencoder

In [14]:
AE_dataloader = create_loader(gene_matrix,batch_size = AE_batch_size)

In [15]:
AE = autoencoder(input_dim = G,num_layers_decoder=3).to(device)

In [16]:
AE_optimizer = torch.optim.Adam(AE.parameters(), lr=AE_learning_rate)
AE_scheduler = torch.optim.lr_scheduler.StepLR(AE_optimizer, 1000,AE_decay_factor)

In [17]:
def cal_result_AE(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

In [18]:
for epoch in range(AE_epochs):
    
    total_loss = 0
    for batch_idx,X_batch in enumerate(AE_dataloader):
        X_batch = X_batch.to(device).float()
        
        AE_optimizer.zero_grad()

        x_e, z_e = AE(X_batch)

        loss = AE.loss(X_batch.float(), x_e.float())
        
        total_loss += loss
        
        
        loss.backward()
        AE_optimizer.step()
        AE_scheduler.step()

    ari,nmi = cal_result_AE(AE,AE_dataloader,clusters_true,gene_matrix)
    
    if best_ari < ari and best_nmi < nmi:
        
        best_ari = ari
        
        best_nmi = nmi
        
    print(f"epoch:{epoch+1} total_loss:{total_loss/len(AE_dataloader.sampler)} ari:{best_ari} nmi:{best_nmi}")

epoch:1 total_loss:0.16327348351478577 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:2 total_loss:0.1452455073595047 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:3 total_loss:0.1253722459077835 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:4 total_loss:0.10423865914344788 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:5 total_loss:0.09151306748390198 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:6 total_loss:0.0834527313709259 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:7 total_loss:0.06227501854300499 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:8 total_loss:0.06521636992692947 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:9 total_loss:0.06770803034305573 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:10 total_loss:0.04950794577598572 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:11 total_loss:0.044275205582380295 ari:0.4706816793932312 nmi:0.6633013612583779
epoch:12 total_loss:0.039720430970191956 ari:0.4706816793932312 n

epoch:97 total_loss:0.01940995268523693 ari:0.5641069674476631 nmi:0.6915853392852442
epoch:98 total_loss:0.019407309591770172 ari:0.565304948087691 nmi:0.6922163170525595
epoch:99 total_loss:0.019351104274392128 ari:0.565304948087691 nmi:0.6922163170525595
epoch:100 total_loss:0.019431812688708305 ari:0.565304948087691 nmi:0.6922163170525595
epoch:101 total_loss:0.019248533993959427 ari:0.565304948087691 nmi:0.6922163170525595
epoch:102 total_loss:0.019264720380306244 ari:0.565304948087691 nmi:0.6922163170525595
epoch:103 total_loss:0.01914284937083721 ari:0.565304948087691 nmi:0.6922163170525595
epoch:104 total_loss:0.01912068948149681 ari:0.565304948087691 nmi:0.6922163170525595
epoch:105 total_loss:0.019152432680130005 ari:0.565304948087691 nmi:0.6922163170525595
epoch:106 total_loss:0.019087698310613632 ari:0.565304948087691 nmi:0.6922163170525595
epoch:107 total_loss:0.019104955717921257 ari:0.565304948087691 nmi:0.6922163170525595
epoch:108 total_loss:0.019450409337878227 ari:0.

epoch:192 total_loss:0.01602635718882084 ari:0.6784445567900751 nmi:0.705557619747758
epoch:193 total_loss:0.01791669800877571 ari:0.6784445567900751 nmi:0.705557619747758
epoch:194 total_loss:0.016913410276174545 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:195 total_loss:0.016701264306902885 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:196 total_loss:0.019234267994761467 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:197 total_loss:0.018354618921875954 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:198 total_loss:0.018858924508094788 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:199 total_loss:0.016893258318305016 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:200 total_loss:0.016256151720881462 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:201 total_loss:0.01612790673971176 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:202 total_loss:0.016701025888323784 ari:0.7079299746789033 nmi:0.7216107751656855
epoch:203 total_loss:0.01644170470535

epoch:286 total_loss:0.014645556919276714 ari:0.7131274812263191 nmi:0.7225695717211571
epoch:287 total_loss:0.01593327522277832 ari:0.7131274812263191 nmi:0.7225695717211571
epoch:288 total_loss:0.017791781574487686 ari:0.7131274812263191 nmi:0.7225695717211571
epoch:289 total_loss:0.014721155166625977 ari:0.7131274812263191 nmi:0.7225695717211571
epoch:290 total_loss:0.014353224076330662 ari:0.716038964269005 nmi:0.7229155208359529
epoch:291 total_loss:0.01426575519144535 ari:0.716038964269005 nmi:0.7229155208359529
epoch:292 total_loss:0.01452100370079279 ari:0.716038964269005 nmi:0.7229155208359529
epoch:293 total_loss:0.015033071860671043 ari:0.716038964269005 nmi:0.7229155208359529
epoch:294 total_loss:0.014333422295749187 ari:0.716038964269005 nmi:0.7229155208359529
epoch:295 total_loss:0.014183432795107365 ari:0.716038964269005 nmi:0.7229155208359529
epoch:296 total_loss:0.014314747415482998 ari:0.716038964269005 nmi:0.7229155208359529
epoch:297 total_loss:0.014063561335206032 

In [19]:
AE.to(torch.device("cpu")).eval()
z_init,_ = AE.encoder(torch.from_numpy(gene_matrix))

In [20]:
z_init = z_init.detach().numpy()

In [21]:
adata = sc.AnnData(X=z_init)
adata.obsm['X_unifan'] = z_init
sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=AE_random_seed)
sc.tl.leiden(adata, resolution=0.5, random_state=AE_random_seed)
clusters_pre = adata.obs['leiden'].astype('int').values  # original as string

In [22]:
df_cluster = pd.DataFrame(z_init)
cluster_labels = np.unique(clusters_pre)
M = len(set(cluster_labels))  # set as number of clusters
df_cluster['cluster'] = clusters_pre

# get centroids
centroids = df_cluster.groupby('cluster').mean().values
centroids_torch = torch.from_numpy(centroids)

In [23]:
centroids = torch.from_numpy(centroids)

In [24]:
ari_smaller = adjusted_rand_score(clusters_pre,
                                  clusters_true)
nmi_smaller = adjusted_mutual_info_score(clusters_pre, clusters_true)

In [25]:
nmi_smaller

0.6900660378155391

In [26]:
ari_smaller

0.6470483905440783

# VQVAE

In [27]:
class VQVAE_T(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = True,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

In [28]:
class VQVAE_N(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = False,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

# Adam

In [29]:
def Adam_training(model,n_epochs,learning_rate,dataloader,clusters_true,gene_matrix,
                non_blocking = True,decay_factor = 0.9):
    
    model.to(device)
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000, decay_factor)
    
    for epoch in range(n_epochs):
        
        total_loss = 0
    
        for batch_idx, X_batch in enumerate(dataloader):
            
            X_batch = X_batch.to(device, non_blocking=non_blocking).float()

            optimizer.zero_grad(set_to_none=True)

            x_q, z_e, z_q = model(X_batch)

            l = model.loss(X_batch,x_q,z_e,z_q)
       
            total_loss += l

            l.backward(retain_graph=True)
            
            optimizer.step()
            
        ari,nmi = cal_result(model,dataloader,clusters_true,gene_matrix)
        global best_ari,best_nmi
        if best_ari < ari and best_nmi < nmi:

            best_ari = ari

            best_nmi = nmi
        
        print("------------------------------")
        print(f"epoch:{epoch+1} total_loss:{total_loss/len(dataloader)} \n ari:{ari.item()} nmi:{nmi.item()} \n best ari:{best_ari.item()} bset nmi:{best_nmi.item()}")
        
            
    # move network back to cpu and return
    model.cpu()
    
    return model

# GA

In [30]:
def crossover_and_mutation(parents, sigma=0.05):

    
    base_sd = parents[0].state_dict()
    keys = base_sd                    # use all layers to be affected
    
    # Sum of the weights of the parent
    for i in range(1, len(parents)):
        parent_sd = parents[i].state_dict()
        for key in keys:
            base_sd[key] = base_sd[key] + parent_sd[key]
            
    
    # Average and add mutation
    num_parents = len(parents)
    
    for key in keys:
        
        tensor_size = base_sd[key].size()
        random_tensor = torch.normal(mean=0.0, std=sigma, size=tensor_size).to(device)
        
        base_sd[key] = (base_sd[key] / num_parents) + random_tensor
    
    # create offspring
    
    if random.randint(0,1) == 0:
        offspring = VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
    else:
    
        offspring = VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
        
    return offspring
    

def create_offspring(population,fitness,rho,sigma):

    
    # Perform selection
    parents = random.choices(population, weights=fitness, k=rho) 
    
    # Perform crossover and mutation
    offspring = crossover_and_mutation(parents, sigma)
    
    
    return offspring

def sigmoid(x):
    
    return 1/(1 + np.exp(-x))


def GA_training(population, pop_size, offspring_size, elitist_level, rho, sigma, dataloader,clusters_true,gene_matrix):
    
    #Calculate fitness of trained population

    fitness = [cal_loss(population[i],dataloader)
                                for i in range(pop_size)]
    
    print(f"--- -- Finished fitness evaluation, length: {len(fitness)}")
    
    #Create offspring population
    fitness_weighted = [ 1/f for f in fitness]   # take inverse of loss so lower losses get higher fitness-values
    
    offspring_population = [create_offspring(population,fitness_weighted, rho, sigma) for i in range(offspring_size)]
    
    print("--- -- Finished creating offspring population")
    
    #Evaluate fitness of offsprings 
    
    offspring_fitness = [cal_loss(offspring_population[i],dataloader) 
                                                          for i in range(offspring_size)]
    
    print("--- -- Finished evaluating fitness of offspring population")
    
    # Combine fitness and population lists
    
    combined_fitness = fitness + offspring_fitness
    combined_population = population + offspring_population
    
    # sort and select population by their fitness values
    
    sorted_population = [pop for _, pop in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    sorted_fitness = [loss for loss, _ in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    
    m = int(pop_size * elitist_level)
    new_population = sorted_population[0:m]
    
    # Fill up rest of population
    difference = pop_size - m
    remaining_population = list(set(sorted_population) - set(new_population))
    filler_population = random.sample(remaining_population, difference)
    
    # assemble new population and return
    new_population = new_population + filler_population
    
    return new_population, sorted_fitness

In [31]:
def cal_loss(model,dataloader,non_blocking=True):
    
    model.to(device)

    total_loss = 0.0
    
    for batch_idx, X_batch in enumerate(dataloader):
        
        X_batch = X_batch.to(device, non_blocking=non_blocking).float()

        x_e, z_e, z_q = model(X_batch)

        l = model.loss(X_batch,x_e, z_e, z_q)
        
        total_loss += l
        

    return float(total_loss)

In [32]:
def cal_result(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e,z_q = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=0.5, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)
    
    ari = torch.from_numpy(np.array(ari)).to(device)
    nmi = torch.from_numpy(np.array(nmi)).to(device)
    
    return (ari,nmi)

# GA Neural

In [33]:
def GA_Neural_train(population,
                    pop_size,
                    max_generations, 
                    SGD_steps, GA_steps, 
                    offspring_size, elitist_level, rho,
                    learning_rate,
                    dataloader,
                    clusters_true,
                    gene_matrix):
    
    
    print(f"Starting with population of size: {pop_size}")

    for k in range(max_generations):
        print(f"Currently in generation {k+1}")
        
        #Adam
        print(f"--- Starting Adam")
        
        # Sequential version
        
        population_copy = []
        
        for i in range(pop_size):
            
            model = Adam_training(population[i],SGD_steps,learning_rate,dataloader,clusters_true,gene_matrix)
            population_copy.append(model)
        
        print(f"--- Finished Adam")
        
        population = population_copy
         
        # GA
        print(f"--- Starting Model GA")
        GA_start = time.time()
        sorted_fitness = []          # store the sorted fitness values to maybe use in data collection
        for i in range(0, GA_steps):
            
            sigma = 0.01 / (k+1)
            population, sorted_fitness = GA_training(population, 
                                                     pop_size, offspring_size, elitist_level, rho, sigma, dataloader,
                                                     clusters_true,gene_matrix)
        
        GA_end = time.time()
        
        print(f"--- Finished Model GA,Time:{(GA_end - GA_start) * 1000}ms")
        
        
    print(f"Finished training process")
    
    return population

# Train VQVAE

In [34]:
# Hyperparameters
#device_train = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 1024
pop_size_T = 5
pop_size_N = 5
max_generations = 50
SGD_steps = 10
GA_steps = 1
offspring_size = 30
elitist_level = 0.4
rho = 4
learning_rate = 1e-5
in_feats = 400
n_hidden = 200
weight_decay = 5e-4
best_nmi = 0
best_ari = 0

In [35]:
dataloader = create_loader(gene_matrix,batch_size = batch_size)
pop_size  = pop_size_T + pop_size_N

In [36]:
# Create population and start training process
population_T = [VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_T)]

In [37]:
population_N = [VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_N)]

In [38]:
population = population_T + population_N

In [None]:
Train_start = time.time()
trained_population = GA_Neural_train(population=population,
                                    pop_size = pop_size,
                                    max_generations=max_generations,
                                    SGD_steps=SGD_steps,GA_steps=GA_steps,
                                    offspring_size=offspring_size,elitist_level=elitist_level,rho=rho,
                                    learning_rate=learning_rate,
                                    dataloader=dataloader,clusters_true=clusters_true,gene_matrix=gene_matrix)
Train_end = time.time()
print(f"All Time:{(Train_end-Train_start)*1000}ms")

Starting with population of size: 10
Currently in generation 1
--- Starting Adam
------------------------------
epoch:1 total_loss:41.320098876953125 
 ari:0.6989767025033613 nmi:0.6692552408737789 
 best ari:0.6989767025033613 bset nmi:0.6692552408737789
------------------------------
epoch:2 total_loss:40.96586227416992 
 ari:0.6858101866845631 nmi:0.6665990628543681 
 best ari:0.6989767025033613 bset nmi:0.6692552408737789
------------------------------
epoch:3 total_loss:40.774654388427734 
 ari:0.7140051227490141 nmi:0.666465276992457 
 best ari:0.6989767025033613 bset nmi:0.6692552408737789
------------------------------
epoch:4 total_loss:40.62885284423828 
 ari:0.7279271039041839 nmi:0.6818900147191096 
 best ari:0.7279271039041839 bset nmi:0.6818900147191096
------------------------------
epoch:5 total_loss:40.38605880737305 
 ari:0.6979671714068743 nmi:0.6641408730459795 
 best ari:0.7279271039041839 bset nmi:0.6818900147191096
------------------------------
epoch:6 total_los

------------------------------
epoch:8 total_loss:40.25709533691406 
 ari:0.7014098373666928 nmi:0.6552440809567409 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:9 total_loss:40.10688018798828 
 ari:0.681490642957157 nmi:0.6489256974053764 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:10 total_loss:39.97138214111328 
 ari:0.7145804702237181 nmi:0.6668675844628787 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:1 total_loss:40.48075866699219 
 ari:0.6792261802740581 nmi:0.6364792388517426 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:2 total_loss:40.225074768066406 
 ari:0.7002944245809991 nmi:0.654008660719422 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:3 total_loss:40.03935623168945 
 ari:0.710667091141086 nmi:0.6559449192369544 
 best ari:0.75

------------------------------
epoch:6 total_loss:39.98693084716797 
 ari:0.6539377067044052 nmi:0.6107625434980248 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:7 total_loss:39.82053756713867 
 ari:0.6564460490134666 nmi:0.6189762111274486 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:8 total_loss:39.67375564575195 
 ari:0.6414655530433654 nmi:0.6117414467318151 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:9 total_loss:39.43836212158203 
 ari:0.6653659409063407 nmi:0.6295129788719043 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:10 total_loss:39.28074264526367 
 ari:0.6448111421646223 nmi:0.618706544182398 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
--- Finished Adam
--- Starting Model GA
--- -- Finished fitness evaluation, length: 10
--- -- Finished creating offspring populati

------------------------------
epoch:2 total_loss:40.12852096557617 
 ari:0.6388580414220764 nmi:0.5842599811574453 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:3 total_loss:39.989097595214844 
 ari:0.63693221478141 nmi:0.5838614857592129 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:4 total_loss:39.94097900390625 
 ari:0.5757647858949668 nmi:0.562042331662251 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:5 total_loss:39.881500244140625 
 ari:0.6143757347043066 nmi:0.5760879233412888 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:6 total_loss:39.804710388183594 
 ari:0.6418377385908237 nmi:0.5983910618949444 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:7 total_loss:39.724403381347656 
 ari:0.6032398445090031 nmi:0.5876856751554262 
 best ari:0.

------------------------------
epoch:10 total_loss:39.390098571777344 
 ari:0.5422762875619499 nmi:0.612691487686512 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:1 total_loss:40.12989807128906 
 ari:0.6468395398462696 nmi:0.601528718735103 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:2 total_loss:40.08815002441406 
 ari:0.6563523414006311 nmi:0.6084638256340459 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:3 total_loss:39.968536376953125 
 ari:0.6521892488581041 nmi:0.6045517224191183 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:4 total_loss:39.86040115356445 
 ari:0.6662159692742389 nmi:0.6341282874358692 
 best ari:0.7514624178784669 bset nmi:0.6915206707031415
------------------------------
epoch:5 total_loss:39.7926139831543 
 ari:0.6408539356014195 nmi:0.6096719714566855 
 best ari:0.7