In [None]:
from typing import *

import matplotlib.pyplot as plt
import torch
import random
import shap
import numpy as np
import gc

from tqdm import tqdm
from xai import *

torch.cuda.empty_cache()
gc.collect()

In [None]:
ae: AutoEncoder = AutoEncoder(
        data_shape=(210,160,3),
        latent_shape=(32,),
        hidden_layers=[512,512],
        output_activation="Sigmoid",
        device=device
    )

ae

In [None]:
params = torch.load("l32-params.pt")
params

In [None]:
with torch.no_grad():
    for ae_param,param in zip(ae.parameters(), params["params"]):
        ae_param[:] = param

In [None]:
ae.save("asteroids-autoencoder-l32.pt")

In [None]:
%%capture output

device = "cuda"
batch_size = 1024
checkpoint_interval = 600
lr = 0.001
epochs = 5_000_000
load: str|None = "asteroids-autoencoder-l32"

In [None]:
dataset_v2 = torch.load("dataset-v2.pt", map_location="cpu")

X_train = dataset_v2[:79216]
X_val = dataset_v2[79216:]

In [None]:
if load is None:
    ae: AutoEncoder = AutoEncoder(
        data_shape=(210,160,3),
        latent_shape=(32,),
        hidden_layers=[512,512],
        output_activation="Sigmoid",
        device=device
    )

else:
    ae: AutoEncoder = AutoEncoder.load(load)

ae

In [None]:
def loss(predict, target):
    N = target.nelement()
    n = target.count_nonzero()
    square_error = (predict - target)**2
    
    return torch.where(target == 0, (n/N)*square_error, ((N-n)/N)*square_error).mean()


ae.adam().fit(
    X_train=X_train,
    Y_train=X_train,
    batch_size=batch_size,
    epochs=epochs,
    loss_criterion=LossModule(loss),
    X_val=X_val,
    Y_val=X_val,
    early_stop_count=400,
    info="Asteroids autoencoder-train",
    verbose=True
)


In [None]:
ae.train_history.figure("Asteroids autoencoder")
plt.gcf().savefig("asteroids-autoencoder-l32-train.png")

In [None]:
#idx = random.randrange(0, len(X_val))

good_ae = AutoEncoder.load("asteroids-autoencoder-l32")

bad_ae = AutoEncoder.load("asteroids-autoencoder-l32-bad")

X = X_val[idx].cuda()
with torch.no_grad():
    X = X.numpy(force=True)
    good_y = good_ae(X)().squeeze(0).numpy(force=True)
    bad_y = bad_ae(X)().squeeze(0).numpy(force=True)

plt.figure(dpi=250)
plt.imshow(np.hstack([X, good_y, bad_y]))
plt.imsave("original_image.png", X.repeat(4,0).repeat(4,1))
plt.imsave("good-reconstruction.png", good_y.repeat(4,0).repeat(4,1))
plt.imsave("bad-reconstruction.png", bad_y.repeat(4,0).repeat(4,1))