In [1]:
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from tqdm import tqdm
from torchvision import transforms

from sklearn.metrics import mean_squared_error
from utils import load_wildfire_data


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
transform = transforms.Resize((128, 128))

In [4]:
train_dl, val_dl, test_dl = load_wildfire_data(batch_size=32, shuffle=True, train_ratio = 0.8, root='data/', seq_length=1, sample_rate = 1, transform=transform, time_embedding=False, verbose=True)

Size of raw data: (12500, 256, 256)
Size of raw data after reshape: (125, 100, 256, 256)
Size of raw data: (5000, 256, 256)
Size of raw data after reshape: (50, 100, 256, 256)
Train size: 9900
Validation size: 2475
Total data size: 12375
Train input size: torch.Size([1, 128, 128])
Train target size: torch.Size([1, 128, 128])


## Model: VQ-VAE 2D Linear

In [None]:
class VQVAE(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25):
        super(VQVAE, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(256*256, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            # nn.Linear(1024, 512),
            # nn.ReLU(),
            nn.Linear(512, embedding_dim),
            nn.ReLU()
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 512),
            nn.ReLU(),
            # nn.Linear(512, 1024),
            # nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256*256),
            nn.Sigmoid()
        )

        # Codebook
        self.codebook = nn.Embedding(num_embeddings, embedding_dim)
        self.codebook.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def encode(self, x):
        z_e = self.encoder(x)
        return z_e

    def quantize(self, z_e):
        
        distances = (z_e ** 2).sum(dim=1, keepdim=True) + (self.codebook.weight ** 2).sum(dim=1) - 2 * torch.matmul(
            z_e, self.codebook.weight.t())
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        z_q = self.codebook(encoding_indices).view(z_e.shape)
        return z_q, encoding_indices

    def decode(self, z_q):
        x_recon = self.decoder(z_q)
        return x_recon

    def forward(self, x):
        x = x.view(x.size(0), -1)
        # print(x.size())
        z_e = self.encode(x)
        # print(z_e.size())
        z_q, encoding_indices = self.quantize(z_e)
        x_recon = self.decode(z_q)
        x_recon = x.view(x.size(0), 1, 256, 256)

        # Commitment loss
        e_latent_loss = F.mse_loss(z_e, z_q.detach())
        q_latent_loss = F.mse_loss(z_q, z_e.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        return x_recon, loss, z_e, z_q
    
print("Done")


## Trainning

In [None]:
# Initialize model, optimizer, and loss function
model = VQVAE(num_embeddings=128, embedding_dim=64, commitment_cost=0.25)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

model.to(device)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    
    for inputs, targets in tqdm(train_dl, desc=f"Epoch {epoch+1}/{num_epochs} [Training]"):
    
    #for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs, loss, z_e, z_q = model(inputs)
        recon_loss = F.mse_loss(outputs, targets)
        total_loss = recon_loss + loss
        total_loss.backward()
        optimizer.step()
        epoch_loss += total_loss.item()
    
    epoch_loss /= len(train_dl)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
