# WAE on Proteins

# Colab

## Paths

Set `colab` to `False` if the notebook is not running on Colab.

In [0]:
colab = True

In [0]:
if colab:
    ROOT = '/content/gdrive/My Drive/Colab/PGM/Project/'
else:
    ROOT = '../'

In [0]:
MODULE_PATH = ROOT + 'autoencoders/'
DATA_PATH = ROOT + 'data/'
NOTEBOOK_PATH = ROOT + 'notebooks/'
MODELS_PATH = ROOT + 'models/'

## Access to Drive

In [4]:
if colab:
    from google.colab import drive
    drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


## Installations

In [0]:
if colab:
    get_ipython().system_raw('pip install torch torchvision tensorboardX')

In [0]:
if colab:
    get_ipython().system_raw('wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip')
    get_ipython().system_raw('unzip ngrok-stable-linux-amd64.zip')

## Tensorboard

In [0]:
LOG_DIR = NOTEBOOK_PATH + 'runs/'

if colab:
    get_ipython().system_raw(
        'tensorboard --logdir="{}" --host 0.0.0.0 --port 6006 &'
        .format(LOG_DIR)
    )
    get_ipython().system_raw('./ngrok http 6006 &')

Get the url for TensorBoard

In [8]:
! curl -s http://localhost:4040/api/tunnels | python3 -c \
    "import sys, json; print(json.load(sys.stdin)['tunnels'][0]['public_url'])"

https://009cdff7.ngrok.io


# Import statements

## Access to modules

In [0]:
import sys
import os

In [0]:
sys.path.append(os.path.abspath(MODULE_PATH))

## Importation of modules

In [0]:
# "Magic" commands for automatic reloading of module, perfect for prototyping
%reload_ext autoreload
%autoreload 2

import wasserstein
import proteins

In [0]:
import numpy as np

from tqdm import tqdm

In [0]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

# Dataset

First we  need to unzip the dataset:

In [0]:
if colab:
    get_ipython().system_raw('mkdir data')
    get_ipython().system_raw('unzip "{path}" -d data'.format(path=DATA_PATH + 'pgm-dataset.npy.zip'))
    DATA_PATH = 'data/'

In [0]:
complete_set = proteins.ProteinDataset(data=np.load(DATA_PATH + 'pgm-dataset.npy'))

In [0]:
np.random.seed(0)

n = len(complete_set)

shuffled = np.arange(n)
np.random.shuffle(shuffled)

split = int(n * .8)

training = shuffled[:split]
validation = shuffled[split:]

In [17]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    dataset=complete_set[training],
    batch_size=batch_size,
    shuffle=True
)

validation_loader = torch.utils.data.DataLoader(
    dataset=complete_set[validation],
    batch_size=batch_size,
    shuffle=False
)

print('>> total training batch number: {}'.format(len(train_loader)))
print('>> total validation batch number: {}'.format(len(validation_loader)))

>> total training batch number: 400
>> total validation batch number: 100


# Wasserstein Auto-Encoder

In [0]:
model = proteins.WassersteinAutoEncoder(ksi=10)

In [0]:
%psource proteins.WassersteinAutoEncoder

In [20]:
print(model)

WassersteinAutoEncoder(
  (fc1): Linear(in_features=1968, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=32, bias=True)
  (fc6): Linear(in_features=32, out_features=128, bias=True)
  (fc7): Linear(in_features=128, out_features=512, bias=True)
  (fc8): Linear(in_features=512, out_features=1968, bias=True)
)


# Training procedure

## TensorBoard

First, we create a `SummaryWriter` instance (in order to use tensorboard):

In [0]:
from tensorboardX import SummaryWriter
writer = SummaryWriter(LOG_DIR + 'proteins-lr-5')

In order to visualize the graph, we call next cell:

In [32]:
dummy_input = torch.autograd.Variable(torch.rand(1, 1968))
writer.add_graph(model, dummy_input)



## Training and testing

We define the device used during the gradient descent:

In [0]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [0]:
model = proteins.WassersteinAutoEncoder(ksi=10)
model = model.to(device)

We define the learning rate:

In [0]:
learning_rate = 1e-5

And the optimizer:

In [0]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

## Sampling space

We want to sample from the latent space.

In [0]:
space = np.array(
    [
        [x, y] 
        for y in np.linspace(-1.5, 1.5, 16) 
        for x in np.linspace(-1.5, 1.5, 16)
    ], 
    dtype=np.float
)

space = torch.from_numpy(space).type(torch.FloatTensor)

# Training

In [38]:
for epoch in tqdm(range(, 50), ascii=True, ncols=100):
    
    proteins.train(epoch, model, optimizer, train_loader, device, writer)
    proteins.test(epoch, model, validation_loader, device, writer)

100%|###############################################################| 49/49 [04:42<00:00,  5.76s/it]


# Saving

In [0]:
torch.save(model.state_dict(), MODELS_PATH + 'wae.weights')

In [0]:
model = proteins.WassersteinAutoEncoder(ksi=10.)
model.load_state_dict(torch.load(MODELS_PATH + 'wae.weights'))

In [45]:
print(model)

WassersteinAutoEncoder(
  (fc1): Linear(in_features=1968, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=32, bias=True)
  (fc4): Linear(in_features=32, out_features=2, bias=True)
  (fc5): Linear(in_features=2, out_features=32, bias=True)
  (fc6): Linear(in_features=32, out_features=128, bias=True)
  (fc7): Linear(in_features=128, out_features=512, bias=True)
  (fc8): Linear(in_features=512, out_features=1968, bias=True)
)
