## Trabajo Fin de Máster <br/> Diseño de una arquitectura multimodal para descripción textual de pares imagen-audio

## Script 6. Entrenamiento del modelo conjunto con inputs de imagen, texto y audio

En este notebook, usamos la base de datos que hemos definido en el Script 5 para entrenar un modelo que acepta imágenes, piezas de texto y audios como inputs. Este modelo pretende diferenciar las distintas personas que han participado en la creación de la misma.

### Paso 1. Montamos el almacenamiento

Damos permiso a Colab para acceder a mi unidad de Drive y nos situamos en la carpeta donde tenemos los scripts y la librería que hemos creado con las clases propias.

In [1]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(0)

In [2]:
import os
os.chdir('..')
os.getcwd()

'/mnt/batch/tasks/shared/LS_root/mounts/clusters/tfm-cpu/code/Users/jose.puche/Scripts'

### Paso 2. Iniciamos sesión para registrar los resultados en wandb


In [3]:
import wandb
!wandb login 1b8abaacf33b7b5812267384768c22a1eef3c11e

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/azureuser/.netrc


### Paso 2. Importación de paquetes

Instalamos las librerías necesarias (entre ellas, necesitamos el modelo CLIP, que descargamos directamente desde github), e importamos otras necesarias.

También importamos el dataset y el modelo que hemos definido para nuestro problema, y que se encuentran en

In [4]:
import clip
import torch
import pandas as pd
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset, SubsetRandomSampler, DataLoader

from tqdm import tqdm

from tfm_lib.audio_processing import AudioUtil, AudioAugmentation
from tfm_lib.datasets import CustomDataset
from tfm_lib.modelos import AudioCLIP
from tfm_lib.EarlyStopping import EarlyStopping



In [5]:
# Función de pérdida
def loss_fn(logits, labels):
    """
    logits: Las salidas del modelo (predicciones) para cada clase.
    labels: Las etiquetas verdaderas (números enteros) para cada ejemplo.
    """
    criterion = nn.CrossEntropyLoss()  # Función de pérdida de entropía cruzada
    return criterion(logits, labels)

# Ejemplo de cómo usar la función de pérdida
logits = torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.3, 0.2, 0.5]])
labels = torch.tensor([0, 1, 2])

loss = loss_fn(logits, labels)
print("Pérdida:", loss.item())

Pérdida: 0.7991690635681152


### Paso 3. Definición de parámetros y configuración

In [6]:
folder_path = './../Final_Database'
num_epochs = 20
BATCH_SIZE = 16
data_augmentation = True
da = "_DA" if data_augmentation else ""
lr = 1e-4
output_dim = 20
selected_model = 'RN50'

model_parameters_file = f"./modelos/multimodal/FULL_{selected_model.replace('/','')}_{output_dim}pers_lr{f'{lr:.0e}'}_bs{BATCH_SIZE}_{num_epochs}ep{da}.pt"
print(model_parameters_file)

./modelos/multimodal/FULL_RN50_20pers_lr1e-04_bs16_20ep_DA.pt


In [7]:
# WandB – Initialize a new run
run_name = model_parameters_file.split("/")[-1].replace('.pt', '')
wandb.init(entity="josealbertoap", project='TFM', name = run_name, tags=["multimodal"])

# WandB – Config is a variable that holds and saves hyperparameters and inputs
config = wandb.config          # Initialize config
config.batch_size = BATCH_SIZE          # input batch size for training (default: 64)
config.test_batch_size = BATCH_SIZE    # input batch size for testing (default: 1000)
config.epochs = num_epochs             # number of epochs to train (default: 10)
config.lr = lr              # learning rate (default: 0.01)
config.momentum = 0          # SGD momentum (default: 0.5)
config.no_cuda = True         # disables CUDA training
config.seed = 0               # random seed (default: 42)
config.log_interval = 1     # how many batches to wait before logging training status
config.num_classes = output_dim

[34m[1mwandb[0m: Currently logged in as: [33mjosealbertoap[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Paso 4. Definición de modelo y base de datos

In [8]:
from torchvision.transforms import Resize, Compose, ColorJitter, RandomHorizontalFlip, \
                                   RandomResizedCrop, RandomRotation, Normalize, ToTensor

def train_test_dataloaders(database_df, model, num_classes, data_augmentation=False, BATCH_SIZE=32, test_split=0.2):

    dataset = CustomDataset(database_df, num_classes, image_transform = model.preprocess)

    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=test_split,
                                           stratify=dataset.database_info.classID, random_state=42)
    train_sampler = SubsetRandomSampler(train_idx)

    # test_subset = Subset(dataset, test_idx) # En caso de que quisiéramos un Dataset y no un Dataloader
    test_sampler = SubsetRandomSampler(test_idx)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)

    # En caso de tener data augmentation, cambiamos el dataset para el Dataloader de train
    if data_augmentation:

      augmentation = Compose([
            RandomHorizontalFlip(p=0.3),
            RandomRotation(degrees=(0, 45), fill=0),
            RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0), ratio=(0.8, 1.2)),
            # ColorJitter(brightness=.3, contrast=.1, saturation=.1, hue=.1),
            ToTensor(),
            Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])

      dataset = CustomDataset(database_df, num_classes, image_transform = augmentation)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)

    return train_loader, test_loader, dataset.labelencoder.classes_

# Por si hay que meter la data augmentation para los audios
# aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

In [9]:
# Descargamos el modelo pre-entrenado y procesador de CLIP
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = AudioCLIP(selected_model, device, output_dim).to(device)
for param in model.parameters():
    param.requires_grad = True

train_loader, test_loader, classes = train_test_dataloaders(pd.read_csv(f'{folder_path}/finalDB_train.csv'),
                                                            model, output_dim, data_augmentation, BATCH_SIZE, 0.2)

Device: cpu


### Paso 5. Entrenamiento del modelo

In [10]:
# Inicializa el optimizador
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 3)

train_loss = {}
test_loss = {}
train_acc = {}
test_acc = {}

# Creamos la lista de descripciones para evaluar el modelo
print(f"People:{classes}\n")
eval_descriptions = torch.cat([clip.tokenize(f"a photo of {c}") for c in classes])

early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01, path=model_parameters_file)

wandb.watch(model, log="all")

for epoch in range(num_epochs):

    model.train()

    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    train_steps = tqdm(train_loader, unit="batch")

    for images, audios, labels in train_steps:

        train_steps.set_description(f"Epoch [{epoch+1}/{num_epochs}]. Training")

        optimizer.zero_grad()
        text_desc = eval_descriptions.to(device)
        audios = audios.to(device)
        images = images.to(device)
        labels = labels.to(device)

        output = model(images, text_desc, audios)

        # Cálculo de la accuracy
        predictions = output.argmax(dim=1, keepdim=True).squeeze()
        correct = (predictions == labels).sum().item()

        total_samples += labels.size(0)
        total_correct += correct

        # Cálculo de la función de pérdida y actualización del modelo
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        train_steps.set_postfix(mean_loss=epoch_loss/total_samples, mean_accuracy = total_correct / total_samples)

    train_loss[epoch+1] = epoch_loss / len(train_loader)
    train_acc[epoch+1] = total_correct / total_samples

    # Evaluación en el conjunto de prueba
    model.eval()  # Cambiamos al modo de evaluación
    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    test_steps = tqdm(test_loader, unit="batch")

    with torch.no_grad():
        for images, audios, labels in test_steps:  # Itera sobre los datos de prueba

            test_steps.set_description(f"Epoch [{epoch+1}/{num_epochs}]. Validation")

            text_desc = eval_descriptions.to(device)
            audios = audios.to(device)
            images = images.to(device)
            labels = labels.to(device)

            output = model(images, text_desc, audios)

            # Cálculo de la accuracy
            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            correct = (predictions == labels).sum().item()

            total_samples += labels.size(0)
            total_correct += correct

            # Cálculo de la función de pérdida y actualización del modelo
            loss = loss_fn(output, labels)
            epoch_loss += loss.item()

            test_steps.set_postfix(mean_loss=epoch_loss/total_samples, mean_accuracy = total_correct / total_samples)

        test_loss[epoch+1] = epoch_loss / len(test_loader)
        test_acc[epoch+1] = total_correct / total_samples

        print(f'Epoch [{epoch+1}/{num_epochs}]:')
        print(f'- Training. Loss = {train_loss[epoch+1]}; Accuracy = {train_acc[epoch+1]}.')
        print(f'- Validation. Loss = {test_loss[epoch+1]}; Accuracy = {test_acc[epoch+1]}.')
        print()

        wandb.log({
                        'Epoch': epoch+1,
                        'Training Loss': train_loss[epoch+1],
                        'Training Accuracy': train_acc[epoch+1],
                        'Evaluation Loss': test_loss[epoch+1],
                        'Evaluation Accuracy': test_acc[epoch+1],
                    })

        # Llamar a early_stopping con la pérdida de validación actual y el modelo
        early_stopping(test_loss[epoch+1], model)
        print('')

        # Si se alcanza el criterio de early stopping, romper el bucle
        if early_stopping.early_stop:
            print("Early stopping")
            break

        # Reducir el learning rate en caso de que no esté mejorando la pérdida
        scheduler.step(test_loss[epoch+1])

print({'train_acc': train_acc, 'train_loss': train_loss, 'val_acc': test_acc, 'val_loss': test_loss})

wandb.save(model_parameters_file)

People:['Alba Azorin Zafrilla' 'Alfonso Girona Palao' 'Alfonso Vidal Lopez'
 'Ana Azorin Puche' 'Ana Puche Palao' 'Angela Espinosa Martinez'
 'Clara Hidalgo Lopez' 'Cristina Carpena Ortiz' 'David Azorin Soriano'
 'Diego Molina Puche' 'Eva Jimenez Mariscal'
 'Francisco Jose Maldonado Montiel' 'Genesis Reyes Arteaga'
 'Irene Gutierrez Perez' 'Irene Molina Puche' 'Irene Ponte Ibanez'
 'Iria Alonso Alves' 'Javier Lopez Martinez' 'Jonathan Gonzalez Lopez'
 'Jose Alberto Azorin Puche']

Epoch [1/20]:
- Training. Loss = 2.8522090510680127; Accuracy = 0.147005444646098.
- Validation. Loss = 3.1212102266458364; Accuracy = 0.0893719806763285.

Validation loss decreased (inf --> 3.121210).  Saving model ...

Epoch [2/20]:
- Training. Loss = 2.3500846211726847; Accuracy = 0.2909860859044162.
- Validation. Loss = 4.037777277139517; Accuracy = 0.07246376811594203.

EarlyStopping counter: 1 out of 5

Epoch [3/20]:
- Training. Loss = 2.010368389578966; Accuracy = 0.3877797943133696.
- Validation. Loss

Epoch [1/20]. Training: 100%|██████████| 104/104 [18:40<00:00, 10.78s/batch, mean_accuracy=0.147, mean_loss=0.179]
Epoch [1/20]. Validation: 100%|██████████| 26/26 [03:25<00:00,  7.92s/batch, mean_accuracy=0.0894, mean_loss=0.196]
Epoch [2/20]. Training: 100%|██████████| 104/104 [18:01<00:00, 10.40s/batch, mean_accuracy=0.291, mean_loss=0.148]
Epoch [2/20]. Validation: 100%|██████████| 26/26 [03:19<00:00,  7.67s/batch, mean_accuracy=0.0725, mean_loss=0.254]
Epoch [3/20]. Training: 100%|██████████| 104/104 [18:22<00:00, 10.60s/batch, mean_accuracy=0.388, mean_loss=0.126]
Epoch [3/20]. Validation: 100%|██████████| 26/26 [03:18<00:00,  7.63s/batch, mean_accuracy=0.109, mean_loss=0.197]
Epoch [4/20]. Training: 100%|██████████| 104/104 [17:48<00:00, 10.27s/batch, mean_accuracy=0.564, mean_loss=0.0909]
Epoch [4/20]. Validation: 100%|██████████| 26/26 [03:26<00:00,  7.93s/batch, mean_accuracy=0.645, mean_loss=0.0752]
Epoch [5/20]. Training: 100%|██████████| 104/104 [17:34<00:00, 10.14s/batch,

### Evaluación del modelo entrenado

In [None]:
test_dataset = CustomDataset(pd.read_csv(f'{folder_path}/finalDB_test.csv'), 
                            output_dim, image_transform = model.preprocess)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=1048, shuffle=True)

In [None]:
# ----------------------------
# Inference
# ----------------------------
def inference (model, test_dl):
  correct_prediction = 0
  total_prediction = 0

  # Disable gradient updates
  with torch.no_grad():

    predictions = []
    label_list = []
    for data in test_dl:
      # Get the input features and target labels, and put them on the GPU
      images, audios, labels = data[0].to(device), data[1].to(device), data[2].to(device)
      texts = eval_descriptions.to(device)

      # Get predictions
      outputs = model(images, texts, audios)

      # Get the predicted class with the highest score
      _, prediction = torch.max(outputs,1)
      # Count of predictions that matched the target label
      correct_prediction += (prediction == labels).sum().item()
      total_prediction += prediction.shape[0]

      predictions.extend(prediction)
      label_list.extend(data[2])

  acc = correct_prediction/total_prediction
  print(f'Accuracy: {acc:.2f}, Total items: {total_prediction}')

  return predictions, label_list

# Run inference on trained model with the validation set
model.load_state_dict(torch.load(model_parameters_file, map_location=torch.device('cpu')))
result = inference(model, test_dl)

In [None]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, f1_score, confusion_matrix
import seaborn as sn
import numpy as np
import matplotlib.pyplot as plt
import re

def extraer_iniciales(name):
    name_words = name.split(' ')
    r = re.compile("^[A-Z][A-z]*")
    valid_words = list(filter(r.match, name_words))
    if len(valid_words) <=3:
        name = valid_words[0]
        valid_words.remove(valid_words[0])
    else:
        name = f'{valid_words[0]} {valid_words[1]}'
        valid_words.remove(valid_words[0])
        valid_words.remove(valid_words[1])
    surname = re.sub('(?<=[A-Z])[A-z]+', '.', ' '.join(valid_words))
    return f'{name} {surname}'

def font_scale(num_classes):
    if num_classes <= 10:
        return 1.0
    elif num_classes <= 20:
        return 0.75
    elif num_classes <= 30:
        return 0.65
    else:
        return 0.45

def plot_confusion_matrix(y_true, y_pred):
    cf_matrix = confusion_matrix(y_true, y_pred)
    people = list(map(extraer_iniciales, test_dataset.labelencoder.classes_))

    df_cm = pd.DataFrame((cf_matrix / np.sum(cf_matrix, axis=1)[:, None]).round(3), index=people, columns=people)
    
    plt.figure(figsize=(8, 6))  
    sn.set(font_scale = font_scale(df_cm.shape[0]))  
    heatmap = sn.heatmap(df_cm, annot=True, cbar=False, cmap='Purples', fmt='g', xticklabels=False)

    # Ajusta la rotación y alineación de los ticks de los ejes
    heatmap.set_yticklabels(heatmap.get_yticklabels(), rotation=0, ha='right')

    plt.tight_layout()  # Asegura que todo se ajuste bien en la figura
    plt.savefig(model_parameters_file.replace('/modelos/', '/results/').replace('.pt', '.png'))

    return plt.gcf()

def get_metrics(result):
    accuracy = accuracy_score(result[1], result[0])
    precision = precision_score(result[1], result[0], average='macro')
    recall = recall_score(result[1], result[0], average='macro')
    f1 = f1_score(result[1], result[0], average='macro')

    metrics = {
        'Test accuracy': accuracy,
        'Test precision': precision,
        'Test recall': recall,
        'F1-score': f1
    }

    print(metrics)

    metrics['Confusion Matrix'] = wandb.Image(plot_confusion_matrix(result[1],result[0]))
    metrics['Test metrics'] = wandb.Table(columns=["Metric name", "Value"], 
                                          data=[["Test accuracy", accuracy], ["Test precision", precision],
                                                ["Test recall", recall], ["Test F1-Score", f1]])

    return metrics

metrics = get_metrics(result)
wandb.log(metrics)

In [None]:
from PIL import Image
image_results = []

audio_file = './../Final_Database/audio/Jose Alberto Azorin Puche/audio0000.ogg'
aud = AudioUtil.open(audio_file)
aud = AudioUtil.resample(aud, 16000)
aud = AudioUtil.rechannel(aud, 1)
aud = AudioAugmentation.pad_trunc(aud, 4)
sgram_1 = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None).unsqueeze(0).to(device)

audio_file = './../Final_Database/audio/Jose Alberto Azorin Puche/audio_prueba.ogg'
aud = AudioUtil.open(audio_file)
aud = AudioUtil.resample(aud, 16000)
aud = AudioUtil.rechannel(aud, 1)
aud = AudioAugmentation.pad_trunc(aud, 4)
sgram_2 = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None).unsqueeze(0).to(device)

with torch.no_grad():

    prueba = False

    for sgram in [sgram_1, sgram_2]:
        audio_name = 'Prueba' if prueba else 'Original'

        for i in range(4):

            read_image = Image.open(f'./../Test_images/IMG_000{i}.jpg')
            image = model.preprocess(read_image).unsqueeze(0).to(device)

            output = model(image, eval_descriptions, sgram)
            probs = torch.round(output.softmax(dim=-1), decimals=4)
            pred_prob = torch.max(probs).item()
            pred_person = classes[torch.argmax(probs)]
            my_prob = probs.squeeze()[list(classes).index('Jose Alberto Azorin Puche')].item()

            image_results.append([f'Imagen {i+1}', audio_name, pred_person, pred_prob, my_prob])

        prueba = True
        
print(image_results)
wandb.log({"Test images results": wandb.Table(columns=["Imagen", "Audio", "Persona", "Probabilidad", "Prob (Joseal)"], data=image_results)})        


In [None]:
wandb.finish()