In [1]:
# automatically upload modules
%load_ext autoreload
%autoreload 2

In [2]:
from argparse import Namespace
from collections import defaultdict
import re
import sys

import matplotlib.pyplot as plt
import numpy as np
import pickle
import ray
from ray import tune
import torch

from genome_embeddings import corrupt
from genome_embeddings import data_viz
from genome_embeddings import models
from genome_embeddings import pre_process
from genome_embeddings import trainable # import before ray (?)

In [None]:
DATA_FP = '/home/ndudek/projects/def-dprecup/ndudek/hp_tuning_04-11-2020/'

In [3]:
tla_to_tnum, keepers = pre_process.genomes2include()
org_to_kos, n_kos_tot, all_kos = pre_process.load_kos(tla_to_tnum)
org_to_mod_to_kos, mod_sets = pre_process.load_mods()

all_kos = torch.load("/Users/natasha/Desktop/all_kos_2020-09-29.pt")
org_to_mod_to_kos = torch.load("/Users/natasha/Desktop/org_to_mod_to_kos_2020-09-29.pt")
train_data = torch.load("/Users/natasha/Desktop/kegg_v2_train_2020-09-29.pt")
test_data = torch.load("/Users/natasha/Desktop/kegg_v2_test_2020-09-29.pt")
train_genomes = torch.load("/Users/natasha/Desktop/kegg_v2_train_genomes_2020-09-29.pt")
test_genomes = torch.load("/Users/natasha/Desktop/kegg_v2_test_genomes_2020-09-29.pt")

mod_to_ko_clean = pre_process.clean_kos(mod_sets)

Total number of bacterial genomes in dataset: 2718
Total number of KOs in dataset: 9874


In [4]:
# Remove any genomes with fewer than 500 KOs 
# Esp. important to remove genomes with 0 KOs (n=35)

good_idx_train = train_data.sum(axis=1) > 500
good_idx_test = test_data.sum(axis=1) > 500
train_data = train_data[good_idx_train,:]
test_data = test_data[good_idx_test,:]

# to numpy for indexing, then back to list for using
train_genomes = list(np.array(train_genomes)[good_idx_train])
test_genomes = list(np.array(test_genomes)[good_idx_test])

In [5]:
date_to_load = "2020-10-16_10mods"

# corrupted_train = torch.load("/Users/natasha/Desktop/corrupted_train_"+date_to_load+".pt")
# c_train_genomes = torch.load("/Users/natasha/Desktop/c_train_genomes_"+date_to_load+".pt")
# corrupted_test = torch.load("/Users/natasha/Desktop/corrupted_test_"+date_to_load+".pt")
# c_test_genomes = torch.load("/Users/natasha/Desktop/c_test_genomes_"+date_to_load+".pt")

corrupted_train = torch.load(DATA_FP+"corrupted_train_"+date_to_load+".pt")
c_train_genomes = torch.load(DATA_FP+"c_train_genomes_"+date_to_load+".pt")
corrupted_test = torch.load(DATA_FP+"corrupted_test_"+date_to_load+".pt")
c_test_genomes = torch.load(DATA_FP+"c_test_genomes_"+date_to_load+".pt")

In [10]:
device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
num_features = int(train_data.shape[1]/2)
z_dim = num_features * 6

# get model instances, set to train mode
gen = models.Generator().to('cpu')
crit = models.Critic().to('cpu')
gen.train()
crit.train()

# define optimizers
beta_1 = 0.5
beta_2 = 0.999
lr = 0.001
weight_decay = 0.001
gen_opt = torch.optim.AdamW(gen.parameters(), lr=lr, betas=(beta_1, beta_2), weight_decay=weight_decay)
crit_opt = torch.optim.AdamW(crit.parameters(), lr=lr, betas=(beta_1, beta_2), weight_decay=weight_decay)

In [7]:
batch_size = 128
kfolds = 10
loaders = trainable.cv_dataloader_SINGLE(batch_size, num_features, kfolds, corrupted_train, corrupted_test)

In [13]:
n_features = int(corrupted_train.shape[1]/2)
n_features

9874

In [23]:
# num_epochs = 1
# crit_repeats = 1 # num times to update critic per generator update
# gen_repeats = 1 # num times to update generator per critic update
# c_lambda = 10 # the weight of the gradient penalty 
# critic_losses = []
# generator_losses = []
# z_dim = n_features * 6
# steps = 0

# for epoch in range(num_epochs):
#     for batch_idx, (_, real) in enumerate(loaders["train"]):
#         cur_batch_size = len(real)
#         real = real.to(device)

#         mean_iteration_critic_loss = 0
#         for _ in range(crit_repeats):
#             ### Update critic ###
#             crit_opt.zero_grad()
#             fake_noise = trainable.get_noise(cur_batch_size, z_dim, device=device)
#             fake = gen(fake_noise)
#             crit_fake_pred = crit(fake.detach())
#             crit_real_pred = crit(real)
            
#             # Calculate loss for this iteration
#             epsilon = torch.rand(len(real), 1, device=device, requires_grad=True)
#             gradient = trainable.get_gradient(crit, real, fake.detach(), epsilon)
#             gp = trainable.gradient_penalty(gradient)
#             crit_loss = trainable.get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
#             mean_iteration_critic_loss += crit_loss.item() / crit_repeats

#             # Update gradients
#             crit_loss.backward(retain_graph=True)
#             # Update optimizer
#             crit_opt.step()
            
#         critic_losses += [mean_iteration_critic_loss]
        
#         mean_iteration_gen_loss = 0
#         for _ in range(gen_repeats):
#             ### Update generator ###
#             gen_opt.zero_grad()
#             fake_noise_2 = trainable.get_noise(cur_batch_size, z_dim, device=device)
#             fake_2 = gen(fake_noise_2)
#             crit_fake_pred = crit(fake_2)

#             gen_loss = trainable.get_gen_loss(crit_fake_pred)
#             gen_loss.backward()
#             mean_iteration_gen_loss += gen_loss.item() / gen_repeats
#             # Update the weights
#             gen_opt.step()

#         # Keep track of the average generator loss
#         generator_losses += [mean_iteration_gen_loss]
        
#         steps += 1

#         if batch_idx % 1 == 0:
#             print("epoch",epoch,"batch",batch_idx,"gen_loss",gen_loss.item(), 
#                      "mean_iteration_critic_loss",mean_iteration_critic_loss)
            
#     # if memory usage is high, may be able to free up space by calling garbage collect
#     auto_garbage_collect()  

[autoreload of genome_embeddings.trainable failed: Traceback (most recent call last):
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 245, in check
    superreload(m, reload, self.old_objects)
  File "/usr/local/lib/python3.7/site-packages/IPython/extensions/autoreload.py", line 394, in superreload
    module = reload(module)
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/imp.py", line 314, in reload
    return importlib.reload(module)
  File "/usr/local/Cellar/python/3.7.6_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 630, in _exec
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/Users/natasha/Desktop/mcgill_postdoc/ncbi_genomes/genome_embeddings/trainabl

In [19]:
# data_viz.gan_learning_curve(generator_losses, critic_losses)

In [15]:
### Define and train network

In [16]:
memory = 2000 * 1024 * 1024
object_store_memory = 200 * 1024 * 1024
driver_object_store_memory=100 * 1024 * 1024
ray.shutdown()
ray.init(local_mode=True, memory=memory, 
        object_store_memory=object_store_memory,
        driver_object_store_memory=driver_object_store_memory,
        num_cpus=10)

{}

In [17]:
config = {"num_epochs": 10,
         "kfolds": 10,
         "batch_size": tune.choice([32, 64, 128, 256]),
          "lr": tune.loguniform(1e-4, 1e-1),
          "beta1": tune.loguniform(0.4, 0.6),
          "beta2": tune.uniform(0.9, 1),
          "weight_decay": tune.loguniform(1e-5, 1e-2),
          "architecture": tune.choice([1,2,3]),
          "crit_to_gen_repeats": tune.choice([0,1]),
          "c_lambda": tune.uniform([0,50]),
          "batch_size": 128,
         }

In [None]:
analysis = tune.run(
    trainable.train_GAN, 
    name="nov4_gan",
    config=config,
    verbose=2, 
    resources_per_trial={
            "cpu": 10,
            "gpu": 0
    },
    num_samples=20,  
    queue_trials=True,
    #local_dir="/Users/natasha/Desktop/TUNE_RESULT_DIR",
    local_dir="/home/ndudek/projects/def-dprecup/ndudek/hp_tuning_04-11-2020/TUNE_RESULT_DIR"
    )