In [1]:
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
TB_writer = SummaryWriter()
from GAN_model.resnetGAN import *
from GAN_model.ProteoGAN import *
from utils.dataloader import *
from utils.torch_utils import *
from utils.metrics import *

In [2]:
args_run_name = 'ProteoGAN_AAV_MSA'
# args; TODO: move these to parser for script
args_batch_size=64; args_seed=42; args_eval_n_seq = 500
args_epoch = 300
# GAN specific param
args_z_dim = 128; args_dim = 256
# optimizer specific param
args_lr_Gen = 1e-4; args_lr_Disc = 8e-5; args_loss = 'hinge'
# train discriminator more initially
# according to https://livebook.manning.com/book/gans-in-action/chapter-5/185
args_disc_iters_init = 2; args_disc_iters = 1; args_disc_iters_init_epoch = 50
# make dir
os.makedirs('checkpoints/'+args_run_name, exist_ok=True)
os.makedirs('out/'+args_run_name+'/plot_out/', exist_ok=True)
os.makedirs('out/'+args_run_name+'/seq_out/', exist_ok=True)

In [3]:
seed_everything(args_seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# keep dataset in cpu
MSA_data = MSA_dataset('AAV_data/AAV_target_MSA.fasta', 
    torch.device('cpu'), pseudo_MSA=False) 

100%|██████████| 466/466 [00:00<00:00, 7832.03it/s]


### Creating data sets

In [4]:
# creating data set with train/val split
train_len = int(MSA_data.length*0.85-MSA_data.length)
val_len = MSA_data.length - train_len
train_data, val_data = torch.utils.data.random_split(
    MSA_data, [train_len, val_len], 
    generator=torch.Generator().manual_seed(args_seed))
# decoding train and validation sequences for evaluation
val_seqs = [decode_one_seq(torch.argmax(data, dim=-1).detach().cpu()) 
            for data in val_data]
tr_seqs = [decode_one_seq(torch.argmax(data, dim=-1).detach().cpu()) 
           for data in train_data]
# creating dataloader for training data
loader = torch.utils.data.DataLoader(
    train_data, batch_size=args_batch_size, shuffle=True,
    num_workers=1, pin_memory=True)
# fixed parameters
n_chars = 21; seq_len = len(val_seqs[0])

### Evaluating positive and negative control sequences to obtain baseline metrics
Positive: sample of real sequences and simulates a perfect model

Negative: a sample that simulates the worst possible model for each metric: \
    (constant sequence for MMD, repeated sequences for diversity measures).


In [5]:
pos_ctrl_seqs = np.random.choice(tr_seqs, size=len(val_seqs), replace=False)
# constant sequences for MMD
neg_ctrl_seqs_MMD = []
for _ in range(len(val_seqs)):
    rand_aa_idx = np.random.randint(0, high=19)
    rand_seq = decode_one_seq(
        np.random.randint(rand_aa_idx, high=rand_aa_idx+2, size=seq_len))
    neg_ctrl_seqs_MMD.append(rand_seq)
# repeated sequences for diversity measures
rand_seq = decode_one_seq(np.random.randint(0, high=21, size=seq_len))
neg_ctrl_seqs_diversity = [rand_seq for _ in range(len(val_seqs))]
# report baseline
MMD_pos = mmd(seq1=pos_ctrl_seqs, seq2=val_seqs)
MMD_neg = mmd(seq1=neg_ctrl_seqs_MMD, seq2=val_seqs)
Entropy_pos = abs(entropy(seq1=pos_ctrl_seqs, seq2=val_seqs))
Entropy_neg = abs(entropy(seq1=neg_ctrl_seqs_diversity, seq2=val_seqs))
Distance_pos = abs(distance(seq1=pos_ctrl_seqs, seq2=val_seqs))
Distance_neg = abs(distance(seq1=neg_ctrl_seqs_diversity, seq2=val_seqs))
Pearson_cor_pos = pearson_cor(seq1=pos_ctrl_seqs, seq2=val_seqs, random_n=1000)[1].item()

In [6]:
## Pearson corr higher the better, all other lower the better
print('>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr' )
print(f'>>> For postivie control:\n'
      f'{MMD_pos:.5} , {Entropy_pos:.5} , {Distance_pos:.5} , {Pearson_cor_pos:.5}')
print(f'>>> For negative control:\n'
      f'{MMD_neg:.5} , {Entropy_neg:.5} , {Distance_neg:.5} , 0.000')

>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr
>>> For postivie control:
0.2118 , 0.0017741 , 0.10012 , 0.95812
>>> For negative control:
0.33255 , 0.018137 , 1.9 , 0.000


In [7]:
def save_checkpoints(discriminator, generator, optim_D, optim_G,  
               scheduler_D, scheduler_G, epoch):
    torch.save({
        'epoch': epoch,
        'disc_state_dict': discriminator.state_dict(),
        'gen_state_dict': generator.state_dict(),
        'optim_D_state_dict': optim_D.state_dict(),
        'optim_G_state_dict': optim_G.state_dict(),
        'scheduler_D_state_dict': scheduler_D.state_dict(),
        'scheduler_G_state_dict': scheduler_G.state_dict(),
    }, os.path.join('checkpoints/'+args_run_name, 'epoch_{}'.format(epoch)))
    return

def load_checkpoints(discriminator, generator, optim_D, optim_G,
                     scheduler_D, scheduler_G, epoch):
    checkpoints = torch.load(
        os.path.join('checkpoints'+args_run_name, 'epoch_{}'.format(epoch)))  
    discriminator.load_state_dict(checkpoints['disc_state_dict'])
    generator.load_state_dict(checkpoints['gen_state_dict'])
    optim_D.load_state_dict(checkpoints['optim_D_state_dict'])
    optim_G.load_state_dict(checkpoints['optim_G_state_dict'])
    scheduler_D.load_state_dict(checkpoints['scheduler_D_state_dict'])
    scheduler_G.load_state_dict(checkpoints['scheduler_G_state_dict'])
    return

## Init model

In [8]:
# init model and optimizer

Disc = Discriminator(args_dim, seq_len=seq_len, n_chars=n_chars).to(device)
Gen = Generator(args_dim, seq_len=seq_len, n_chars=n_chars, 
                z_dim=args_z_dim).to(device)

# Disc = ResNetSN_Discriminator(args_dim, seq_len, n_chars).to(device)
# Gen = ResNetSN_Generator(args_dim, seq_len, n_chars, args_z_dim).to(device)

optim_disc = optim.Adam(filter(lambda p: p.requires_grad, Disc.parameters()),
                        lr=args_lr_Disc, betas=(0.0, 0.9))
optim_gen = optim.Adam(Gen.parameters(), lr=args_lr_Gen, betas=(0.0, 0.9))
# use an exponentially decaying learning rate
scheduler_d = optim.lr_scheduler.ExponentialLR(optim_disc, gamma=0.99)
scheduler_g = optim.lr_scheduler.ExponentialLR(optim_gen, gamma=0.99)

#load_checkpoints(Disc, Gen, optim_disc, optim_gen, scheduler_d, scheduler_g, 0)
#save_checkpoints(Disc, Gen, optim_disc, optim_gen, scheduler_d, scheduler_g, 0)

In [9]:
def sample_generated(generator, epoch, run_name, n_seq=500):
    # random sampling n_seq noise, batched 100 noise at a time
    # return sampled n_seq sequences, also saving them to fasta file
    gen_seqs = []
    max_batch_size = 100
    for _ in range(n_seq//max_batch_size):
        z = torch.randn(max_batch_size, args_z_dim, device=device)
        gen_probs = generator(z)
        batched_seq = [decode_one_seq(
            gen_prob.argmax(-1).cpu().numpy()) for gen_prob in gen_probs]
        gen_seqs += batched_seq
        torch.cuda.empty_cache()
    # saving seqs in fasta
    fasta_name = 'out/' + run_name + '/seq_out/epoch_' + str(epoch) + '.fasta'
    outfile = open(fasta_name, 'w')
    for i, seq in enumerate(gen_seqs):
        seq_id = 'Epoch'+str(epoch)+'_seq'+str(i)
        outfile.write('>' + seq_id + '\n')
        outfile.write(seq + '\n')
    return gen_seqs


def write_eval_TB(writer, MMD, Entropy, Distance, Pearson, epoch):
    writer.add_scalar(args_run_name+'/MMD',  MMD, epoch)
    writer.add_scalar(args_run_name+'/abs(D_Entropy)', Entropy, epoch)
    writer.add_scalar(args_run_name+'/abs(D_Distance)', Distance, epoch)
    writer.add_scalar(args_run_name+'/Pearson Correlation', Pearson, epoch)
    return

def evaluate(epoch):
    # eval mode
    Gen.eval()
    # sample args_eval_n_seq sequences
    gen_seqs = sample_generated(Gen, epoch, args_run_name, n_seq=args_eval_n_seq)
    MMD = mmd(seq1=gen_seqs, seq2=val_seqs)
    Entropy = abs(entropy(seq1=gen_seqs, seq2=val_seqs))
    Distance = abs(distance(seq1=gen_seqs, seq2=val_seqs))
    Pearson = pearson_cor(
        seq1=gen_seqs, seq2=val_seqs, random_n=args_eval_n_seq)[1].item()
    print(f'#========== evaluating for epoch ' f'{epoch} ' f'==========#')
    print('>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr \n'
          f'{MMD:.5} , {Entropy:.5} , {Distance:.5} , {Pearson:.5}')
    # save the gap distribution plot
    plot_gap_dist(seqs=gen_seqs, run_name=args_run_name, epoch=epoch)
    write_eval_TB(TB_writer, MMD, Entropy, Distance, Pearson, epoch)
    # revert to train mode
    Gen.train()
    return


In [10]:
def train(epoch):
    for batch_idx, data in enumerate(loader):
        if data.size()[0] != args_batch_size:
            continue
        else:
            data = data.to(device)
        # train discriminator more initially
        if epoch <= args_disc_iters_init_epoch: 
            disc_iters = args_disc_iters_init
        else:
            disc_iters = args_disc_iters
        # update discriminator
        for _ in range(disc_iters):
            z = torch.randn(args_batch_size, args_z_dim, device=device)
            optim_disc.zero_grad()
            optim_gen.zero_grad()
            real_D = Disc(data)
            fake_D = Disc(Gen(z))
            if args_loss == 'hinge':
                disc_loss = nn.ReLU()(1.0 - real_D).mean() + \
                nn.ReLU()(1.0 + fake_D).mean()
            disc_loss.backward()
            optim_disc.step()
        # record output for both fake and real from discriminator
        real_prob_avg = real_D.mean()
        fake_prob_avg = fake_D.mean()
        
        # update generator
        z = torch.randn(args_batch_size, args_z_dim, device=device)
        optim_disc.zero_grad()
        optim_gen.zero_grad()
        if args_loss == 'hinge' or args_loss == 'wasserstein':
            gen_loss = -Disc(Gen(z)).mean()
        gen_loss.backward()
        optim_gen.step()
        
        # write to Tensor board
        iter = len(loader)*epoch + batch_idx
        TB_writer.add_scalars(
            args_run_name+'/GAN loss', {
            'Disc_loss': disc_loss,
            'Gen_loss': gen_loss,
        }, iter)
        TB_writer.add_scalars(
            args_run_name+'/Discriminator out', {
            'Real prob': real_prob_avg,
            'Fake prob': fake_prob_avg,
        }, iter)
        if batch_idx % 60 == 0:
            print('disc loss', disc_loss.item(), 'gen loss', gen_loss.item())
    # update lr scheduler
    scheduler_d.step()
    scheduler_g.step()
    return


In [11]:
for epoch in range(args_epoch):
    train(epoch)
    save_checkpoints(Disc, Gen, optim_disc, optim_gen, 
                     scheduler_d, scheduler_g, epoch)
    if epoch != 0:
        os.remove(os.path.join('checkpoints/'+args_run_name, 
                               'epoch_{}'.format(epoch-1)))
    # eval every 2 epoch
    if epoch % 2 == 0:
        evaluate(epoch)

disc loss 1.9991246461868286 gen loss -0.054017454385757446
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.22209 , 0.008184 , 0.090369 , 0.93752
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.9679758548736572 gen loss -0.05976656824350357
disc loss 1.9316887855529785 gen loss -0.0715143159031868
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.22273 , 0.0081778 , 0.090379 , 0.93789
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.890212059020996 gen loss -0.08376333862543106
disc loss 1.8348057270050049 gen loss -0.10136439651250839
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.22191 , 0.0081911 , 0.090432 , 0.94154
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.8080930709838867 gen loss -0.11484336853027344
disc loss 1.7638585567474365 gen loss -0.13120988011360168
>>> MMD

  c /= stddev[:, None]
  c /= stddev[None, :]


>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.26815 , 0.0012224 , 0.023949 , nan
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.6989924907684326 gen loss 0.25580134987831116
disc loss 1.7221341133117676 gen loss 0.2547169327735901


  c /= stddev[:, None]
  c /= stddev[None, :]


>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27534 , 0.0011185 , 0.024047 , nan
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.7391870021820068 gen loss 0.28230565786361694
disc loss 1.7309980392456055 gen loss 0.2904701232910156


  c /= stddev[:, None]
  c /= stddev[None, :]


>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27379 , 0.0013533 , 0.020629 , nan
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.7155722379684448 gen loss 0.2812163233757019
disc loss 1.7795085906982422 gen loss 0.25814318656921387
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27382 , 0.001424 , 0.021772 , 0.78331
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.6974475383758545 gen loss 0.24623362720012665
disc loss 1.6808074712753296 gen loss 0.257803350687027


  c /= stddev[:, None]
  c /= stddev[None, :]


>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.26945 , 0.0014446 , 0.010417 , nan
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.726703405380249 gen loss 0.2630341649055481
disc loss 1.7036428451538086 gen loss 0.279459148645401
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27245 , 0.0013835 , 0.015444 , 0.79188
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.669114112854004 gen loss 0.27078232169151306
disc loss 1.7303396463394165 gen loss 0.23327898979187012
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.2665 , 0.0016907 , 0.0045451 , 0.81572
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.7436842918395996 gen loss 0.22369948029518127
disc loss 1.6588973999023438 gen loss 0.259907603263855
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.26643 , 0.0015044 , 

  c /= stddev[:, None]
  c /= stddev[None, :]


>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27654 , 0.0014598 , 0.00030053 , nan
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.6501352787017822 gen loss 0.26434993743896484
disc loss 1.6380215883255005 gen loss 0.24868249893188477
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27721 , 0.0016021 , 0.00065517 , 0.80791
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.6448767185211182 gen loss 0.27524012327194214
disc loss 1.6391286849975586 gen loss 0.2521844208240509
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.27663 , 0.0014562 , 0.00030744 , 0.79639
The loci with percentage of gap >= 95% in the cleaned MSA: 
 []
Their percentage gap are:
 []
disc loss 1.6067246198654175 gen loss 0.2794768512248993
disc loss 1.654712438583374 gen loss 0.2775317430496216
>>> MMD | abs(D_Entropy) | abs(D_Distance) | Pearson Corr 
0.28575 , 0.0