In [None]:
import urllib.request
import gzip
from tqdm import tqdm
import torch
import torch.nn as nn
import numpy as np
import pickle
import sys

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Preprocessing

In [None]:
from gym_sts.data.state_log_loader import StateLogLoader

loader = StateLogLoader()

data_url = '../gym-sts/out/states/states_20221106-233706.json.gz'

with open(data_url, "rb") as url:
    with gzip.GzipFile(fileobj=url, mode='r') as f:
        loader.load_file(f)

In [None]:
with open("state_log_loader_small.pkl.gz", "wb") as f:
    with gzip.GzipFile(fileobj=f, mode='w') as f2:
        pickle.dump(loader, f2)

In [None]:
with open("loader_state_data_dict.pkl.gz", "wb") as f:
    with gzip.GzipFile(fileobj=f, mode='w') as f2:
        pickle.dump(loader.state_data, f2)

In [None]:
with open("state_log_loader_small.pkl.gz", "rb") as f:
    with gzip.GzipFile(fileobj=f, mode='r') as f2:
        loader = pickle.load(f2)

In [None]:
with open("loader_state_data_dict.pkl.gz", "rb") as f:
    with gzip.GzipFile(fileobj=f, mode='r') as f2:
        loader_state_data_dict = pickle.load(f2)

# Experiment 1 - Embedding Autoencoder

In [5]:
with open("../gym-sts/loader_state_data.npz", "rb") as f:
    with np.load(f) as fz:
        for _, v in fz.items():
            loader_state_data = v

In [6]:
loader_state_data.shape

(5001, 47233)

In [7]:
total_size = loader_state_data.shape[0]
train_size = round(total_size * 0.8)
train_set, val_set = torch.utils.data.random_split(np.array(loader_state_data).astype(np.float32), [train_size, total_size - train_size])

In [8]:
train_set

<torch.utils.data.dataset.Subset at 0x7f2a563a4c40>

In [9]:
LOSS_FN = nn.MSELoss()

class Encoder(nn.Module):
    def __init__(self, num_inputs : int, latent_dim : int):
        super().__init__()
        self.num_inputs = num_inputs
        self.latent_dim = latent_dim
        
        self.net = nn.Sequential(
            nn.Linear(num_inputs, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, latent_dim),
            nn.LeakyReLU()
        )
    
    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, num_inputs : int, latent_dim : int):
        super().__init__()
        self.num_inputs = num_inputs
        self.latent_dim = latent_dim
        
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, 1024),
            nn.LeakyReLU(),
            nn.Linear(1024, num_inputs),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

class AutoEncoder(nn.Module):
    def __init__(self, num_inputs : int, latent_dim : int):
        super().__init__()
        self.num_inputs = num_inputs
        self.latent_dim = latent_dim
        
        self.encoder = Encoder(num_inputs, latent_dim)
        self.decoder = Decoder(num_inputs, latent_dim)
    
    def forward(self, x):
        z = self.encoder(x)
        xhat = self.decoder(z)
        return xhat

class LinearAutoEncoder(nn.Module):
    def __init__(self, num_inputs : int, latent_dim : int):
        super().__init__()
        self.num_inputs = num_inputs
        self.latent_dim = latent_dim
        
        self.net = nn.Sequential(
            nn.Linear(num_inputs, latent_dim),
            nn.LeakyReLU(),
            nn.Linear(latent_dim, num_inputs),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        xhat = self.net(x)
        return xhat

In [10]:
dists = loader_state_data.sum(axis=0) / loader_state_data.shape[1]

print(f"Proportion of features that are all 0: {sum(dists == 0) / loader_state_data.shape[1]}")
print(f"Proportion of features that are all 1: {sum(dists == 1) / loader_state_data.shape[1]}")

# Pretend we're running a model f(X) = constant
dud_val_loss = LOSS_FN(torch.tensor(val_set), torch.tensor(np.repeat(dists.reshape((1,len(dists))), len(val_set), axis=0)))

print(f"Dud validation loss: {dud_val_loss}")

Proportion of features that are all 0: 0.8525818813117947
Proportion of features that are all 1: 0.0


  dud_val_loss = LOSS_FN(torch.tensor(val_set), torch.tensor(np.repeat(dists.reshape((1,len(dists))), len(val_set), axis=0)))


Dud validation loss: 0.0836987930342024


In [12]:
data_loader = torch.utils.data.DataLoader(train_set, batch_size=100, shuffle=True, drop_last=True)
input_dim = val_set[0].shape[0]
auto_encoder = LinearAutoEncoder(num_inputs=input_dim, latent_dim=512)
optimizer = torch.optim.Adam(auto_encoder.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=20, min_lr=5e-6)

In [None]:
training_losses = []
val_losses = []

def train():
    total_batches = len(data_loader)
    for batch_num, batch in enumerate(data_loader):
        loss = LOSS_FN(batch, auto_encoder.forward(batch))
        loss.backward()
        print(f"Batch: {batch_num+1}/{total_batches} loss: {loss.double()}", end="\r")
        training_losses.append(float(loss))
        optimizer.step()
    val_loss = LOSS_FN(torch.tensor(val_set), auto_encoder.forward(torch.tensor(val_set)))
    print(f"\nValidation loss: {val_loss}")
    val_losses.append(float(val_loss))
    scheduler.step(val_loss)

for epoch in range(10):
    print(f"Epoch {epoch}")
    train()

Epoch 0
Batch: 40/40 loss: 0.014221844263374805
Validation loss: 0.012453131377696991
Epoch 1
Batch: 40/40 loss: 0.0046832067891955385
Validation loss: 0.004682841710746288
Epoch 2
Batch: 40/40 loss: 0.0047970125451684386
Validation loss: 0.004653583746403456
Epoch 3
Batch: 2/40 loss: 0.004727833438664675

In [None]:
val_loss