# Set up a simple training loop

In [None]:
from utils import *
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt; plt.style.use('dark_background')
from tqdm import tqdm

In [None]:
N_SAMPLES = 100
# quantization parameters
EPSI = 0.15 # quantization step
MAX_SIG = 2.5 # maximum value of the signal
NLEVELS = int(2*MAX_SIG / EPSI)//2+1 # number of quantization levels
print(f"Quantization step: {EPSI}, Number of levels: {NLEVELS}")

In [None]:
# architecture
LATENT_DIM = 8
class Net(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=64):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            # nn.ReLU(),
            # nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
    def forward(self, x): return self.fc(x)

In [None]:
# create a dataset of random signals
class SigDS(Dataset):
    def __init__(self, n_ds):
        self.n_ds = n_ds
        self.data = th.stack([create_random_signal(N_SAMPLES) for _ in range(n_ds)])
    def __len__(self): return self.n_ds
    def __getitem__(self, idx): return self.data[idx]

ds = SigDS(20000)
dl = DataLoader(ds, batch_size=16, shuffle=True)

In [None]:
# plot 10 random signals
plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(5, 2, i+1)
    plt.plot(ds[i])
plt.show()

In [None]:
# define custom loss with custom gradient
class HLoss1(nn.Module):
    def __init__(self):
        super(HLoss1, self).__init__()
    def forward(self, x1, x2):
        x = x1 - x2
        b = F.softmax(x, dim=-1) * F.log_softmax(x, dim=-1)
        # b = -1.0 * b.sum() / x.size(0)
        b = b.sum() / x.size(0)
        return b

In [None]:
# training loop
enc = Net(N_SAMPLES, LATENT_DIM)
dec = Net(LATENT_DIM, N_SAMPLES)
opt = optim.Adam(list(enc.parameters()) + list(dec.parameters()), lr=3e-4)

mse_loss = nn.MSELoss()
l1_loss = nn.L1Loss()
# H_loss = HLoss1()
H_loss = HLoss2(EPSI, MAX_SIG)

n_epochs = 5
lmses, lL1s, lHs = [], [], []
for epoch in (range(n_epochs)):
    elmse, elL1, elH = 0, 0, 0
    for x in dl:
        opt.zero_grad()

        z = enc(x)
        x̂ = dec(z)

        lmse = mse_loss(x̂, x)
        lL1 = l1_loss(x̂, x)
        lH = H_loss(x̂, x)

        elmse += lmse.item()
        elL1 += lL1.item()
        elH += lH.item()

        # loss = lmse
        # loss = lL1
        loss = lH
        loss.backward()

        opt.step()

    lmses.append(elmse/len(dl))
    lL1s.append(elL1/len(dl))
    lHs.append(elH/len(dl))

    print(f'ep {epoch}-> mse:{lmses[-1]:.4f}, L1:{lL1s[-1]:.4f}, H:{lHs[-1]:.4f}')

In [None]:
# plot losses
lmses, lL1s, lHs =  th.tensor(lmses), th.tensor(lL1s), th.tensor(lHs)
lmses = lmses/th.max(lmses)
lL1s = lL1s/th.max(lL1s)
lHs = lHs/th.max(lHs)
plt.figure(figsize=(10, 2))
plt.plot(lmses, label='mse')
plt.yscale('log')
plt.legend()
plt.figure(figsize=(10, 2))
plt.plot(lL1s, label='L1')
plt.yscale('log')
plt.legend()
plt.figure(figsize=(10, 2))
plt.plot(lHs, label='H')
plt.yscale('log')
plt.legend()
plt.show()
print(f'train losses: mse:{lmses[-1]:.4f}, L1:{lL1s[-1]:.4f}, H:{lHs[-1]:.4f}')


# evaluate on unseen data useing mse loss
test_in = SigDS(100).data
test_out = dec(enc(SigDS(100).data))
test_loss = mse_loss(test_out, test_in)
print(f'test loss: {test_loss}')

In [None]:
# compare input and output
plt.figure(figsize=(10, 10))
err = []
for i in range(20):
    plt.subplot(10, 2, i+1)
    x = create_random_signal(N_SAMPLES)
    x̂ = dec(enc(th.tensor(x).view(1,N_SAMPLES))).view(N_SAMPLES).detach().numpy()
    err.append(th.mean((x-x̂)**2))
    plt.plot(x, label='input')
    plt.plot(x̂, label='output')
    # plt.grid()
plt.suptitle(f'mse: {th.mean(th.tensor(err))}')
plt.legend()
plt.show()