In [None]:
""" MODELO PRE ENTRENADO SwinViT Base
Dataset: Patch Camelyon, de tensorflow preparado para HuggingFace completo
Fuente del modelo: HuggingFace
Modelo: microsoft/swin-sbase-patch4-window7-224
Entrenamiento 3 epocas
"""

In [None]:
# Conectar con Google Drivee()
from google.colab import drive
drive.mount('/content/drive')
base_folder = "/content/drive/MyDrive/00 VIU/10 TFM"


In [None]:
# librerías
!pip install -qqq datasets evaluate keras_cv
!pip install --upgrade transformers -qqq
!pip install -qqq tensorflow
!pip install -qqq tf-keras


In [None]:
import numpy as np
import transformers
import evaluate
from transformers import AutoModelForImageClassification, AutoImageProcessor,TrainingArguments, Trainer
from datasets import Dataset as HuggingFaceDataset, Features
from PIL import Image as PILImage
import torch

In [None]:
# importar credenciales de Hugging face
from google.colab import userdata
import os
from huggingface_hub import login

# Accede al token almacenado en los secretos de Colab
hf_token = userdata.get('HF_TOKEN')

## Cargar el dataset patch camelyon de hugging face

In [None]:
train_hf_ds = HuggingFaceDataset.load_from_disk(base_folder+"/Datasets/pcamelyon_hf/train")

val_hf_ds = HuggingFaceDataset.load_from_disk(base_folder+"/Datasets/pcamelyon_hf/val")

In [None]:
model_name = "microsoft/swin-base-patch4-window7-224"
batch_size = 32

In [None]:
print(len(train_hf_ds))


### Preprocesar los datos

In [None]:
num_labels = 2
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
    ignore_mismatched_sizes = True
)

# instanciar el image processing
image_processor = AutoImageProcessor.from_pretrained(model_name,do_rescale=True,use_fast=True)

In [None]:
model_path="microsoft/swin-base-patch4-window7-224"
model = AutoModelForImageClassification.from_pretrained(model_path)
image_processor = AutoImageProcessor.from_pretrained(model_path,do_rescale=True,use_fast=True)
batch_size = 32
num_labels = 2

### Métricas

In [None]:
metric = evaluate.load("accuracy")
def compute_metrics(p):
    predictions = np.argmax(p.predictions, axis=1)
    references = p.label_ids
    metrics_result = metric.compute(predictions=predictions, references=references)
    print(f"Metric computation result: {metrics_result}")
    return {"eval_accuracy": metrics_result["accuracy"]}

### Collator

In [None]:
def custom_image_collator(batch):
    # Extract raw images and labels from the batch
    # Assuming example["image"] is already a PIL Image object
    images = [example["image"].convert("RGB") for example in batch]
    labels = [example["label"] for example in batch] # Assuming 'label' is the original label key

    # Process images using the image_processor
    processed_inputs = image_processor(images=images, return_tensors="pt")

    # Stack labels
    stacked_labels = torch.tensor(labels, dtype=torch.long)

    # Return the processed batch
    return {"pixel_values": processed_inputs["pixel_values"], "labels": stacked_labels}

### Parámetros de entrenamiento

In [None]:
# establecer los parámetros del entrenamiento
args = TrainingArguments(
    remove_unused_columns=False,
    eval_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_accuracy",
    report_to="none",
    output_dir=base_folder+"/Modelos entrenados/Swin v1 Base 3epocas"
)

In [None]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_hf_ds,
    eval_dataset=val_hf_ds,
    compute_metrics=compute_metrics,
    data_collator=custom_image_collator,
    tokenizer=image_processor
)

In [None]:
%%time
print("\n--- Iniciando entrenamiento ---")
train_results = trainer.train()

In [None]:
trainer.save_model(base_folder+"/Modelos entrenados/Swin v1 Base 3epocas")

In [None]:
# # EVALUAR EL MODELO

test_hf_ds = HuggingFaceDataset.load_from_disk(base_folder+"/Datasets/pcamelyon_hf/test")


In [None]:
%%time
test_results = trainer.evaluate(test_hf_ds)

print("\n--- Resultados de la evaluación en el dataset de test ---")
print(test_results)
