# Imports

In [None]:
# Math
import torch
import numpy as np

# Flair Device
import flair
flair.device = torch.device('cpu')

# Analysis
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

# Autoencoder and Dataset Preparation
from data.imdb_preparation import IMDB_preparation
from data.ag_news_preparation import AG_NEWS_preparation
from autoencoders.autoencoder import Autoencoder

# Downloading Packages
import nltk
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

# 1. Vanilla Autoencoder

## 1.1 Loading Data

In [None]:
ds = AG_NEWS_preparation()
dl = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True)

## 1.2 Training loop

In [None]:
m = Autoencoder(100,100,100,4,3,variational=False, max_log2len=8)
Ls = []
device = 'cpu'
m = m.to(device)

In [None]:
opt = torch.optim.Adamax(m.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
m = m.train()
EPOCHS = 3
for e in range(EPOCHS):
    print(f"Epoch: {e}")
    pbar = tqdm(dl)
    for i, (x_in, x_out) in enumerate(pbar):
        opt.zero_grad()
        x_in = x_in.to(device)
        x_out = x_out.to(device)
        mx = m(x_in)
        L = (mx - x_out).pow(2).sum((1,2)).mean()
        L.backward()
        opt.step()
    
        pbar.set_description(f"L: {L.item()}")
        Ls.append(L.item())

In [None]:
plt.plot(np.log(Ls))

# 2. VAE

## 2.1 Loading Data

In [None]:
ds = AG_NEWS_preparation()
dl = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True)

## 2.2 Training loop

In [None]:
m = Autoencoder(100,100,100,4,3,variational=True,max_log2len=8)
Ls = []
Lsmse = []
Lsvar = []

device = 'cpu'
m = m.to(device)

In [None]:
opt = torch.optim.Adamax(m.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
m = m.train()
EPOCHS = 3
beta = 0.1
for e in range(EPOCHS):
    print(f"Epoch: {e}")
    pbar = tqdm(dl)
    for i, (x_in, x_out) in enumerate(pbar):
        opt.zero_grad()
        x_in = x_in.to(device)
        x_out = x_out.to(device)
        mx, (mu, logvar) = m(x_in, return_Z=True)
        Lmse = (mx - x_out).pow(2).sum((1,2)).mean()
        Lvar = - 0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum((1,2)).mean()
        L = Lmse + beta*Lvar
        L.backward()
        opt.step()
    
        pbar.set_description(f"L: {L.item()} Lmse: {Lmse.item()} Lvar: {Lvar.item()}")
        
        Lsmse.append(Lmse.item())
        Lsvar.append(Lvar.item())
        Ls.append(L.item())

In [None]:
plt.plot(np.log(Ls), label="Ls")
plt.plot(np.log(Lsmse), label="MSE")
plt.plot(np.log(beta*np.array(Lsvar)), label="VAR")
plt.legend()

# 3. Denoising Autoencoder

## 3.1 Loading Data

In [None]:
ds = AG_NEWS_preparation(aug_params={})
dl = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True)

## 3.2 Training loop

In [None]:
m = Autoencoder(100,100,100,4,3,variational=False, max_log2len=8)
Ls = []
device = 'cpu'
m = m.to(device)

In [None]:
opt = torch.optim.Adamax(m.parameters(), lr=0.001, weight_decay=0.0001)

In [None]:
m = m.train()
EPOCHS = 3
for e in range(EPOCHS):
    print(f"Epoch: {e}")
    pbar = tqdm(dl)
    for i, (x_in, x_out) in enumerate(pbar):
        opt.zero_grad()
        x_in = x_in.to(device)
        x_out = x_out.to(device)
        mx = m(x_in)
        L = (mx - x_out).pow(2).sum((1,2)).mean()
        L.backward()
        opt.step()
    
        pbar.set_description(f"L: {L.item()}")
        Ls.append(L.item())

In [None]:
plt.plot(np.log(Ls))