In [None]:
import numpy as np
import torch 
from torch import nn
from torch.nn import functional as F

from typing import List, Callable, Union, Any, TypeVar, Tuple
Tensor = TypeVar('torch.tensor')

import torch.optim as optim


# Data preprocessing utils : 
from acdc_dataset import ACDC_Dataset, One_hot_Transform, load_dataset
from torchvision.transforms import Compose
from torchvision import transforms

from torch.utils.data import DataLoader


# Visuals utils
import os
import matplotlib.pyplot as plt
from tqdm import tqdm


# my defined model
from vqVAE import VQVAE


## Preparing Dataset 

In [None]:

L = 64 # image size L=W
BATCH_SIZE = 64

In [None]:
dataset_path = "/Users/kajou/OneDrive/Desktop/VQ-VAE/ACDC/database"

train_set_path = os.path.join(dataset_path, "training")
test_set_path  = os.path.join(dataset_path, "testing")


In [None]:
train_dataset = load_dataset(train_set_path)
test_dataset  = load_dataset(test_set_path)


input_transforms = Compose([
    transforms.Resize(size=(L,L), interpolation=transforms.InterpolationMode.NEAREST),
    One_hot_Transform(num_classes=4)
    ])


TrainDataset = ACDC_Dataset(data = train_dataset, transforms= input_transforms) 
TestDataset  = ACDC_Dataset(data = test_dataset, transforms= input_transforms)

TrainLoader  = DataLoader(TrainDataset, batch_size = BATCH_SIZE, shuffle = True)
TestLoader   = DataLoader(TestDataset , batch_size = BATCH_SIZE, shuffle = True)

## Prepairing the model

In [None]:
K =  512 # num_embeddings
D =  64 # embedding_dim
in_channels = 4 
img_size = 64

In [None]:
ACDC_VQVAE = VQVAE(in_channels, D, K)

input = torch.rand(16, 4, 64, 64)

In [None]:
y = ACDC_VQVAE(input)
z_e = ACDC_VQVAE.encode(input)[0]
z_q, _ = ACDC_VQVAE.vq_layer(z_e)
codeBook = ACDC_VQVAE.vq_layer.embedding

## Training the Model

In [None]:
# detect gpu ?

print(torch.cuda.is_available())

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
### Learning parameters


model = ACDC_VQVAE.to(device)
batch_size = 64
lr = 1e-4
epochs = 2
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

In [None]:
###########################        Training ....      #################################

def evaluate_model(model, val_loader):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            inputs = batch.float().to(device)
           
            output, input, vq_loss = model(inputs)
            
            # Loss and backward
            loss = model.loss_function(output, input, vq_loss)['loss']
            
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader.dataset)
    return avg_val_loss


def save_model(model, epoch):
    checkpoint_path = os.path.join( os.getcwd() , 'vqvae_100_bestmodel.pth' )
    torch.save({'epoch' : epoch,
                'model_state_dict': model.state_dict()}, checkpoint_path)



model.train()
train_loss_values = []
val_loss_values = []
best_val_loss = float('inf')

for epoch in range(epochs):

    train_loss = 0.0
    
    with tqdm(TrainLoader, unit="batch") as tepoch:
        for batch_idx, inputs in enumerate(TrainLoader):
            inputs = inputs.float().to(device)  # Move data to the appropriate device (GPU/CPU)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass // args is a list containing : [output, input, vq_loss]
            output, input, vq_loss = model(inputs)
            
            # Loss and backward
            loss = model.loss_function(output, input, vq_loss)['loss']  # Use the loss function defined in the model
            loss.backward()
            optimizer.step()
            
            # Track running loss
            train_loss += loss.item()

            # tqdm bar displays the loss
            tepoch.set_postfix(loss=loss.item())

    epoch_loss = train_loss / len(TrainLoader.dataset)
    train_loss_values.append(epoch_loss)

    # Validation after each epoch
    val_loss = evaluate_model(model, TestLoader)
    val_loss_values.append(val_loss)

    #saving model if Loss values decreases
    if val_loss < best_val_loss :
        save_model(model, epoch)

    print('Epoch {}: Train Loss: {:.4f}'.format(epoch, train_loss/len(TrainLoader)))

print("Training complete.")