In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
TB_writer = SummaryWriter()
from GAN_model.resnetGAN import *
from GAN_model.ProteoGAN import *
from utils.dataloader import *
from utils.AAV import *
from utils.torch_utils import *
from utils.metrics import *

In [2]:
args_run_name = 'ProteoGAN_AAV_mutational'
# args; TODO: move these to parser for script
args_batch_size=512; args_seed=1; args_eval_n_seq = 500
args_epoch = 300
# GAN specific param
args_z_dim = 128; args_dim = 256
# optimizer specific param
args_lr_Gen = 1e-4; args_lr_Disc = 8e-5; args_loss = 'hinge'
# train discriminator more initially
# according to https://livebook.manning.com/book/gans-in-action/chapter-5/185
args_disc_iters_init = 2; args_disc_iters = 1; args_disc_iters_init_epoch = 50
# make dir
os.makedirs('checkpoints/'+args_run_name, exist_ok=True)
os.makedirs('out/'+args_run_name+'/plot_out/', exist_ok=True)
os.makedirs('out/'+args_run_name+'/seq_out/', exist_ok=True)

In [3]:
# preparing fasta file
df = pd.read_csv('AAV_data/AAV_library.csv')
_, y, Seqs = get_AAV_X_y_aa(df, large_only=True, return_str=True)
fasta_name = 'AAV_data/aav_mutational.fasta'
outfile = open(fasta_name, 'w')
Seq_set = set(Seqs)
i = 0
for seq in Seqs:
    tokens = len(tokenize_mutation_seq(seq))
    if seq in Seq_set and (tokens) == 29:
        seq_id = 'AAV_seq_' + str(i) + '_score_' + str(y[i].item())
        outfile.write('>' + seq_id + '\n')
        outfile.write(seq + '\n')
        Seq_set.remove(seq)
        i += 1
outfile.close()

In [4]:
seed_everything(args_seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# keep dataset in cpu
AAV_data = AAV_dataset('AAV_data/aav_mutational.fasta', torch.device('cpu')) 

### Creating data sets

In [5]:
# creating data set with train/val split
train_len = int(AAV_data.length - AAV_data.length*0.15 )
val_len = AAV_data.length - train_len
train_data, val_data = torch.utils.data.random_split(
    AAV_data, [train_len, val_len], 
    generator=torch.Generator().manual_seed(args_seed))
# decoding train and validation sequences for evaluation
val_seqs = [decode_aav_mutation_seq(torch.argmax(data, dim=-1).detach().cpu()) 
            for data in val_data]
tr_seqs = [decode_aav_mutation_seq(torch.argmax(data, dim=-1).detach().cpu()) 
           for data in train_data]
# creating dataloader for training data
loader = torch.utils.data.DataLoader(
    train_data, batch_size=args_batch_size, shuffle=True,
    num_workers=1, pin_memory=True)
# fixed parameters
n_chars = 21; seq_len = 57

### Evaluating positive and negative control sequences to obtain baseline metrics
Positive: sample of real sequences and simulates a perfect model

Negative: a sample that simulates the worst possible model for each metric: \
    (constant sequence for MMD, repeated sequences for diversity measures).


In [6]:
pos_ctrl_seqs = np.random.choice(tr_seqs, size=len(val_seqs), replace=False)
# constant sequences for MMD
neg_ctrl_seqs_MMD = []
for _ in range(len(val_seqs)):
    rand_aa_idx = np.random.randint(0, high=19)
    rand_seq = decode_aav_mutation_seq(
        np.random.randint(rand_aa_idx, high=rand_aa_idx+2, size=seq_len))
    neg_ctrl_seqs_MMD.append(rand_seq)
# repeated sequences for diversity measures
rand_seq = decode_aav_mutation_seq(np.random.randint(0, high=21, size=seq_len))
neg_ctrl_seqs_diversity = [rand_seq for _ in range(len(val_seqs))]
# report baseline
MMD_pos = mmd(seq1=pos_ctrl_seqs, seq2=val_seqs)
MMD_neg = mmd(seq1=neg_ctrl_seqs_MMD, seq2=val_seqs)
Entropy_pos = abs(entropy(seq1=pos_ctrl_seqs, seq2=val_seqs))
Entropy_neg = abs(entropy(seq1=neg_ctrl_seqs_diversity, seq2=val_seqs))
Distance_pos = abs(distance(seq1=pos_ctrl_seqs, seq2=val_seqs))
Distance_neg = abs(distance(seq1=neg_ctrl_seqs_diversity, seq2=val_seqs))

In [7]:
## Lower the better
print('>>> MMD | abs(D_Entropy) | abs(D_Distance)' )
print(f'>>> For postivie control:\n'
      f'{MMD_pos:.5} , {Entropy_pos:.5} , {Distance_pos:.5}')
print(f'>>> For negative control:\n'
      f'{MMD_neg:.5} , {Entropy_neg:.5} , {Distance_neg:.5}')

>>> MMD | abs(D_Entropy) | abs(D_Distance)
>>> For postivie control:
0.0083223 , 8.8228e-05 , 0.01161
>>> For negative control:
0.75291 , 0.012981 , 0.86623


In [8]:
def save_checkpoints(discriminator, generator, optim_D, optim_G,  
               scheduler_D, scheduler_G, epoch):
    torch.save({
        'epoch': epoch,
        'disc_state_dict': discriminator.state_dict(),
        'gen_state_dict': generator.state_dict(),
        'optim_D_state_dict': optim_D.state_dict(),
        'optim_G_state_dict': optim_G.state_dict(),
        'scheduler_D_state_dict': scheduler_D.state_dict(),
        'scheduler_G_state_dict': scheduler_G.state_dict(),
    }, os.path.join('checkpoints/'+args_run_name, 'epoch_{}'.format(epoch)))
    return

def load_checkpoints(discriminator, generator, optim_D, optim_G,
                     scheduler_D, scheduler_G, epoch):
    checkpoints = torch.load(
        os.path.join('checkpoints'+args_run_name, 'epoch_{}'.format(epoch)))  
    discriminator.load_state_dict(checkpoints['disc_state_dict'])
    generator.load_state_dict(checkpoints['gen_state_dict'])
    optim_D.load_state_dict(checkpoints['optim_D_state_dict'])
    optim_G.load_state_dict(checkpoints['optim_G_state_dict'])
    scheduler_D.load_state_dict(checkpoints['scheduler_D_state_dict'])
    scheduler_G.load_state_dict(checkpoints['scheduler_G_state_dict'])
    return

## Init model

In [9]:
# init model and optimizer

Disc = Discriminator(args_dim, seq_len=seq_len, n_chars=n_chars).to(device)
Gen = Generator(args_dim, seq_len=seq_len, n_chars=n_chars, 
                z_dim=args_z_dim).to(device)

# Disc = ResNetSN_Discriminator(args_dim, seq_len, n_chars).to(device)
# Gen = ResNetSN_Generator(args_dim, seq_len, n_chars, args_z_dim).to(device)

optim_disc = optim.Adam(filter(lambda p: p.requires_grad, Disc.parameters()),
                        lr=args_lr_Disc, betas=(0.0, 0.9))
optim_gen = optim.Adam(Gen.parameters(), lr=args_lr_Gen, betas=(0.0, 0.9))
# use an exponentially decaying learning rate
scheduler_d = optim.lr_scheduler.ExponentialLR(optim_disc, gamma=0.99)
scheduler_g = optim.lr_scheduler.ExponentialLR(optim_gen, gamma=0.99)

#load_checkpoints(Disc, Gen, optim_disc, optim_gen, scheduler_d, scheduler_g, 0)
#save_checkpoints(Disc, Gen, optim_disc, optim_gen, scheduler_d, scheduler_g, 0)

In [10]:
def sample_generated(generator, epoch, run_name, n_seq=500):
    # random sampling n_seq noise, batched 100 noise at a time
    # return sampled n_seq sequences, also saving them to fasta file
    gen_seqs = []
    max_batch_size = 100
    for _ in range(n_seq//max_batch_size):
        z = torch.randn(max_batch_size, args_z_dim, device=device)
        gen_probs = generator(z)
        batched_seq = [decode_aav_mutation_seq(
            gen_prob.argmax(-1).cpu().numpy()) for gen_prob in gen_probs]
        gen_seqs += batched_seq
        torch.cuda.empty_cache()
    # saving seqs in fasta
    fasta_name = 'out/' + run_name + '/seq_out/epoch_' + str(epoch) + '.fasta'
    outfile = open(fasta_name, 'w')
    for i, seq in enumerate(gen_seqs):
        seq_id = 'Epoch'+str(epoch)+'_seq'+str(i)
        outfile.write('>' + seq_id + '\n')
        outfile.write(seq + '\n')
    return gen_seqs


def write_eval_TB(writer, MMD, Entropy, Distance, epoch):
    writer.add_scalar(args_run_name+'/MMD',  MMD, epoch)
    writer.add_scalar(args_run_name+'/abs(D_Entropy)', Entropy, epoch)
    writer.add_scalar(args_run_name+'/abs(D_Distance)', Distance, epoch)
    return

def evaluate(epoch):
    # eval mode
    Gen.eval()
    # sample args_eval_n_seq sequences
    gen_seqs = sample_generated(Gen, epoch, args_run_name, n_seq=args_eval_n_seq)
    MMD = mmd(seq1=gen_seqs, seq2=val_seqs)
    Entropy = abs(entropy(seq1=gen_seqs, seq2=val_seqs))
    Distance = abs(distance(seq1=gen_seqs, seq2=val_seqs))
    print(f'#========== evaluating for epoch ' f'{epoch} ' f'==========#')
    print('>>> MMD | abs(D_Entropy) | abs(D_Distance) \n'
          f'{MMD:.5} , {Entropy:.5} , {Distance:.5}')
    # save the gap distribution plot
    #plot_gap_dist(seqs=gen_seqs, run_name=args_run_name, epoch=epoch)
    write_eval_TB(TB_writer, MMD, Entropy, Distance, epoch)
    # revert to train mode
    Gen.train()
    return


In [11]:
def train(epoch):
    for batch_idx, data in enumerate(loader):
        if data.size()[0] != args_batch_size:
            continue
        else:
            data = data.to(device)
            
        # train discriminator more initially
        if epoch <= args_disc_iters_init_epoch: 
            disc_iters = args_disc_iters_init
        else:
            disc_iters = args_disc_iters
        # update discriminator
        for _ in (range(disc_iters)):
            z = torch.randn(args_batch_size, args_z_dim, device=device)
            optim_disc.zero_grad()
            optim_gen.zero_grad()
            real_D = Disc(data)
            fake_D = Disc(Gen(z))
            if args_loss == 'hinge':
                disc_loss = nn.ReLU()(1.0 - real_D).mean() + \
                nn.ReLU()(1.0 + fake_D).mean()
            disc_loss.backward()
            optim_disc.step()
        # record output for both fake and real from discriminator
        real_prob_avg = real_D.mean()
        fake_prob_avg = fake_D.mean()
        
        # update generator
        z = torch.randn(args_batch_size, args_z_dim, device=device)
        optim_disc.zero_grad()
        optim_gen.zero_grad()
        if args_loss == 'hinge' or args_loss == 'wasserstein':
            gen_loss = -Disc(Gen(z)).mean()
        gen_loss.backward()
        optim_gen.step()

        # write to Tensor board
        iter = len(loader)*epoch + batch_idx
        TB_writer.add_scalars(
            args_run_name+'/GAN loss', {
            'Disc_loss': disc_loss,
            'Gen_loss': gen_loss,
        }, iter)
        TB_writer.add_scalars(
            args_run_name+'/Discriminator out', {
            'Real prob': real_prob_avg,
            'Fake prob': fake_prob_avg,
        }, iter)
        if batch_idx % 60 == 0:
            print('disc loss', disc_loss.item(), 'gen loss', gen_loss.item())

    # update lr scheduler
    scheduler_d.step()
    scheduler_g.step()
    return


In [12]:
for epoch in range(args_epoch):
    train(epoch)
    save_checkpoints(Disc, Gen, optim_disc, optim_gen, 
                     scheduler_d, scheduler_g, epoch)
    if epoch != 0:
        os.remove(os.path.join('checkpoints/'+args_run_name, 
                               'epoch_{}'.format(epoch-1)))
    # eval every 2 epoch
    if epoch % 2 == 0:
        evaluate(epoch)

disc loss 1.9387624263763428 gen loss 0.05583495646715164
disc loss 0.5854005217552185 gen loss 0.4890280067920685
disc loss 0.37229788303375244 gen loss 0.820999801158905
disc loss 1.3679983615875244 gen loss -0.00958347786217928
>>> MMD | abs(D_Entropy) | abs(D_Distance) 
0.19332 , 0.00038986 , 0.29824
disc loss 1.5274876356124878 gen loss -0.08605532348155975
disc loss 1.6626719236373901 gen loss 0.33268213272094727
disc loss 1.695770263671875 gen loss 0.48372581601142883
disc loss 1.7839455604553223 gen loss 0.3881561756134033
disc loss 1.771636962890625 gen loss 0.42810916900634766
disc loss 1.665272831916809 gen loss 0.46294283866882324
disc loss 1.5260425806045532 gen loss 0.6490155458450317
disc loss 1.4940264225006104 gen loss 0.6560856103897095
>>> MMD | abs(D_Entropy) | abs(D_Distance) 
0.15485 , 0.0010983 , 0.052042
disc loss 1.4810562133789062 gen loss 0.603990912437439
disc loss 1.3460166454315186 gen loss 0.48290595412254333
disc loss 1.463431477546692 gen loss 0.5537452