In [1]:
# automatically upload modules
%load_ext autoreload
%autoreload 2

In [14]:
from argparse import Namespace
from collections import defaultdict
import re
import sys

import matplotlib.pyplot as plt
import numpy as np
import pickle
import ray
from ray import tune
import torch

from genome_embeddings import corrupt
from genome_embeddings import data_viz
from genome_embeddings import models
from genome_embeddings import pre_process
from genome_embeddings import trainable # import before ray (?)

In [9]:
tla_to_tnum, keepers = pre_process.genomes2include()
org_to_kos, n_kos_tot, all_kos = pre_process.load_kos(tla_to_tnum)
org_to_mod_to_kos, mod_sets = pre_process.load_mods()

all_kos = torch.load("/Users/natasha/Desktop/all_kos_2020-09-29.pt")
org_to_mod_to_kos = torch.load("/Users/natasha/Desktop/org_to_mod_to_kos_2020-09-29.pt")
train_data = torch.load("/Users/natasha/Desktop/kegg_v2_train_2020-09-29.pt")
test_data = torch.load("/Users/natasha/Desktop/kegg_v2_test_2020-09-29.pt")
train_genomes = torch.load("/Users/natasha/Desktop/kegg_v2_train_genomes_2020-09-29.pt")
test_genomes = torch.load("/Users/natasha/Desktop/kegg_v2_test_genomes_2020-09-29.pt")

mod_to_ko_clean = pre_process.clean_kos(mod_sets)

Total number of bacterial genomes in dataset: 2718
Total number of KOs in dataset: 9874


In [10]:
# Remove any genomes with fewer than 500 KOs 
# Esp. important to remove genomes with 0 KOs (n=35)

good_idx_train = train_data.sum(axis=1) > 500
good_idx_test = test_data.sum(axis=1) > 500
train_data = train_data[good_idx_train,:]
test_data = test_data[good_idx_test,:]

# to numpy for indexing, then back to list for using
train_genomes = list(np.array(train_genomes)[good_idx_train])
test_genomes = list(np.array(test_genomes)[good_idx_test])

In [11]:
date_to_load = "2020-10-16_10mods"

corrupted_train = torch.load("/Users/natasha/Desktop/corrupted_train_"+date_to_load+".pt")
c_train_genomes = torch.load("/Users/natasha/Desktop/c_train_genomes_"+date_to_load+".pt")
corrupted_test = torch.load("/Users/natasha/Desktop/corrupted_test_"+date_to_load+".pt")
c_test_genomes = torch.load("/Users/natasha/Desktop/c_test_genomes_"+date_to_load+".pt")

In [18]:
noise = models.get_noise(train_data.shape[0], train_data.shape[1]*10)

In [21]:
corrupted_test.shape

torch.Size([27900, 19748])

In [47]:
gen = models.Generator().to('cpu')

In [None]:
crit = models.Critic().to('cpu')

In [48]:
gen

Generator(
  (gen): Sequential(
    (0): Sequential(
      (0): Linear(in_features=59244, out_features=39496, bias=True)
      (1): BatchNorm2d(39496, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
    )
    (1): Sequential(
      (0): Linear(in_features=39496, out_features=19748, bias=True)
      (1): BatchNorm2d(19748, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
    )
    (2): Sequential(
      (0): Linear(in_features=19748, out_features=9874, bias=True)
      (1): Sigmoid()
    )
  )
)

In [None]:
generator_losses = []
critic_losses = []

device = torch.device("cpu") #"cuda" if use_cuda else "cpu")
num_features = int(train_data.shape[1]/2)
z_dim = num_features * 6

# get model instances, set to train mode
gen = models.Generator().to('cpu')
crit = models.Critic().to('cpu')
gen.train()
crit.train()

# define optimizers
beta_1 = 0.5
beta_2 = 0.999
lr = 0.001
gen_opt = torch.optim.AdamW(gen.parameters(), lr=lr, betas=(beta_1, beta_2), weight_decay=weight_decay)
crit_opt = torch.optim.AdamW(crit.parameters(), lr=lr, betas=(beta_1, beta_2), weight_decay=weight_decay)

In [None]:
# Create data loader
loaders = trainable.cv_dataloader_SINGLE(batch_size, num_features, kfolds, train_data, test_data)

In [None]:
for epoch in range(num_epochs):

    # enumerate batches in epoch
    for batch_idx, (_, real) in enumerate(loaders["train"]):
        cur_batch_size = len(real)
        real = real.to(device)		

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)				
            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()

        critic_losses += [mean_iteration_critic_loss]			

        ### Update generator ###
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)			

        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()			

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]			


        if batch_idx % 100 == 0:
            train_f1 = f1_score(pred, target, replacement_threshold)
            test_loss, test_f1 = cv_vae(model, loaders, replacement_threshold)
            train_losses.append(loss.item())
            test_losses.append(test_loss.item())
            train_f1s.append(train_f1)
            test_f1s.append(test_f1)
            print("epoch",epoch,"batch",batch_idx)
            print("train_loss",loss.item(), "train_f1",train_f1, "test_loss",test_loss.item(), "test_f1",test_f1) #, auc_score=auc)	
            model.train()

# if memory usage is high, may be able to free up space by calling garbage collect
auto_garbage_collect() 