In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader,TensorDataset
from sklearn.metrics import mean_squared_error

from utils import train, create_dataloader

import cVAE

### Load/Process Data:

In [2]:
batch_size = 32

train_path = 'data/Ferguson_fire_train.npy'
train_loader = create_dataloader(train_path, 32)

val_path = 'data/Ferguson_fire_test.npy'
val_loader = create_dataloader(val_path, 32, mode='val')

test_path = 'data/Ferguson_fire_obs.npy'
test_data = np.array(np.load(open(test_path,'rb')))
test_data_1D = np.reshape(test_data, (np.shape(test_data)[0],np.shape(test_data)[1]*np.shape(test_data)[2]))
test_data_1D_shifted = torch.Tensor(test_data_1D[1:])
test_data_1D = torch.Tensor(test_data_1D[:-1])
test_dataset = TensorDataset(test_data_1D,test_data_1D_shifted)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*32, shuffle=False)

### Train (ConvVAE) Model:

In [None]:
device = 'cpu'
cvae = cVAE.VAE_Conv(device).to(device)

if not os.path.exists('models/cvae.pt'):
    cvae = train(cvae, train_loader, val_loader, epochs=1, device=device)
    if not os.path.exists('models/'):
        os.makedirs('models/')
        torch.save(cvae.state_dict(), 'models/cvae.pt')
else:
    cvae.load_state_dict(torch.load('models/cvae.pt'))

Epoch 1 of 1
Train:


 66%|███████████████████████████              | 255/387 [08:51<04:42,  2.14s/it]

### Plot validation results (actual vs forecasted):

In [None]:
images, labels = next(iter(train_loader))
cvae.eval()

fig, ax = plt.subplots(2, 5, figsize=[18.5, 6])
fig.tight_layout(pad=4)
for n, idx  in enumerate(torch.randint(0,images.shape[0], (5,))):
    recon, _ = cvae(images[idx].unsqueeze(0)) 
    if n==2:
        ax[0,n].set_title('(Test) Actual:', fontsize=20, pad=20)
    ax[0, n].imshow(labels[idx].squeeze())
    ax[0, n].axis('off')
    if n==2:
        ax[1,n].set_title('(Test) Forecasted:', fontsize=20, pad=20)
    ax[1, n].imshow(recon.cpu().detach().squeeze())
    ax[1, n].axis('off')

### Plot test results (actual vs forecasted):

In [None]:
for image,label in test_dataset:
    

In [None]:
images, labels = next(iter(test_loader))
cvae.eval()

_, ax = plt.subplots(2, 5, figsize=[18.5, 6])
plt.axis('off')
for idx  in enumerate(len(images)):
    recon, _ = cvae(images[idx].unsqueeze(0)) 
    if n==0:
        ax[0,n].set_title('(Test) Actual')
    ax[0, n].imshow(labels[idx].squeeze())
    if n==0:
        ax[1,n].set_title('(Test) Forecasted')
    ax[1, n].imshow(recon.cpu().detach().squeeze())

### MSE test:

In [None]:
# do this