## Trabajo Fin de Máster <br/> Diseño de una arquitectura multimodal para descripción textual de pares imagen-audio

## Script 6. Entrenamiento del modelo conjunto con inputs de imagen, texto y audio

En este notebook, usamos la base de datos que hemos definido en el Script 5 para entrenar un modelo que acepta imágenes, piezas de texto y audios como inputs. Este modelo pretende diferenciar las distintas personas que han participado en la creación de la misma.

### Paso 1. Montamos el almacenamiento

Damos permiso a Colab para acceder a mi unidad de Drive y nos situamos en la carpeta donde tenemos los scripts y la librería que hemos creado con las clases propias.

In [3]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

set_seed(0)

In [4]:
import os
os.chdir('..')
os.getcwd()

'/mnt/batch/tasks/shared/LS_root/mounts/clusters/tfm-cpu/code/Users/jose.puche/Scripts'

### Paso 2. Iniciamos sesión para registrar los resultados en wandb


In [1]:
import wandb
!wandb login 1b8abaacf33b7b5812267384768c22a1eef3c11e

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/azureuser/.netrc


### Paso 2. Importación de paquetes

Instalamos las librerías necesarias (entre ellas, necesitamos el modelo CLIP, que descargamos directamente desde github), e importamos otras necesarias.

También importamos el dataset y el modelo que hemos definido para nuestro problema, y que se encuentran en

In [10]:
import clip
import torch
import pandas as pd
import torch.nn as nn
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset, SubsetRandomSampler, DataLoader

from tqdm import tqdm

from tfm_lib.datasets import CustomDataset
from tfm_lib.modelos import AudioCLIP
from tfm_lib.EarlyStopping import EarlyStopping

In [6]:
# Función de pérdida
def loss_fn(logits, labels):
    """
    logits: Las salidas del modelo (predicciones) para cada clase.
    labels: Las etiquetas verdaderas (números enteros) para cada ejemplo.
    """
    criterion = nn.CrossEntropyLoss()  # Función de pérdida de entropía cruzada
    return criterion(logits, labels)

# Ejemplo de cómo usar la función de pérdida
logits = torch.tensor([[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.3, 0.2, 0.5]])
labels = torch.tensor([0, 1, 2])

loss = loss_fn(logits, labels)
print("Pérdida:", loss.item())

Pérdida: 0.7991690635681152


### Paso 3. Definición de parámetros y configuración

In [7]:
folder_path = './../Final_Database'
num_epochs = 20
BATCH_SIZE = 16
data_augmentation = True
da = "_DA" if data_augmentation else ""
lr = 1e-4
output_dim = 2
selected_model = 'RN50'

model_parameters_file = f"./modelos/multimodal/FULL_{selected_model.replace('/','')}_{output_dim}pers_lr{f'{lr:.0e}'}_bs{BATCH_SIZE}_{num_epochs}ep{da}.pt"
print(model_parameters_file)

./modelos/multimodal/FULL_RN50_2pers_lr1e-04_bs16_20ep_DA.pt


In [8]:
# WandB – Initialize a new run
run_name = model_parameters_file.split("/")[-1].replace('.pt', '')
wandb.init(entity="josealbertoap", project='TFM', name = run_name, tags=["multimodal"])

# WandB – Config is a variable that holds and saves hyperparameters and inputs
config = wandb.config          # Initialize config
config.batch_size = BATCH_SIZE          # input batch size for training (default: 64)
config.test_batch_size = BATCH_SIZE    # input batch size for testing (default: 1000)
config.epochs = num_epochs             # number of epochs to train (default: 10)
config.lr = lr              # learning rate (default: 0.01)
config.momentum = 0          # SGD momentum (default: 0.5)
config.no_cuda = True         # disables CUDA training
config.seed = 0               # random seed (default: 42)
config.log_interval = 1     # how many batches to wait before logging training status
config.num_classes = output_dim

[34m[1mwandb[0m: Currently logged in as: [33mjosealbertoap[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Paso 4. Definición de modelo y base de datos

In [9]:
from torchvision.transforms import Resize, Compose, ColorJitter, RandomHorizontalFlip, \
                                   RandomResizedCrop, RandomRotation, Normalize, ToTensor

def train_test_dataloaders(folder_path, output_dim, model, data_augmentation=False, BATCH_SIZE=32, test_split=0.2):

    dataset = CustomDataset(database_path = folder_path, num_classes = output_dim, image_transform = model.preprocess, audio_transform = None)

    train_idx, test_idx = train_test_split(list(range(len(dataset))), test_size=test_split,
                                           stratify=dataset.database_info.classID, random_state=42)
    train_sampler = SubsetRandomSampler(train_idx)

    # test_subset = Subset(dataset, test_idx) # En caso de que quisiéramos un Dataset y no un Dataloader
    test_sampler = SubsetRandomSampler(test_idx)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=test_sampler)

    # En caso de tener data augmentation, cambiamos el dataset para el Dataloader de train
    if data_augmentation:

      augmentation = Compose([
            RandomHorizontalFlip(p=0.3),
            RandomRotation(degrees=(0, 45), fill=0),
            RandomResizedCrop(size=(224, 224), scale=(0.2, 1.0), ratio=(0.8, 1.2)),
            # ColorJitter(brightness=.3, contrast=.1, saturation=.1, hue=.1),
            ToTensor(),
            Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])

      dataset = CustomDataset(database_path = folder_path, num_classes = output_dim, image_transform = augmentation, audio_transform = None)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler)

    return train_loader, test_loader, dataset.labelencoder.classes_

# Por si hay que meter la data augmentation para los audios
# aug_sgram = AudioUtil.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

In [42]:
# Descargamos el modelo pre-entrenado y procesador de CLIP
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = AudioCLIP(selected_model, device, output_dim).to(device)
for param in model.parameters():
    param.requires_grad = True

train_loader, test_loader, classes = train_test_dataloaders(pd.read_csv(f'{folder_path}/finalDB_train.csv'),
                                                            output_dim, model, data_augmentation, BATCH_SIZE, 0.2)

Device: cuda:0


In [43]:
train_loader, test_loader, classes = train_test_dataloaders('../Final_Database_mini', output_dim, model, data_augmentation, BATCH_SIZE, 0.1)
text_desc = torch.cat([clip.tokenize(f"a photo of {c}") for c in classes]).to(device)
print(classes)

['Genesis Reyes Arteaga' 'Jose Alberto Azorin Puche' 'Juan Cuesta Lopez'
 'Juanjo Bautista Ibanez' 'Maria Jose Morales Forte'
 'Noelia Sanchez Alonso']


### Paso 5. Entrenamiento del modelo

In [44]:
# Inicializa el optimizador

# from torch.optim.lr_scheduler import StepLR
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
# scheduler = StepLR(optimizer, step_size=5, gamma=0.05)
# loss_fn = nn.CrossEntropyLoss()

early_stopping = EarlyStopping(patience=10, verbose=True, path=model_parameters_file)

train_loss = {}
test_loss = {}
train_acc = {}
test_acc = {}

# wandb.watch(model, log="all")

for epoch in range(num_epochs):

    model.train()

    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    train_steps = tqdm(train_loader, unit="batch")

    for images, audios, labels in train_steps:

        train_steps.set_description(f"Epoch [{epoch+1}/{num_epochs}]. Training")

        optimizer.zero_grad()
        text_desc = text_desc.to(device)
        audios = audios.to(device)
        images = images.to(device)
        labels = labels.to(device)

        output = model(images, text_desc, audios)

        # Cálculo de la accuracy
        predictions = output.argmax(dim=1, keepdim=True).squeeze()
        correct = (predictions == labels).sum().item()

        total_samples += labels.size(0)
        total_correct += correct

        # Cálculo de la función de pérdida y actualización del modelo
        loss = loss_fn(output, labels)
        loss.backward()
        optimizer.step()
        #scheduler.step()

        epoch_loss += loss.item()

        train_steps.set_postfix(mean_loss=epoch_loss/total_samples, mean_accuracy = total_correct / total_samples)

    train_loss[epoch+1] = epoch_loss / len(train_loader)
    train_acc[epoch+1] = total_correct / total_samples

    # Evaluación en el conjunto de prueba
    model.eval()  # Cambiamos al modo de evaluación
    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    test_steps = tqdm(test_loader, unit="batch")

    with torch.no_grad():
        for images, audios, labels in test_steps:  # Itera sobre los datos de prueba

            test_steps.set_description(f"Epoch [{epoch+1}/{num_epochs}]. Evaluation")

            text_desc = text_desc.to(device)
            audios = audios.to(device)
            images = images.to(device)
            labels = labels.to(device)

            output = model(images, text_desc, audios)

            # Cálculo de la accuracy
            predictions = output.argmax(dim=1, keepdim=True).squeeze()
            correct = (predictions == labels).sum().item()

            total_samples += labels.size(0)
            total_correct += correct

            # Cálculo de la función de pérdida y actualización del modelo
            loss = loss_fn(output, labels)
            epoch_loss += loss.item()

            test_steps.set_postfix(mean_loss=epoch_loss/total_samples, mean_accuracy = total_correct / total_samples)

        test_loss[epoch+1] = epoch_loss / len(test_loader)
        test_acc[epoch+1] = total_correct / total_samples

        print(f'Epoch [{epoch+1}/{num_epochs}]:')
        print(f'- Training. Loss = {train_loss[epoch+1]}; Accuracy = {train_acc[epoch+1]}.')
        print(f'- Evaluation. Loss = {test_loss[epoch+1]}; Accuracy = {test_acc[epoch+1]}.')
        print()

        # Llamar a early_stopping con la pérdida de validación actual y el modelo
        early_stopping(test_loss[epoch+1], model)
        print('')

        # Si se alcanza el criterio de early stopping, romper el bucle
        if early_stopping.early_stop:
            print("Early stopping")
            break

# wandb.log({'Training loss': train_loss, 'Training accuracy': train_acc, 'Evaluation loss': test_loss, 'Evaluation accuracy': test_acc})

# Guardamos el modelo
# wandb.save(model_parameters_file)

Epoch [1/100]. Training: 100%|██████████| 35/35 [00:16<00:00,  2.13batch/s, mean_accuracy=0.379, mean_loss=0.1]
Epoch [1/100]. Evaluation: 100%|██████████| 4/4 [00:01<00:00,  2.66batch/s, mean_accuracy=0.774, mean_loss=0.0581]


Epoch [1/100]:
- Training. Loss = 1.5873861057417733; Accuracy = 0.37906137184115524.
- Evaluation. Loss = 0.9003170728683472; Accuracy = 0.7741935483870968.

Validation loss decreased (inf --> 0.900317).  Saving model ...



Epoch [2/100]. Training: 100%|██████████| 35/35 [00:17<00:00,  1.96batch/s, mean_accuracy=0.903, mean_loss=0.0411]
Epoch [2/100]. Evaluation: 100%|██████████| 4/4 [00:02<00:00,  1.92batch/s, mean_accuracy=1, mean_loss=0.0302]


Epoch [2/100]:
- Training. Loss = 0.6503684895379203; Accuracy = 0.9025270758122743.
- Evaluation. Loss = 0.46790148317813873; Accuracy = 1.0.

Validation loss decreased (0.900317 --> 0.467901).  Saving model ...



Epoch [3/100]. Training: 100%|██████████| 35/35 [00:17<00:00,  2.05batch/s, mean_accuracy=0.975, mean_loss=0.0251]
Epoch [3/100]. Evaluation: 100%|██████████| 4/4 [00:01<00:00,  2.44batch/s, mean_accuracy=1, mean_loss=0.0191]


Epoch [3/100]:
- Training. Loss = 0.3970865547657013; Accuracy = 0.9747292418772563.
- Evaluation. Loss = 0.2966003492474556; Accuracy = 1.0.

Validation loss decreased (0.467901 --> 0.296600).  Saving model ...



Epoch [4/100]. Training: 100%|██████████| 35/35 [00:17<00:00,  1.99batch/s, mean_accuracy=0.982, mean_loss=0.0157]
Epoch [4/100]. Evaluation: 100%|██████████| 4/4 [00:02<00:00,  1.92batch/s, mean_accuracy=1, mean_loss=0.0109]


Epoch [4/100]:
- Training. Loss = 0.24847270590918405; Accuracy = 0.9819494584837545.
- Evaluation. Loss = 0.16871321201324463; Accuracy = 1.0.

Validation loss decreased (0.296600 --> 0.168713).  Saving model ...



  0%|          | 0/35 [00:00<?, ?batch/s]


KeyboardInterrupt: 

In [18]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [None]:
import matplotlib.pyplot as plt

plt.plot(*zip(*sorted(train_loss.items())))
plt.plot(*zip(*sorted(test_loss.items())))
plt.show()

### Pruebas

In [39]:
import clip
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from PIL import Image
import numpy as np
from IPython.display import Audio
import torchaudio
from tfm_lib.audio_processing import AudioUtil, AudioAugmentation
from tfm_lib.modelos import AudioCLIP

# Descarga el modelo pre-entrenado y procesador de CLIP
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AudioCLIP(selected_model, device, output_dim).to(device)

# model_parameters_file = '/content/drive/MyDrive/TFM/Proyecto/Scripts/modelos/FULL_RN50_6pers_lr1e-06_bs16_20epDA.h5'

model.load_state_dict(torch.load(model_parameters_file))
model.eval()

AudioCLIP(
  (clip_model): CLIP(
    (visual): VisionTransformer(
      (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (transformer): Transformer(
        (resblocks): Sequential(
          (0): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
            )
            (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (mlp): Sequential(
              (c_fc): Linear(in_features=768, out_features=3072, bias=True)
              (gelu): QuickGELU()
              (c_proj): Linear(in_features=3072, out_features=768, bias=True)
            )
            (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          )
          (1): ResidualAttentionBlock(
            (attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantiz

In [40]:
# read_image = Image.open('/content/drive/MyDrive/TFM/Proyecto/Final_Database_mini/image/Jose Alberto Azorin Puche/frame00001.jpg')
read_image = Image.open('/content/drive/MyDrive/TFM/Proyecto/IMG_0003.jpg')
aud = AudioUtil.open('/content/drive/MyDrive/TFM/Proyecto/Final_Database_mini/audio/Jose Alberto Azorin Puche/audio0000.ogg')
aud = AudioUtil.resample(aud, 48000)
aud = AudioUtil.rechannel(aud, 1)
aud = AudioAugmentation.pad_trunc(aud, 4)
sgram = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None).unsqueeze(0).to(device)

image = model.preprocess(read_image).unsqueeze(0).to(device)
people = ['Genesis Reyes Arteaga', 'Jose Alberto Azorin Puche', 'Juan Cuesta Lopez',
          'Juanjo Bautista Ibanez', 'Maria Jose Morales Forte', 'Noelia Sanchez Alonso']
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in people]).to(device)

with torch.no_grad():
  output = model(image, text, sgram)

print(output.softmax(dim=-1))

tensor([[0.0019, 0.6802, 0.0247, 0.0222, 0.2579, 0.0132]], device='cuda:0')


In [29]:
class CustomDataset(Dataset):
    def __init__(self, database_path='', num_classes=40, image_transform=None, audio_transform=None):

      # Atributos derivados de los parametros
      self.database_path = database_path
      self.image_transform = image_transform
      self.audio_transform = audio_transform
      self.num_classes = num_classes
      self.database_info = self.filter_classes(pd.read_csv(f'{database_path}/final_db.csv'), num_classes).iloc[::5]
      self.database_info.reset_index(drop=True, inplace=True)
      self.classes = list(self.database_info['classID'].unique())

      # Atributos relacionados con audio
      self.duration = 4
      self.sr = 48000
      self.channel = 1
      self.shift_pct = 0.7

      # Codificación de las clases en valores numéricos
      le = preprocessing.LabelEncoder()
      self.labelencoder = le.fit(self.database_info["classID"])

    # --------------------------------
    # Numero de elementos del dataset
    # --------------------------------
    def __len__(self):
      return len(self.database_info)

    # --------------------------------
    # Selección del elemento i-esimo
    # --------------------------------
    def __getitem__(self, idx):
      # Definición de los paths donde leer imagen y audio del i-esimo dato
      image_file = self.database_info.loc[idx, 'image_path']
      audio_file = self.database_info.loc[idx, 'audio_path']

      # Obtenemos la etiqueta para el i-esimo dato
      db_df = self.database_info.copy()
      db_df['classID'] = self.labelencoder.transform(db_df['classID'])
      class_id = db_df.loc[idx, 'classID']

      # Obtención del tensor correspondiente a la i-esima imagen
      read_image = Image.open(image_file)
      image = self.image_transform(read_image)

      # Obtención del espectrograma del audio correspondiente al i-esimo dato
      aud = AudioUtil.open(audio_file)
      aud = AudioUtil.resample(aud, self.sr)
      aud = AudioUtil.rechannel(aud, self.channel)
      aud = AudioAugmentation.pad_trunc(aud, self.duration)
      sgram = AudioUtil.spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None)

      return image, sgram, class_id

    # ----------------------------
    # Filter only the number of classes of interest
    # ----------------------------
    def filter_classes(self, df, num_classes):
      # Obtener todas las clases únicas
      unique_classes = df['classID'].unique()

      # Seleccionar las primeras num_classes clases
      selected_classes = list(unique_classes)[:num_classes]

      # Filtrar el DataFrame para incluir solo las clases seleccionadas
      filtered_df = df[df['classID'].isin(selected_classes)]

      return filtered_df