In [None]:
# # Download Results data if needed
# !curl "https://drive.usercontent.google.com/download?id={1NvnpyMYoV0GNq0hC0EokHKindbpU65Qs}&confirm=xxx" -o "./data/generated_images_VAE.npy"

# # Download trained model if needed
# !curl "https://drive.usercontent.google.com/download?id={10eJvrjtVV6hE2-pNdgun8RIJ8YNQJX81}&confirm=xxx" -o "./data/Final_Linear.pth"

In [1]:
import numpy as np
import torch
from livelossplot import PlotLosses
from torchsummary import summary
from sklearn.metrics import mean_squared_error
import torch 
import os
import matplotlib.pyplot as plt

# import custom packages
from WildfireThomas.WildfireGenerate.models import VAE
from WildfireThomas.WildfireGenerate.task2functions import training, predict
from WildfireThomas.WildfireDA.task3functions import assimilate


In [14]:
if torch.cuda.is_available(): 
 dev = "cuda" 
else: 
 dev = "cpu" 
device = torch.device(dev) 

#### Load training and test data

In [16]:
data_path =  'data/Ferguson_fire_train.npy'
train_data = np.load(data_path)
test_path =  'data/Ferguson_fire_test.npy'
test_data = np.load(test_path)

The model for task 2 we decided on was a VAE (variational autoencoder), and it was trained on 3D data (Ferguson_fire.train.npy) including a time steps, where a sequence of 20 time steps were taken from each series of originally 100 time steps, each spaced 5 steps apart. We then created the datasets where the training is the first 19, and the validation is the next (last) 19, which corresponds to t and t+1 for train and test respectively. Therefore, the shapes input into the model are (19,256,256) and outputs (19,256,256), 19 images of the corresponding time steps, and the model is trained to predict what happens 5 time steps later. We similary split the Ferguson_fire_test.npy into t and t+1 in 19 time steps and used that is validation

 The VAE we decided with was a simple linear one with 3 layers in the encoder and decoder respectively,with a latent size = 64 and KL divergence loss using MSE. We used the adam optimizer with learning rate 0.001. We have tested other structures, including a convolutional VAE, as well as testing with 2D input and inputting images independently without time labelled, but the result was nowhere near as good as MSE 0.148 for validation. 
 

## 1. Training

#### Create data loaders

The dataloader is designed to take a set of sequences of 100 images (our train dataset is 125 sequences), in which images are selected at intervals specified by the split_size value which means a sample is taken from a group of 100 images. This leads to a list of 3D objects of size 19,256,256. Here we create 4 dataloader objects for train (t), train(t+1), test (t), test (t+1).

In [None]:
split_size = 5
batch_size = 16
seq_length = 100

train_loader, train_shifted_loader, test_loader, test_shifted_loader = training.create_dataloaders(train_data, test_data, seq_length, split_size, batch_size)

# Check dataset shapes and lengths
print(f'Train dataset shape: {train_loader.dataset.shape}')
print(f'Test dataset shape: {test_loader.dataset.shape}')

del train_data, test_data

#### Define model

We now initiate the model and summarize the layers. We set a latent dimension of 64, and the channel size is the number of time steps for each series (19).

In [None]:
torch.cuda.empty_cache()
channel_size = 100//split_size -1
latent_dim = 64
image_size = 256
print(channel_size)
model = VAE(latent_dim = latent_dim, 
                channel_size = channel_size
                ).to(device)

summary(model, (1, 19, 256, 256))

The model has been previously trained has produced the liveloss plot shown below, beign orinally trained for 200 epochs.

In [None]:
model_name = 'data/Final_Linear.pth'

if os.path.exists(model_name):
    print(f"The model {model_name} was already trained.")
    plot_filename = 'data/Task2LogLoss.png'

    # Load the model
    model.load_state_dict(torch.load(model_name, map_location=torch.device('cpu')))
    # model = model_info['model']

    # Plot loss plot
    loss_plot_image = plt.imread(plot_filename)
    plt.imshow(loss_plot_image)
    plt.axis('off')  # Turn off axis
    plt.show()
    
else:
    print("Training new model.")
    num_epochs = 20
    liveloss = PlotLosses()
    train_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        logs = {}
        train_loss = training.train(model, train_loader, train_shifted_loader)
        logs['log loss'] = train_loss.detach().numpy()

        val_loss = training.validate(model, test_loader, test_shifted_loader)
        logs['val_log loss'] = val_loss.detach().numpy()

        liveloss.update(logs)
        liveloss.draw()

    plot_filename = 'data/Task2LogLoss.png'
    

    torch.save(model_name)

![alt text](misc/Task2LogLoss.png "Log loss")

## 2. Prediction

### Extract random sample from latent space

Image generation worked by creating a random torch sample with the dimensions number of images X latent_dim , this is then put through the model's decoder to produce a chosen number of samples. This is then displayed using the diplay function in the predict module.

In [None]:
predicted_samples = predict.predict_samples(model, 2, latent_dim)
print(predicted_samples[0].shape)
print(predicted_samples[1].shape)


In [None]:
predict.display_samples(predicted_samples, channel_size)

Display images after applying threshold

In [None]:
threshold = 0.1

binary_image = (predicted_samples > threshold).astype(int)
predict.display_samples(binary_image, channel_size)


## 3. MSE with Satellite Observed Data

See the notebook `Chosing_Image.ipynb`. In that notebook we built a metric to chose the best AI generated images for each of the satellite and background images.

To find the best corresponding time step to the satellite images, we calculate the MSE between each image and computed a metric (refer to  'Chosing_image.ipynb' notebook). The lowest MSE corresponding to the first background image was the 8th generated image (40th time step), and the second satellite image corresponded to the 10th generated image (50th time step), which meant the our model made sense in that it grew in a very similar way to the satellite data as it was observed every 10 time steps. Hence, we chose to use the 8th, 10th, 12th, 14th, and 16th, image for data assimilation. We have also attempted to find the MSE between our model and the satellite data, but found that getting corresponding images to the background was more appropriate for our purpose of a data assimilation to incorporate more information from the model

Here we are going to pick those best 5 images and calculate the MSE with the background data:

In [17]:
obs_dataset = np.load('data/Ferguson_fire_obs.npy')

In [18]:
obs_dataset.shape

(5, 256, 256)

In [20]:
best_generated_images = np.load('data/generated_images_VAE.npy')

In [21]:
best_generated_images.shape

(5, 1, 256, 256)

In [22]:
best_generated_images = best_generated_images.reshape(5,256,256)

In [23]:
best_generated_images.shape

(5, 256, 256)

In [24]:
assimilate.mse(obs_dataset, best_generated_images)

0.07300988149632023

After comparing the satellite images with our AI-generated images, we achieved a combined Mean Squared Error (MSE) of 0.0746 (sum of 5 images). This low MSE value indicates a strong alignment between the satellite images and the images produced by our model. Essentially, the model's ability to predict the state of the system 5 time steps ahead is quite accurate, as demonstrated by the minimal error in the generated images compared to the actual satellite data. This level of precision suggests that our model effectively captures the underlying patterns and dynamics observed in the satellite imagery, thereby validating the model's performance and its suitability for data assimilation purposes.

![alt text](misc/VAEnothreshold.png "VAE with no threshold")

![alt text](misc/VAEThreshold.png "VAE with no threshold")