In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from livelossplot import PlotLosses
from torch.utils.data import DataLoader,TensorDataset
from tqdm import tqdm 
import torch.utils.data as data
import torch.nn as nn

from cvae_torch import CVAE

In [2]:
train_path = 'data/Ferguson_fire_train.npy'
val_path = 'data/Ferguson_fire_test.npy'
test_path = 'data/Ferguson_fire_obs.npy'

In [3]:
def find_decreasing_images(images):
    decreasing_indices = []
    previous_ones_count = None
    for i in range(len(images)):
        current_image = images[i]
        ones_count = np.count_nonzero(current_image == 1)
        if previous_ones_count is not None and ones_count < previous_ones_count:
            decreasing_indices.append(i)
        previous_ones_count = ones_count
    return decreasing_indices

def create_x_y(data,indices):
    x = np.split(data,indices)
    y = np.split(data,indices)
    for i in range(len(x)):
        x[i] = x[i][:-1]
        y[i] = y[i][1:]
    return np.concatenate(x),np.concatenate(y)

In [4]:
# Load and process data

train_path = 'data/Ferguson_fire_train.npy'
val_path = 'data/Ferguson_fire_test.npy'
test_path = 'data/Ferguson_fire_obs.npy'

train_data = np.array(np.load(open(train_path,'rb')))
train_data_x, train_data_y = create_x_y(train_data, find_decreasing_images(train_data))
tensor_x = torch.Tensor(train_data_x)
tensor_y = torch.Tensor(train_data_y)
train_dataset = TensorDataset(tensor_x,tensor_y)

val_data = np.array(np.load(open(val_path,'rb')))
val_data_x, val_data_y = create_x_y(val_data, find_decreasing_images(val_data))
tensor_x = torch.Tensor(val_data_x)
tensor_y = torch.Tensor(val_data_y)
val_dataset = TensorDataset(tensor_x,tensor_y)

# test_data = np.array(np.load(open(test_path,'rb')))
# test_data_1D = np.reshape(test_data, (np.shape(test_data)[0],np.shape(test_data)[1]*np.shape(test_data)[2]))
# test_data_1D_shifted = test_data_1D[1:]
# test_data_1D = test_data_1D[:-1]

train_loader = data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_loader = data.DataLoader(dataset=val_dataset, batch_size=2*128, shuffle=False)
# test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*batch_size, shuffle=False)

In [5]:
# conv model

In [6]:
class VAE_Encoder_Conv(nn.Module):
    def __init__(self):
        '''
        Class contains the Encoder (image -> latent).
        '''
        super(VAE_Encoder_Conv, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 20, 5, padding=2),  # Pad so that image dims are preserved
            nn.GELU(),
            nn.MaxPool2d(2, stride=2)  # Halves the spatial dimensions
        )  # Dims in 256x256 -> out 128x128

        self.layer2 = nn.Sequential(
            nn.Conv2d(20, 40, 5, padding=2), 
            nn.GELU(),
            nn.MaxPool2d(2, stride=2) 
        )  # Dims in 128x128 -> out 64x64

        self.layer3 = nn.Sequential(
            nn.Conv2d(40, 60, 3, padding=1),
            nn.GELU(),
            nn.MaxPool2d(2, stride=2)
        )  # Dims in 64x64 -> out 32x32

        self.layerMu = nn.Sequential(
            nn.Conv2d(60, 120, 3, padding=1),
            nn.GELU(),
            nn.MaxPool2d(2, stride=2) 
        )  # Dims in 32x32 -> out 16x16

        self.layerSigma = nn.Sequential(
            nn.Conv2d(60, 120, 3, padding=1),
            nn.GELU(),
            nn.MaxPool2d(2, stride=2) 
        )  # Dims in 32x32 -> out 16x16

    def forward(self, x, print_shape=True): 
        '''
        x: [float] the MNIST image
        '''
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        mu =  self.layerMu(x)
        sigma = self.layerSigma(x)
        return mu, sigma
    
class VAE_Decoder_Conv(nn.Module):  
    def __init__(self):
        '''
        Class contains the Decoder (latent -> image).
        '''

        super(VAE_Decoder_Conv, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(120, 60, 4, stride=2, padding=1),  # Upsample by a factor of 2
            nn.GELU()
        )  # Dims in 16x16 -> out 32x32

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(60, 40, 4, stride=2, padding=1), 
            nn.GELU()
        )  # Dims in 32x32 -> out 64x64

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(40, 20, 4, stride=2, padding=1),
            nn.GELU()
        )  # Dims in 64x64 -> out 128x128
        
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(20, 10, 4, stride=2, padding=1),
            nn.GELU()
        )  # Dims in 128x128 -> out 256x256

        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(10, 1, 5, stride=1, padding=2),  # Preserve spatial dimensions
            nn.Tanh()
        )  # Dims in 256x256 -> out 256x256

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        return x
 
    
class VAE_Conv(nn.Module):
    def __init__(self, device):
        '''
        Class combines the Encoder and the Decoder with a VAE latent space.
        '''
        super(VAE_Conv, self).__init__()
        self.device = device
        self.encoder = VAE_Encoder_Conv()
        self.decoder = VAE_Decoder_Conv()
        self.distribution = torch.distributions.Normal(0, 1)  # Sample from N(0,1)

    def sample_latent_space(self, mu, sigma):
        z = mu + sigma * self.distribution.sample(mu.shape).to(self.device)  # Sample the latent distribution
        kl_div = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()  # A term, which is required for regularisation
        return z, kl_div

    def forward(self, x):
        '''
        x - [float] A batch of images from the data-loader
        '''
        mu, sigma = self.encoder(x)  # Run the image through the Encoder
        z, kl_div = self.sample_latent_space(mu, sigma)  # Take the output of the encoder and get the latent vector 
        z = self.decoder(z)  # Return the output of the decoder (the predicted image)
        return z, kl_div

In [None]:
def train(autoencoder, train_data, val_data, kl_div_on=True, epochs=10, device='cpu', patience=3):
    opt = torch.optim.Adam(autoencoder.parameters())
    liveloss = PlotLosses()    
    best_val_loss = float('inf')
    counter = 0
    for epoch in range(epochs):
        logs = {}
        train_loss = 0.0
        val_loss = 0.0
        # Training
        autoencoder.train()
        for batch, label in tqdm(train_data):
            batch = batch.to(device)
            batch = batch.reshape(batch.shape[0], 1, batch.shape[1], batch.shape[2])
            opt.zero_grad()
            x_hat, KL = autoencoder(batch)
            loss = ((batch - x_hat) ** 2).sum() + KL
            loss.backward()
            opt.step()
            train_loss += loss.item()
        train_loss /= len(train_data)
        logs['loss'] = train_loss
        # Validation
        autoencoder.eval()
        with torch.no_grad():
            for batch, label in tqdm(val_data):
                batch = batch.to(device)
                batch = batch.reshape(batch.shape[0], 1, batch.shape[1], batch.shape[2])
                x_hat, KL = autoencoder(batch)
                loss = ((batch - x_hat) ** 2).sum() + KL
                val_loss += loss.item()
        val_loss /= len(val_data)
        logs['val_loss'] = val_loss
        liveloss.update(logs)
        liveloss.send()
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping: No improvement in validation loss for {patience} epochs.")
                break
    return autoencoder

device = 'cpu'
vae = VAE_Conv(device).to(device)
vae = train(vae, train_loader, val_loader, epochs=10, device=device)

 42%|██████████████████▏                        | 41/97 [05:23<07:34,  8.12s/it]