# Lens Distortion Model Training

##### **Load libraries and Cuda**



In [1]:
import torch
import os
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import math

from typing import Tuple

from dataloaderNetS import get_loader
from modelNetS import EncoderNet, ModelNet, EPELoss
from models.utils import load_config

config = load_config('C:/Users/JoelVP/Desktop/UPV/ImageEnhancementTFG/imageenhancementtfg/data/config.ini')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if torch.cuda.is_available():
    gpu_info = torch.cuda.get_device_properties(0)
    print(f"GPU Name: {gpu_info.name}")
    print(f"GPU Memory: {gpu_info.total_memory / 1024**3:.2f} GB")
else:
    print("No GPU available")

GPU Name: NVIDIA GeForce GTX 1050 with Max-Q Design
GPU Memory: 4.00 GB


##### **Define Params**

In [3]:
# Define los valores de los argumentos directamente en Colab
dataset_type = 0
batch_size = 32 # Cambiar tambien en modelNetS
epochs = 50
lr = 0.0001
dataset_size = 256
checkpoint_interval = 5  # Guardar un checkpoint cada 5 epochs
checkpoint_dir = config['lens_distortion']['checkpoints_dir']
dataset_dir = config['lens_distortion']['dataset_dir']

# **AUX FUNCTION**

In [4]:
def model_paths(distortion_type, lr, image_size) -> Tuple[str,str]:

  lr_name = str(lr).replace('.','_')

  model1_path = f'./models/model1_{distortion_type[0]}_{lr_name}_{image_size}.pth'
  model2_path = f'./models/model2_{distortion_type[0]}_{lr_name}_{image_size}.pth'

  return model1_path, model2_path


def load_weights(model, path):
    model = model.cuda() if torch.cuda.is_available() else model
    model.load_state_dict(torch.load(path))
    return model.eval()


def save_checkpoint(epoch, model_1, model_2, optimizer, loss, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)

    last_checkpoint_path = os.path.join(checkpoint_dir, f'last_checkpoint_epoch_{epoch}.pth')

    # Eliminar el last_model anterior si existe
    for file in os.listdir(checkpoint_dir):
        if file.startswith('last_checkpoint_epoch_') and file != f'last_checkpoint_epoch_{epoch}.pth':
            os.remove(os.path.join(checkpoint_dir, file))

    # Guardar last_checkpoint_{epoch}
    torch.save({
        'epoch': epoch,
        'model_1_state_dict': model_1.state_dict(),
        'model_2_state_dict': model_2.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, last_checkpoint_path)
    print(f'Last checkpoint saved at {last_checkpoint_path}')


def save_best_checkpoint(epoch, model_1, model_2, optimizer, loss, checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)

    best_checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint.pth')

    # Guardar best_checkpoint
    torch.save({
        'epoch': epoch,
        'model_1_state_dict': model_1.state_dict(),
        'model_2_state_dict': model_2.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }, best_checkpoint_path)
    print(f'Best checkpoint saved in epoch {epoch}, LOSS {loss}')


def load_checkpoint(checkpoint_path, model_1, model_2, optimizer):
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model_1.load_state_dict(checkpoint['model_1_state_dict'])
        model_2.load_state_dict(checkpoint['model_2_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
        print(f"Checkpoint cargado. Última epoch={epoch}, pérdida={loss}")
        return model_1, model_2, optimizer, epoch, loss
    else:
        raise FileNotFoundError(f"No se encontró el archivo de checkpoint en {checkpoint_path}")


def get_best_loss(best_checkpoint_path):
    if os.path.exists(best_checkpoint_path):
        best_checkpoint = torch.load(best_checkpoint_path)
        best_loss = best_checkpoint['loss']
        print(f"Best loss encontrado: {best_loss}")
        return best_loss
    else:
        raise FileNotFoundError(f"No se encontró el archivo de best_checkpoint en {best_checkpoint_path}")


def select_model_path(checkpoint_dir, best_or_last):
    if best_or_last == "best":
        best_checkpoint_path = os.path.join(checkpoint_dir, 'best_checkpoint.pth')
        if os.path.exists(best_checkpoint_path):
            print("Best checkpoint in ", best_checkpoint_path)
            return best_checkpoint_path
        else:
            raise FileNotFoundError(f"No se encontró el archivo 'best_checkpoint.pth' en {checkpoint_dir}.")
    else:
        # Encontrar el último last_checkpoint basado en el número de epoch
        last_checkpoint_paths = [f for f in os.listdir(checkpoint_dir) if f.startswith('last_checkpoint_') and f.endswith('.pth')]

        if last_checkpoint_paths:
            last_checkpoint_paths.sort()  # Ordenar para obtener el último checkpoint basado en el número de epoch
            last_checkpoint_path = os.path.join(checkpoint_dir, last_checkpoint_paths[-1])
            print("Last checkpoint in ", last_checkpoint_path)
            return last_checkpoint_path
        else:
            raise FileNotFoundError(f"No se encontró ningún last_checkpoint en {checkpoint_dir}.")


def save_model(checkpoint_path, model1_path, model2_path):
    checkpoint = torch.load(checkpoint_path)
    # Obtener los estados de los modelos
    model_1_state_dict = checkpoint['model_1_state_dict']
    model_2_state_dict = checkpoint['model_2_state_dict']

    # Guardar model_1 en un archivo independiente
    torch.save(model_1_state_dict, model1_path)

    # Guardar model_2 en un archivo independiente
    torch.save(model_2_state_dict, model2_path)



##### **Prepare models for training**

In [5]:
if(dataset_type == 0):
    distortion_type = ['barrel']
elif(dataset_type == 1):
    distortion_type = ['pincushion']
elif(dataset_type == 2):
    distortion_type = ['rotation']
elif(dataset_type == 3):
    distortion_type = ['shear']
elif(dataset_type == 4):
    distortion_type = ['projective']
elif(dataset_type == 5):
    distortion_type = ['wave']

use_GPU = torch.cuda.is_available()

train_loader = get_loader(distortedImgDir = f'{dataset_dir}/train_distorted',
                  flowDir   = f'{dataset_dir}/train_flow',
                  batch_size = batch_size,
                  distortion_type = distortion_type)

val_loader = get_loader(distortedImgDir = f'{dataset_dir}/test_distorted',
                  flowDir   = f'{dataset_dir}/test_flow',
                  batch_size = batch_size,
                  distortion_type = distortion_type)

model1_path, model2_path = model_paths(distortion_type, lr, dataset_size)

model_1 = EncoderNet([1,1,1,1,2])
model_2 = ModelNet(distortion_type[0], batch_size)
criterion = EPELoss()

print('dataset type:',distortion_type)
print('batch size:', batch_size)
print('epochs:', epochs)
print('lr:', lr)
print('train_loader',len(train_loader))
print('val_loader', len(val_loader))
print('path model 1', model1_path)
print('path model 2', model2_path)


if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model_1 = nn.DataParallel(model_1)

if torch.cuda.is_available():
    model_1 = model_1.cuda()
    model_2 = model_2.cuda()
    criterion = criterion.cuda()

optimizer = torch.optim.Adam(model_1.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
n_steps_per_epoch = math.ceil(len(train_loader) / batch_size)

dataset type: ['barrel']
batch size: 32
epochs: 10
lr: 0.0001
train_loader 625
val_loader 62
path model 1 ./models/model1_barrel_0_0001_256.pth
path model 2 ./models/model2_barrel_0_0001_256.pth


## **TRAIN**

**Checkpoint**

In [7]:
# Arrancar con el checkpoint si se ha interrumpido el entrenamiento
start_with_checkpoint = False
start_epoch = 0
best_loss = float('inf')

if start_with_checkpoint:
    last_checkpoint_path = select_model_path(checkpoint_dir, 'last')
    best_checkpoint_path = select_model_path(checkpoint_dir, 'best')

    best_loss = get_best_loss(best_checkpoint_path)
    model_1, model_2, optimizer, start_epoch, best_loss = load_checkpoint(last_checkpoint_path, model_1, model_2, optimizer)

    print(f'Resumiendo el entrenamiento desde el epoch {start_epoch}')

else:
  try:
      last_checkpoint_path = select_model_path(checkpoint_dir, 'last')
      best_checkpoint_path = select_model_path(checkpoint_dir, 'best')

      os.remove(last_checkpoint_path)
      os.remove(best_checkpoint_path)
      print(f'Starting new training, and deleting old checkpoints')
  except FileNotFoundError:
      print(f'Starting new training')

Starting new training


In [None]:
# Ejecutar en caso de error con wandb.login()
!pip install wandb -qU

In [None]:
import wandb
wandb.login(key = config['api_keys']['w&b'])

In [None]:
# Log in to your W&B account
wandb.init(
      project="tfg",
      name=f"Train_lr_{lr}_bs_{batch_size}",
      id=f"id_train_lr_{lr}_bs_{batch_size}",
      resume="allow",
      config={
        "image_size": dataset_size,
        "learning rate" : lr,
        "batch size": batch_size,
        "train loader": len(train_loader),
        "val loader" : len(val_loader),
        "epochs" : epochs
    }
      )

wandb.define_metric("epoch/step")
wandb.define_metric("epoch/*", step_metric="epoch/step")

**Start training**

In [None]:
# Definir la función para entrenar una época
def train_one_epoch(epoch):

    cumu_loss = 0.0
    batch_loss = 0.0

    for i, (disimgs, disx, disy) in enumerate(train_loader):
        if torch.cuda.is_available():
            disimgs = disimgs.cuda()
            disx = disx.cuda()
            disy = disy.cuda()

        optimizer.zero_grad()

        labels_x = disx
        labels_y = disy

        flow_truth = torch.cat([labels_x, labels_y], dim=1)

        # Forward pass
        flow_output_1 = model_1(disimgs)
        flow_output = model_2(flow_output_1)

        # Calculate loss
        loss = criterion(flow_output, flow_truth)

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

        # Accumulate total loss
        cumu_loss += loss.item()
        batch_loss += loss.item()

        if (i + 1) % n_steps_per_epoch == 0:
            last_loss = batch_loss / n_steps_per_epoch
            # step = epoch * len(train_loader) + i - 1
            print(f"Iter {i + 1} Loss {last_loss}")
            wandb.log({"train_loss": last_loss, "step": epoch * len(train_loader) + i + 1})
            batch_loss = 0.0  # Reset total loss after logging
            
    # Calculate average loss for the epoch
    average_loss = cumu_loss / len(train_loader)
    
    print("Average Epoch Loss",average_loss)

    return average_loss


In [None]:
for epoch in range(epochs):
    print("\nEpoch", epoch)
    
    model_1.train()
    model_2.train()

    # Entrenamiento de una época
    avg_train_loss = train_one_epoch(epoch)
    
    running_val_loss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model_1.eval()
    model_2.eval()
    
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, (disimgs, disx, disy) in enumerate(val_loader):
            if use_GPU:
                torch.cuda.empty_cache()
                disimgs = disimgs.cuda()
                disx = disx.cuda()
                disy = disy.cuda()
        
            labels_x = disx
            labels_y = disy
        
            flow_truth = torch.cat([labels_x, labels_y], dim=1)
        
            # In one step
            flow_output = model_2(model_1(disimgs))
        
            val_loss = criterion(flow_output, flow_truth)
            running_val_loss += val_loss
            
    avg_val_loss = running_val_loss / len(val_loader)
    print(f'EPOCH {epoch}, LOSS train {avg_train_loss} LOSS val {avg_val_loss}')
    wandb.log({
    "epoch/avg_train_loss": avg_train_loss,
    "epoch/avg_val_loss": avg_val_loss,
    "epoch/step":epoch})


    # Guardar checkpoint cada checkpoint_interval epochs
    if (epoch + 1) % checkpoint_interval == 0:
        save_checkpoint(epoch + 1, model_1, model_2, optimizer, avg_train_loss, checkpoint_dir)


    # Guardar el mejor checkpoint
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        save_best_checkpoint(epoch, model_1, model_2, optimizer, avg_train_loss, checkpoint_dir)

    scheduler.step()

# Finaliza la sesión de W&B al finalizar el entrenamiento
wandb.finish()

## **SAVE BEST/LAST MODEL**

In [None]:
# Save the final model
final_model_type = "best"  # "last" or "best"
checkpoint_path = select_model_path(checkpoint_dir, final_model_type)
save_model(checkpoint_path, model1_path, model2_path)

Saving in models the best checkpoint
./checkpoints/best_checkpoint.pth
