In [1]:
import torch
import glob
import tqdm

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from utilities import utils, train_utils
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from torch.utils.data import DataLoader, TensorDataset
from pythae.models import AE, AEConfig
from pythae.trainers import BaseTrainerConfig
from pythae.pipelines.training import TrainingPipeline
from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_utils import ModelOutput

In [2]:
paths = sorted(glob.glob('./data/environmental_embeddings_0001/0001/*.msgpack'))
device = train_utils.get_device()

X = []
for path in tqdm.tqdm(paths):
    pos_emb, neg_emb = utils.read_embedding_data(path)
    X.append(pos_emb)
X = np.concatenate(X, axis=0)

# reshape to 2D to apply L2 normalization
X = X.reshape(-1, 768)
X, l2_norm = normalize(X, norm='l2', axis=1, return_norm=True)

# reshape it back
X = X.reshape(-1, 77, 768)

Xtr, Xvl = train_test_split(X, test_size=0.2, random_state=42)

train_data = torch.tensor(Xtr, dtype=torch.float32)
val_data = torch.tensor(Xvl, dtype=torch.float32)

train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:05<00:00, 198.71it/s]


In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # define encoder layers: 3 layers fc of 1024, last fc no activation
        self.encoder_layers = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.latent_dim),
        )

    def forward(self, x):
        assert x.shape == (x.shape[0], 77, 768)

        # reshape for fc
        x = x.view(-1, 768)
        x = self.encoder_layers(x)

        # reshape back to batch_size * 77 * 768
        x = x.view(-1, 77, self.latent_dim)

        return x
    
class Decoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(Decoder, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim

        # define decoder layers: 3 layers of fc 1024, last layer activated with tanh
        self.decoder_layers = nn.Sequential(
            nn.Linear(self.latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, self.input_dim),
            nn.Tanh()
        )

    def forward(self, z):
        assert z.shape == (z.shape[0], 77, self.latent_dim)

        # reshape to batch_size * latent_size for fc
        z = z.view(-1, self.latent_dim)
        x = self.decoder_layers(z)

        # reshape back to batch_size * 77 * 768
        x = x.view(-1, 77, 768)

        return x
    
class AutoEncoder(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(AutoEncoder, self).__init__()

        self.encoder = Encoder(input_dim, latent_dim)
        self.decoder = Decoder(input_dim, latent_dim)

    def forward(self, x):
        self.z = self.encoder(x)
        recon = self.decoder(self.z)

        return recon

In [4]:
model = AutoEncoder(input_dim=768, latent_dim=128)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, threshold=0.01)
model = model.to(device)

In [8]:
save_path = 'experiments/fc_exp4/'
epoch = 300
batch_size = 128
sparsity_penalty_weight = 0

In [9]:
# loss function
class MSELossFC(nn.Module):
    def __init__(self, sparsity_penalty_weight=0.0):
        super(MSELossFC, self).__init__()
        self.sparsity_penalty_weight = sparsity_penalty_weight

    def forward(self, y, y_hat, z):
        assert y.shape == (y.shape[0], 77, 768)
        assert y_hat.shape == (y.shape[0], 77, 768)

        reconstruction_loss = nn.MSELoss(reduction='sum')(y, y_hat)
        
        l1_penalty = torch.abs(z).mean()
        sparsity_loss = l1_penalty * self.sparsity_penalty_weight

        total_loss = reconstruction_loss + sparsity_loss

        return total_loss

criterion = MSELossFC(sparsity_penalty_weight)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_utils.train_loop(
    train_loader=train_loader,
    val_loader=val_loader,
    model=model,
    criterion=criterion,
    device=device,
    optimizer=optimizer,
    epochs=epoch,
    save_path=save_path,
    scheduler=None
)

In [15]:
model.eval()
with torch.no_grad():
    recon = model(val_data.to(device))

criterion(val_data.to(device), recon, model.z) / 200

tensor(14.0679, device='cuda:0')