In [1]:
import torch
from torch import nn

from torch.optim import lr_scheduler

from torch.utils.data import random_split,Dataset,DataLoader

import torch.nn.functional as F
import torch.nn.init as init

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed
from joblib.externals.loky.backend.context import get_context
import time
import copy
import random
import pickle
import tarfile

import time
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Ignore warnings
import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

In [2]:
import scanpy as sc
from sklearn.metrics.cluster import adjusted_rand_score, adjusted_mutual_info_score
from sklearn.model_selection import train_test_split

In [3]:
from unifan.networks import Encoder, Decoder, Set2Gene
import random

# Hyperparameters

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
AE_batch_size=128
AE_learning_rate = 0.001
AE_decay_factor = 0.9
AE_epochs = 300
AE_random_seed = 123
best_ari = 0
best_nmi = 0

# My Data

In [5]:
adata = sc.read("data/cortex.h5ad", dtype='float64')

In [6]:
gene_matrix = adata.X

In [7]:
G = adata.X.shape[1]

In [8]:
clusters_true = adata.obs["label"]

# Dataset&DataLoader

In [9]:
class MyDataset(Dataset):
    def __init__(self,gene_matrix):
        
        self.gene_matrix = torch.from_numpy(gene_matrix)
    
    def __len__(self):
        
        return self.gene_matrix.shape[0]

    def __getitem__(self,idx):
        
        X = self.gene_matrix[idx]
        
        return X

In [10]:
def create_loader(X,batch_size):

    #print(type(X))
    mydataset = MyDataset(X)

    dataloader = DataLoader(dataset=mydataset,batch_size=batch_size,shuffle=True)
    
    return dataloader

# autoencoder

In [11]:
class autoencoder(nn.Module):


    def __init__(self, input_dim: int = 10000,
                 hidden_dim:int = 128,encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 z_dim: int = 32,
                 dropout_rate:float = 0.1):

        super().__init__()
        
        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder,hidden_dim=encoder_dim)
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

        # initialize loss
        self.mse_loss = nn.MSELoss()

    def forward(self, data):

        x = data

        z_e,_ = self.encoder(x)

        x_e,_ = self.decoder(z_e)

        return x_e, z_e

    def _loss_reconstruct(self, x, x_e):

        l_e = self.mse_loss(x, x_e)
        
        mse_l = l_e

        return mse_l

    def loss(self, x, x_e):
        
        l = self._loss_reconstruct(x, x_e)
        
        return l


# Pretrain Autoencoder

In [12]:
AE_dataloader = create_loader(gene_matrix,batch_size = AE_batch_size)

In [13]:
AE = autoencoder(input_dim = G).to(device)

In [14]:
AE_optimizer = torch.optim.Adam(AE.parameters(), lr=AE_learning_rate)
AE_scheduler = torch.optim.lr_scheduler.StepLR(AE_optimizer, 1000,AE_decay_factor)

In [15]:
def cal_result_AE(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=1, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)

    return ari,nmi

In [16]:
for epoch in range(AE_epochs):
    
    total_loss = 0
    for batch_idx,X_batch in enumerate(AE_dataloader):
        X_batch = X_batch.to(device).float()
        
        AE_optimizer.zero_grad()

        x_e, z_e = AE(X_batch)

        loss = AE.loss(X_batch.float(), x_e.float())
        
        total_loss += loss
        
        
        loss.backward()
        AE_optimizer.step()
        AE_scheduler.step()

    ari,nmi = cal_result_AE(AE,AE_dataloader,clusters_true,gene_matrix)
    
    if best_ari < ari and best_nmi < nmi:
        
        best_ari = ari
        
        best_nmi = nmi
        
    print(f"epoch:{epoch+1} total_loss:{total_loss/len(AE_dataloader.sampler)} ari:{best_ari} nmi:{best_nmi}")

epoch:1 total_loss:0.578575849533081 ari:0.1354668020671266 nmi:0.350115993041843
epoch:2 total_loss:0.4777005612850189 ari:0.1354668020671266 nmi:0.350115993041843
epoch:3 total_loss:0.24634608626365662 ari:0.1354668020671266 nmi:0.350115993041843
epoch:4 total_loss:0.1783933937549591 ari:0.1354668020671266 nmi:0.350115993041843
epoch:5 total_loss:0.17381924390792847 ari:0.1354668020671266 nmi:0.350115993041843
epoch:6 total_loss:0.1695861518383026 ari:0.1354668020671266 nmi:0.350115993041843
epoch:7 total_loss:0.16246333718299866 ari:0.1354668020671266 nmi:0.350115993041843
epoch:8 total_loss:0.15150859951972961 ari:0.1354668020671266 nmi:0.350115993041843
epoch:9 total_loss:0.13707037270069122 ari:0.1354668020671266 nmi:0.350115993041843
epoch:10 total_loss:0.13028106093406677 ari:0.1354668020671266 nmi:0.350115993041843
epoch:11 total_loss:0.12571397423744202 ari:0.1354668020671266 nmi:0.350115993041843
epoch:12 total_loss:0.12524577975273132 ari:0.1354668020671266 nmi:0.3501159930

epoch:96 total_loss:0.046766847372055054 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:97 total_loss:0.04572204872965813 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:98 total_loss:0.04616906866431236 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:99 total_loss:0.05821797251701355 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:100 total_loss:0.062024082988500595 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:101 total_loss:0.04951942712068558 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:102 total_loss:0.045553192496299744 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:103 total_loss:0.042643703520298004 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:104 total_loss:0.039803870022296906 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:105 total_loss:0.03604646027088165 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:106 total_loss:0.03486742824316025 ari:0.2007141364282903 nmi:0.45785123606882006
epoch:107 total_loss:0.03628405

epoch:189 total_loss:0.026857364922761917 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:190 total_loss:0.02648441307246685 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:191 total_loss:0.026217784732580185 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:192 total_loss:0.02613312005996704 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:193 total_loss:0.026132192462682724 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:194 total_loss:0.026009725406765938 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:195 total_loss:0.026218293234705925 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:196 total_loss:0.02689262293279171 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:197 total_loss:0.026413673534989357 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:198 total_loss:0.026313144713640213 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:199 total_loss:0.0257274117320776 ari:0.29086062247198063 nmi:0.5488297512251205
epoch:200 total_loss:0.026

epoch:283 total_loss:0.023505432531237602 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:284 total_loss:0.023157795891165733 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:285 total_loss:0.023089606314897537 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:286 total_loss:0.023495810106396675 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:287 total_loss:0.023581087589263916 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:288 total_loss:0.023626849055290222 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:289 total_loss:0.022976554930210114 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:290 total_loss:0.023284701630473137 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:291 total_loss:0.02346745692193508 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:292 total_loss:0.023128047585487366 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:293 total_loss:0.022786790505051613 ari:0.3112921929585136 nmi:0.5794293557866982
epoch:294 total_loss:0.0228490643

In [17]:
AE.to(torch.device("cpu")).eval()
z_init,_ = AE.encoder(torch.from_numpy(gene_matrix))

In [18]:
z_init = z_init.detach().numpy()

In [19]:
adata = sc.AnnData(X=z_init)
adata.obsm['X_unifan'] = z_init
sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=AE_random_seed)
sc.tl.leiden(adata, resolution=1, random_state=AE_random_seed)
clusters_pre = adata.obs['leiden'].astype('int').values  # original as string

In [20]:
df_cluster = pd.DataFrame(z_init)
cluster_labels = np.unique(clusters_pre)
M = len(set(cluster_labels))  # set as number of clusters
df_cluster['cluster'] = clusters_pre

# get centroids
centroids = df_cluster.groupby('cluster').mean().values
centroids_torch = torch.from_numpy(centroids)

In [21]:
centroids = torch.from_numpy(centroids)

In [22]:
ari_smaller = adjusted_rand_score(clusters_pre,
                                  clusters_true)
nmi_smaller = adjusted_mutual_info_score(clusters_pre, clusters_true)

In [23]:
nmi_smaller

0.5750889813077029

In [24]:
ari_smaller

0.30627858255204143

# VQVAE

In [25]:
class VQVAE_T(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = True,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

In [26]:
class VQVAE_N(nn.Module):

    def __init__(self, input_dim: int = 10000, z_dim: int = 32,
                  encoder_dim: int = 128, emission_dim: int = 128,
                 num_layers_encoder: int = 2,num_layers_decoder: int = 2,
                 beta: float = 1.0,gama:float =0.25,
                 n_clusters: int = 16,
                 hidden_dim: int = 128, dropout_rate: float = 0.1, use_t_dist: bool = False,
                 centroids: torch.Tensor = None):

        super().__init__()

        # initialize parameters
        self.z_dim = z_dim
        self.beta = beta
        self.gama = gama
        self.n_clusters = n_clusters
        self.use_t_dist = use_t_dist
        
        # initialize centroids embeddings
        if centroids is not None:
            self.embeddings = nn.Parameter(centroids, requires_grad=True)
        else:
            self.embeddings = nn.Parameter(torch.randn(self.n_clusters, self.z_dim) * 0.05, requires_grad=True)

        # initialize loss
        self.mse_loss = nn.MSELoss()

        self.encoder = Encoder(input_dim, z_dim, num_layers=num_layers_encoder, hidden_dim=encoder_dim,
                                   dropout_rate=dropout_rate)
        
        self.decoder = Decoder(z_dim, input_dim, num_layers=num_layers_decoder,hidden_dim=emission_dim)

    def forward(self, x):

        # get encoding
        z_e, _ = self.encoder(x)

        # get the index of embedding closed to the encoding
        k, z_dist, dist_prob = self._get_clusters(z_e)

        # get embeddings (discrete representations)
        z_q = self._get_embeddings(k)

        # decode embedding (discrete representation) and encoding
        x_q, _ = self.decoder(z_e + (z_q-z_e).detach())

        return x_q, z_e, z_q

    def _get_clusters(self, z_e):


        _z_dist = (z_e.unsqueeze(1) - self.embeddings.unsqueeze(0)) ** 2
        z_dist = torch.sum(_z_dist, dim=-1)
        if self.use_t_dist:
            dist_prob = self._t_dist_sim(z_dist, df=10)
            k = torch.argmax(dist_prob, dim=-1)
        else:
            k = torch.argmin(z_dist, dim=-1)
            dist_prob = None

        return k, z_dist, dist_prob

    def _t_dist_sim(self, z_dist, df=10):


        _factor = - ((df + 1) / 2)
        dist_prob = torch.pow((1 + z_dist / df), _factor)
        dist_prob = dist_prob / dist_prob.sum(axis=1).unsqueeze(1)

        return dist_prob

    def _get_embeddings(self, k):


        k = k.long()
        _z_q = []
        for i in range(len(k)):
            _z_q.append(self.embeddings[k[i]])

        z_q = torch.stack(_z_q)

        return z_q


    def _loss_reconstruct(self,x,x_q, z_e, z_q):


        l_x = self.mse_loss(x, x_q)
        
        l_q = self.mse_loss(z_e.detach(),z_q)
        
        l_e = self.mse_loss(z_e,z_q.detach())
        
        mse_l = l_e + self.beta*l_q + self.gama*l_x
        
        return mse_l


    def loss(self,x ,x_e ,z_e ,z_q):

        mse_l = self._loss_reconstruct(x,x_e,z_e,z_q)

        return mse_l

# Adam

In [27]:
def Adam_training(model,n_epochs,learning_rate,dataloader,clusters_true,gene_matrix,
                non_blocking = True,decay_factor = 0.9):
    
    model.to(device)
    model.train()
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1000, decay_factor)
    
    for epoch in range(n_epochs):
        
        total_loss = 0
    
        for batch_idx, X_batch in enumerate(dataloader):
            
            X_batch = X_batch.to(device, non_blocking=non_blocking).float()

            optimizer.zero_grad(set_to_none=True)

            x_q, z_e, z_q = model(X_batch)

            l = model.loss(X_batch,x_q,z_e,z_q)
       
            total_loss += l

            l.backward(retain_graph=True)
            
            optimizer.step()
            
        ari,nmi = cal_result(model,dataloader,clusters_true,gene_matrix)
        global best_ari,best_nmi
        if best_ari < ari and best_nmi < nmi:

            best_ari = ari

            best_nmi = nmi
        
        print("------------------------------")
        print(f"epoch:{epoch+1} total_loss:{total_loss/len(dataloader)} \n ari:{ari.item()} nmi:{nmi.item()} \n best ari:{best_ari.item()} bset nmi:{best_nmi.item()}")
        
            
    # move network back to cpu and return
    model.cpu()
    
    return model

# GA

In [28]:
def crossover_and_mutation(parents, sigma=0.05):

    
    base_sd = parents[0].state_dict()
    keys = base_sd                    # use all layers to be affected
    
    # Sum of the weights of the parent
    for i in range(1, len(parents)):
        parent_sd = parents[i].state_dict()
        for key in keys:
            base_sd[key] = base_sd[key] + parent_sd[key]
            
    
    # Average and add mutation
    num_parents = len(parents)
    
    for key in keys:
        
        tensor_size = base_sd[key].size()
        random_tensor = torch.normal(mean=0.0, std=sigma, size=tensor_size).to(device)
        
        base_sd[key] = (base_sd[key] / num_parents) + random_tensor
    
    # create offspring
    
    if random.randint(0,1) == 0:
        offspring = VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
    else:
    
        offspring = VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0])
    
        offspring.load_state_dict(base_sd)
        
    return offspring
    

def create_offspring(population,fitness,rho,sigma):

    
    # Perform selection
    parents = random.choices(population, weights=fitness, k=rho) 
    
    # Perform crossover and mutation
    offspring = crossover_and_mutation(parents, sigma)
    
    
    return offspring

def sigmoid(x):
    
    return 1/(1 + np.exp(-x))


def GA_training(population, pop_size, offspring_size, elitist_level, rho, sigma, dataloader,clusters_true,gene_matrix):
    
    #Calculate fitness of trained population

    fitness = [cal_result(population[i],dataloader,clusters_true,gene_matrix)[0]+cal_result(population[i],dataloader,clusters_true,gene_matrix)[1] 
                                for i in range(pop_size)]
    
    print(f"--- -- Finished fitness evaluation, length: {len(fitness)}")
    
    #Create offspring population
    fitness_weighted = [f for f in fitness]   # take inverse of loss so lower losses get higher fitness-values
    
    offspring_population = [create_offspring(population,fitness_weighted, rho, sigma) for i in range(offspring_size)]
    
    print("--- -- Finished creating offspring population")
    
    #Evaluate fitness of offsprings 
    
    offspring_fitness = [cal_result(offspring_population[i],dataloader,clusters_true,gene_matrix)[0]+cal_result(offspring_population[i],dataloader,clusters_true,gene_matrix)[1] 
                                                          for i in range(offspring_size)]
    
    print("--- -- Finished evaluating fitness of offspring population")
    
    # Combine fitness and population lists
    
    combined_fitness = fitness + offspring_fitness
    combined_population = population + offspring_population
    
    # sort and select population by their fitness values
    
    sorted_population = [pop for _, pop in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    sorted_fitness = [loss for loss, _ in sorted(zip(combined_fitness, combined_population), key=lambda pair: pair[0])]
    
    m = int(pop_size * elitist_level)
    new_population = sorted_population[0:m]
    
    # Fill up rest of population
    difference = pop_size - m
    remaining_population = list(set(sorted_population) - set(new_population))
    filler_population = random.sample(remaining_population, difference)
    
    # assemble new population and return
    new_population = new_population + filler_population
    
    return new_population, sorted_fitness

In [29]:
def cal_loss(model,dataloader,non_blocking=True):
    
    model.to(device)

    total_loss = 0.0
    
    for batch_idx, X_batch in enumerate(dataloader):
        
        X_batch = X_batch.to(device, non_blocking=non_blocking).float()

        x_e, z_e, z_q = model(X_batch)

        l = model.loss(X_batch,x_e, z_e, z_q)
        
        total_loss += l
        
    model.cpu()

    return float(total_loss)

In [30]:
def cal_result(model,dataloader,clusters_true,gene_matrix):
    
    model.to(device)
    model.eval()
    
    gene_matrix = torch.from_numpy(gene_matrix).to(device)

    x_e,z_e,z_q = model(gene_matrix)
    
    z_e = z_e.detach().cpu().numpy()
    
    adata.obsm['X_unifan'] = z_e
    
    sc.pp.neighbors(adata, n_pcs=32,use_rep='X_unifan', random_state=123)
    
    sc.tl.leiden(adata, resolution=1, random_state=123)
    
    clusters_pre = adata.obs['leiden'].astype('int').values  # original as string
    
    ari = adjusted_rand_score(clusters_pre, clusters_true)

    nmi = adjusted_mutual_info_score(clusters_pre, clusters_true)
    
    ari = torch.from_numpy(np.array(ari)).to(device)
    nmi = torch.from_numpy(np.array(nmi)).to(device)
    
    return (ari,nmi)

# GA Neural

In [31]:
def GA_Neural_train(population,
                    pop_size,
                    max_generations, 
                    SGD_steps, GA_steps, 
                    offspring_size, elitist_level, rho,
                    learning_rate,
                    dataloader,
                    clusters_true,
                    gene_matrix):
    
    
    print(f"Starting with population of size: {pop_size}")

    for k in range(max_generations):
        print(f"Currently in generation {k+1}")
        
        #Adam
        print(f"--- Starting Adam")
        
        # Sequential version
        
        population_copy = []
        
        for i in range(pop_size):
            
            model = Adam_training(population[i],SGD_steps,learning_rate,dataloader,clusters_true,gene_matrix)
            population_copy.append(model)
        
        print(f"--- Finished Adam")
        
        population = population_copy
         
        # GA
        print(f"--- Starting Model GA")
        GA_start = time.time()
        sorted_fitness = []          # store the sorted fitness values to maybe use in data collection
        for i in range(0, GA_steps):
            
            sigma = 0.01 / (k+1)
            population, sorted_fitness = GA_training(population, 
                                                     pop_size, offspring_size, elitist_level, rho, sigma, dataloader,
                                                     clusters_true,gene_matrix)
        
        GA_end = time.time()
        
        print(f"--- Finished Model GA,Time:{(GA_end - GA_start) * 1000}ms")
        
        
    print(f"Finished training process")
    
    return population

# Train VQVAE

In [32]:
# Hyperparameters
#device_train = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 1024
pop_size_T = 5
pop_size_N = 5
max_generations = 50
SGD_steps = 10
GA_steps = 1
offspring_size = 30
elitist_level = 0.4
rho = 4
learning_rate = 1e-5
in_feats = 400
n_hidden = 200
weight_decay = 5e-4
best_nmi = 0
best_ari = 0

In [33]:
dataloader = create_loader(gene_matrix,batch_size = batch_size)
pop_size  = pop_size_T + pop_size_N

In [34]:
# Create population and start training process
population_T = [VQVAE_T(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_T)]

In [35]:
population_N = [VQVAE_N(input_dim = gene_matrix.shape[1],
                            beta = 1,gama=0.25, n_clusters = centroids.shape[0],centroids = centroids).to(device)
                     for i in range(pop_size_N)]

In [36]:
population = population_T + population_N

In [None]:
Train_start = time.time()
trained_population = GA_Neural_train(population=population,
                                    pop_size = pop_size,
                                    max_generations=max_generations,
                                    SGD_steps=SGD_steps,GA_steps=GA_steps,
                                    offspring_size=offspring_size,elitist_level=elitist_level,rho=rho,
                                    learning_rate=learning_rate,
                                    dataloader=dataloader,clusters_true=clusters_true,gene_matrix=gene_matrix)
Train_end = time.time()
print(f"All Time:{(Train_end-Train_start)*1000}ms")

Starting with population of size: 10
Currently in generation 1
--- Starting Adam
------------------------------
epoch:1 total_loss:116.62741088867188 
 ari:0.27351531056082723 nmi:0.4398371786005851 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:2 total_loss:114.07743835449219 
 ari:0.19649953505169468 nmi:0.4161126552097777 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:3 total_loss:112.17552185058594 
 ari:0.24446408813638498 nmi:0.4234200971548962 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:4 total_loss:110.59776306152344 
 ari:0.25564905243911284 nmi:0.42014647978772546 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:5 total_loss:109.08436584472656 
 ari:0.2397818305482188 nmi:0.4210009006399322 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epo

------------------------------
epoch:7 total_loss:109.08184051513672 
 ari:0.1611259018207618 nmi:0.3524989976413478 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:8 total_loss:107.55641174316406 
 ari:0.19751229646754775 nmi:0.3662451158929468 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:9 total_loss:105.89459228515625 
 ari:0.16111378438818824 nmi:0.34561020559336414 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:10 total_loss:104.30343627929688 
 ari:0.16300387244603642 nmi:0.34405635671716894 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:1 total_loss:114.5887680053711 
 ari:0.17667774236424766 nmi:0.38371788886947683 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
------------------------------
epoch:2 total_loss:112.44002532958984 
 ari:0.19300955128354774 nmi:0.401845271522

------------------------------
epoch:4 total_loss:112.48332214355469 
 ari:0.1881445968085869 nmi:0.37580548435177913 
 best ari:0.27351531056082723 bset nmi:0.4398371786005851
