## Imports

In [None]:
from therapeutic_enzyme_engineering_with_generative_neural_networks.models import VAE
from therapeutic_enzyme_engineering_with_generative_neural_networks.SeqLikeDataset import SeqLikeDataset

import torch
import torch.nn
import torch.optim
from torch.utils.data import DataLoader, random_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm as tqdm

from Bio.SeqIO import parse
from seqlike import aaSeqLike
from seqlike.alphabets import AA

device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'

import seaborn as sns; sns.set_style("ticks")
%config InlineBackend.figure_formats = ['svg']
import matplotlib as mpl
mpl.rcParams['image.interpolation'] = 'nearest'

## Dataset

In [None]:
# the following file will exist if you've run the included get_homologous_seqs and remove_gappy_seqs notebooks.
input_file = '../data/tr-B5LY47-B5LY47-ECOLX_blast_nr_5000_aligned.fasta'
seqs = pd.Series([aaSeqLike(s, alphabet=AA) for s in parse(input_file, 'fasta')])
seqs = pd.Series([x for x in seqs if 'B' not in x and 'X' not in x])

dataset = SeqLikeDataset(seqs)

train_size = int(len(dataset)*0.8)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=128)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=128)


In [None]:
seqs.sample(10).seq.plot()

## Model and optimizer creation

In [None]:
v = VAE(sequence_length=seqs.apply(len).max(), layer_sizes=[128, 96, 64], z_size=64).to(device)
v.losses = []

vae_encoder_optimizer = torch.optim.Adam(v.encoder.parameters(), amsgrad=True)
vae_decoder_optimizer = torch.optim.Adam(v.decoder.parameters(), amsgrad=True)

batch, _ = next(iter(train_loader))

## Training loop

In [None]:
batches = 0
outer = tqdm(range(30))
for epoch in outer:
    inner = tqdm(train_loader, leave=False)
    v.train()

    for i, (data, target) in enumerate(inner):
        data = data.to(device)
        target = target.to(device)

        vae_encoder_optimizer.zero_grad()
        vae_decoder_optimizer.zero_grad()

        recon, mu, logvar, z_sample = v(data)

        # recon loss
        recon_loss = -(data * torch.log(recon) + (1 - data) * torch.log(1 - recon)).squeeze().mean(dim=1).mean(dim=1)

        # KL loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        kl_element = 1 + logvar - mu ** 2 - logvar.exp()
        kl_loss = -0.5 * kl_element.mean(dim=1) * (batches + 1) / 1300.0

        total_loss = torch.mean(recon_loss + kl_loss)

        v.losses.append(
            {
                "batch_id": batches,
                "train_kl": kl_loss.data.mean(),
                "train_recon": recon_loss.data.mean(),
                "train_total": total_loss.item(),
            }
        )

        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(v.parameters(), 2)
        vae_encoder_optimizer.step()
        vae_decoder_optimizer.step()

        inner.set_description(
            "KL[{:02.6f}], recon=[{:02.6f}], total=[{:02.6f}]".format(
                kl_loss.data.mean(), recon_loss.data.mean(), total_loss.item()
            )
        )

        batches += 1

    v.eval()
    test_data, test_target = next(iter(test_loader))
    test_data = test_data.to(device)
    test_target = test_target.to(device)

    test_out, mu, logvar, z_sample = v(test_data)

    # recon loss
    recon_loss = -(data * torch.log(recon) + (1 - data) * torch.log(1 - recon)).squeeze().mean(dim=1).mean(dim=1)

    # KL loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_element = 1 + logvar - mu ** 2 - logvar.exp()
    kl_loss = -0.5 * kl_element.mean(dim=1) * (batches + 1) / 1300.0

    total_loss = torch.mean(recon_loss + kl_loss)

    v.losses.append(
        {
            "batch_id": batches,
            "test_total": total_loss.item(),
            "test_kl": kl_loss.mean().item(),
            "test_recon": recon_loss.mean().item(),
        }
    )
