<a href="https://colab.research.google.com/github/gabrielcarcedo/SargazoClassification_ViT/blob/main/MeIA_Sargazo_Clasificaci%C3%B3n_ViT_19062025.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from google.colab import drive
import torch
from transformers import ViTImageProcessor
import matplotlib.pyplot as plt
import random
import torch.nn as nn
from transformers import ViTForImageClassification
from torch.optim import AdamW
import time
import torch.nn.functional as F
import numpy as np
import cv2
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from tqdm import tqdm

In [None]:
drive.mount('/content/drive')

# Revisar bien el path a la carpeta de Drive

In [None]:
cd drive/MyDrive/MeIA Sargazo Dataset/

In [None]:
df_train = pd.read_csv('labels/labels_augmented.csv')
df_test = pd.read_csv('labels/test.csv')

images_list = os.listdir('resized_images')

images_train = df_train['image_name'].tolist()
images_test = df_test['image_name'].tolist()

images_train = ['resized_images/' + image for image in images_train]
images_test = ['resized_images/' + image for image in images_test]

df_test['label_num'] = np.zeros(len(df_test))

images_train_label = df_train[['image_name', 'label_num']].values.tolist()
images_test_label = df_test[['image_name', 'label_num']].values.tolist()

In [None]:
class FilenameMappedDataset(Dataset):
    def __init__(self, image_paths, filename_label_pairs, transform=None):
        self.image_paths = image_paths
        self.transform = transform

        # Crear un diccionario tipo: {'ID_0001.png': 0, 'ID_0002.jpg': 1, ...}
        self.label_map = {filename: label for filename, label in filename_label_pairs}

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        filename = os.path.basename(img_path)
        label = self.label_map[filename]  # Busca la clase asociada a este filename

        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

transform_fn = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std)
])

In [None]:
from sklearn.model_selection import train_test_split

images_train, images_val, images_train_label, images_val_label = train_test_split(images_train, images_train_label, test_size=0.2, random_state=250)

In [None]:
train_dataset = FilenameMappedDataset(images_train, images_train_label, transform=transform_fn)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = FilenameMappedDataset(images_val, images_val_label, transform=transform_fn)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_dataset = FilenameMappedDataset(images_test, images_test_label, transform=transform_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Configuración
num_epochs = 10
patience = 5
learning_rate = 2e-6
num_classes = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Métricas
train_stats, val_stats = [], []

#fecha = time.strftime("%Y_%m_%d")
fecha = '2025_06_18'
checkpoint_dir = f"ViT_checkpoints_{fecha}"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
model = ViTForImageClassification.from_pretrained(
        'google/vit-base-patch16-224-in21k',
        num_labels=num_classes
    )
# Ruta del mejor modelo
best_model_path = os.path.join(checkpoint_dir, f"ViT_best_model_Luis.pth")
model.load_state_dict(torch.load(best_model_path, map_location=device))

model.to(device)

# Optimizador y función de pérdida
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = nn.BCEWithLogitsLoss()

# Entrenamiento

In [None]:
best_val_acc, patience_counter, best_val_f1 = 0.0, 0, 0.0
for epoch in range(num_epochs):
    # ----- Entrenamiento -----
    epoch_train_start = time.time()
    model.train()

    running_loss, correct, total, running_sensitivity, running_specificity = 0.0, 0, 0, 0.0, 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Convert labels to one-hot encoding
        labels_one_hot = torch.zeros(labels.size(0), num_classes).to(device)
        labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels_one_hot) # Use one-hot encoded labels
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        tp = ((predicted == labels) & (labels == 1)).sum().item()
        tn = ((predicted == labels) & (labels == 0)).sum().item()
        fp = ((predicted != labels) & (labels == 0)).sum().item()
        fn = ((predicted != labels) & (labels == 1)).sum().item()
        running_sensitivity += tp / (tp + fn + 1e-8)
        running_specificity += tn / (tn + fp + 1e-8)
        total += labels.size(0)

    train_loss = running_loss / total
    train_acc = correct / total
    train_sensitivity = running_sensitivity / total
    train_specificity = running_specificity / total
    train_time = time.time() - epoch_train_start

    # ----- Validación -----
    epoch_val_start = time.time()
    model.eval()

    val_loss, val_correct, val_total, val_sensitivity, val_specificity, val_f1 = 0.0, 0, 0, 0.0, 0.0, 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            # Convert labels to one-hot encoding
            labels_one_hot = torch.zeros(labels.size(0), num_classes).to(device)
            labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)

            outputs = model(images).logits
            loss = criterion(outputs, labels_one_hot) # Use one-hot encoded labels

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            tp = ((predicted == labels) & (labels == 1)).sum().item()
            tn = ((predicted == labels) & (labels == 0)).sum().item()
            fp = ((predicted != labels) & (labels == 0)).sum().item()
            fn = ((predicted != labels) & (labels == 1)).sum().item()
            val_sensitivity += tp / (tp + fn + 1e-8)
            val_specificity += tn / (tn + fp + 1e-8)
            val_total += labels.size(0)
            val_f1 += 2 * tp / (2 * tp + fp + fn + 1e-8)

    val_loss /= val_total
    val_acc = val_correct / val_total
    val_sensitivity /= val_total
    val_specificity /= val_total
    val_time = time.time() - epoch_val_start
    val_f1 /= val_total

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f}, "
              f"Val Loss={val_loss:.4f}, Acc={val_acc:.4f}")

    # Guardar si mejora la validación
    if best_val_f1 < val_f1:
        best_val_f1 = val_f1
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"ViT_best_model_19062025.pth"))
        print(f"Mejor modelo guardado en epoch {epoch+1} con val acc = {val_acc:.4f}")
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping en epoch {epoch+1} con val acc = {val_acc:.4f}")
            break

    # Registrar métricas
    train_stats.append([epoch+1, train_loss, train_acc, train_sensitivity, train_specificity, train_time])
    val_stats.append([epoch+1, val_loss, val_acc, val_sensitivity, val_specificity, val_time, val_f1])

# Guardar métricas
train_df = pd.DataFrame(train_stats, columns=['Epoch', 'Train_Loss', 'Train_Acc', 'Train_Sensitivity', 'Train_Specificity', 'Train_Time'])
val_df = pd.DataFrame(val_stats, columns=['Epoch', 'Val_Loss', 'Val_Acc', 'Val_Sensitivity', 'Val_Specificity', 'Val_Time', 'Val_f1'])

train_df.to_csv(os.path.join(checkpoint_dir, 'ViT_train_stats_Luis.csv'), index=False)
val_df.to_csv(os.path.join(checkpoint_dir, 'ViT_val_stats_Luis.csv'), index=False)

print("Entrenamiento finalizado. Modelo óptimo guardado como best_vit_model.pth")

# Validación

In [None]:
fecha = '2025_06_18'
checkpoint_dir = f"ViT_checkpoints_{fecha}"

# Ruta del mejor modelo
best_model_path = os.path.join(checkpoint_dir, f"ViT_best_model_Luis.pth")

# Cargar modelo
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes
)
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)
model.eval()

y_true = []
y_pred = []

# test_loader debe estar definido previamente
for images, labels in val_loader:
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(images)
        probs = F.softmax(outputs.logits, dim=1)
        preds = torch.argmax(probs, dim=1)

    y_true.extend(labels.cpu().numpy())
    y_pred.extend(preds.cpu().numpy())

print("\n--- Classification Report ---")
print(classification_report(y_true, y_pred))

# Matriz de confusión
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap=plt.cm.Blues)
plt.title("Matriz de Confusión - Validación Set")
plt.show()

# Test

In [None]:
df_train = pd.read_csv('new_version/labels/labels.csv')
df_test = pd.read_csv('new_version/labels/test.csv')

images_list = os.listdir('new_version/images')

images_train = df_train['image_name'].tolist()
images_train = ['new_version/resized_images/' + image for image in images_train]

images_test = df_test['image_name'].tolist()
images_test_paths = ['new_version/resized_images/' + image for image in images_test]

In [None]:
label_dict = {'nada':0, 'bajo':1, 'moderado':2, 'abundante':3, 'excesivo':4}
df_train['label_num'] = df_train['label'].map(label_dict)
images_train_label = df_train[['image_name', 'label_num']].values.tolist()

df_test['label_num'] = np.zeros(len(df_test))
images_test_label = df_test[['image_name', 'label_num']].values.tolist()

train_dataset = FilenameMappedDataset(images_train, images_train_label, transform=transform_fn)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = FilenameMappedDataset(images_test_paths, images_test_label, transform=transform_fn)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
fecha = '2025_06_18'
checkpoint_dir = f"ViT_checkpoints_{fecha}"
num_classes = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Ruta del mejor modelo
best_model_path = os.path.join(checkpoint_dir, f"ViT_best_model_Luis.pth")

# Cargar modelo
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=num_classes
)
model.load_state_dict(torch.load(best_model_path, map_location=device))
model.to(device)
model.eval()

y_true = []
y_pred = []

# test_loader debe estar definido previamente
for images, labels in test_loader:
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(images)
        probs = F.softmax(outputs.logits, dim=1)
        preds = torch.argmax(probs, dim=1)

    y_true.extend(labels.cpu().numpy())
    y_pred.extend(preds.cpu().numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)

for img1, img2 in zip(images_test, df_test['image_name'].tolist()):
  if img1.split('/')[-1]!= img2:
    print('ERROR', img1)

In [None]:
df_prediction = pd.DataFrame({'image_name': df_test['image_name'].tolist(), 'label': y_pred})
df_prediction.head()

In [None]:
label_dict = {'nada':0, 'bajo':1, 'moderado':2, 'abundante':3, 'excesivo':4}
label_dict_inv = {0:'nada', 1:'bajo', 2:'moderado', 3:'abundante', 4:'excesivo'}
df_prediction['label'] = df_prediction['label'].map(label_dict_inv)
df_prediction.head()

In [None]:
csv_path = 'ViT_predictions_19062025.csv'
df_prediction.to_csv(csv_path, index=False)