In [20]:
import boto3
import os
from dotenv import load_dotenv

load_dotenv()
access_key_id = os.getenv("ACCESS_KEY_ID")
secret_access_key = os.getenv("SECRET_ACCESS_KEY")
s3_endpoint = os.getenv("S3_API_ENDPOINT", "localhost:9000")

# Si el endpoint contiene "minio:" (nombre de servicio Docker), reemplazarlo por localhost
# para que funcione cuando se ejecuta fuera de Docker
if s3_endpoint.startswith("minio:"):
    s3_endpoint = s3_endpoint.replace("minio:", "localhost:")

minio_url = "http://" + s3_endpoint
print(f"Conectando a MinIO en: {minio_url}")

minio_client = boto3.client(
    "s3",
    aws_access_key_id=access_key_id,
    aws_secret_access_key=secret_access_key,
    endpoint_url=minio_url
)

minio_bucket = "training-preparation-zone"
manifest_name = "dataset_train.json"
local_file = "./dataset_train.json"

manifest_name_augmented = "dataset_train_augmented.json"
local_file_augmented = "./dataset_train_augmented.json"


Conectando a MinIO en: http://localhost:9000


In [21]:
def download_manifest_from_minio(bucket_name, object_name, local_path):
    try:
        minio_client.download_file(bucket_name, object_name, local_path)
    except Exception as e:
        print(f"Error downloading {object_name} from bucket {bucket_name}: {e}")
    return local_path

downloaded_path = download_manifest_from_minio(minio_bucket, manifest_name, local_file)
downloaded_path_augmented = download_manifest_from_minio(minio_bucket, manifest_name_augmented, local_file_augmented)

print(downloaded_path)
print(downloaded_path_augmented)

./dataset_train.json
./dataset_train_augmented.json


In [22]:
import pandas as pd

def load_manifest(manifest_path):
    with open(manifest_path, 'r') as f:
        data = pd.read_json(f)
    
    print(f"Loaded {len(data)} entries from the manifest.")
    return data

df = load_manifest(downloaded_path)
print(df)

df_augmented = load_manifest(downloaded_path_augmented)
print(df_augmented)


Loaded 664 entries from the manifest.
                                                 image  \
0    images/ISIC_0025899_rotated_-13_contrast_1.06.png   
1    images/ISIC_0026803_rotated_-2_contrast_1.19_f...   
2    images/ISIC_0026803_rotated_-2_contrast_1.19_f...   
3    images/ISIC_0029577_brightness_0.87_contrast_0...   
4                              images/ISIC_0031981.png   
..                                                 ...   
659  images/ISIC_0027219_rotated_10_brightness_1.03...   
660                            images/ISIC_0031380.png   
661                            images/ISIC_0025899.png   
662  images/ISIC_0029694_brightness_1.06_contrast_0...   
663  images/ISIC_0028103_rotated_2_contrast_0.89_fl...   

                              text     score  
0    texts/actinic_keratosis_0.txt  9934.281  
1    texts/actinic_keratosis_1.txt  9933.505  
2    texts/actinic_keratosis_2.txt  9932.875  
3    texts/actinic_keratosis_3.txt  9935.860  
4    texts/actinic_keratosis_4

## Hyperparameters

In [23]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.optim import AdamW
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-base-patch32"
BATCH_SIZE = 8
LEARNING_RATE = 5e-6
EPOCHS = 3

## Data retrieval

The inputs variable is defined as it is because the model needs all of those parameters:

- Truncation=True means that if we provide more than 77 tokens (the usual maximum) it truncates the data

- Padding=max_length means that we add zeros to fill the max_length. We need to provide the same length for all the data (specially in text).



In [24]:
import io

class SkinLesionDataset(Dataset):
    def __init__(self, dataframe, processor, minio_client, bucket_name):
        self.df = dataframe
        self.processor = processor
        self.minio_client = minio_client
        self.bucket_name = bucket_name

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_key = self.df.iloc[idx]['image']
        txt_key = self.df.iloc[idx]['text']

        img_response = self.minio_client.get_object(Bucket=self.bucket_name, Key=img_key)
        img_bytes = img_response['Body'].read()
        image = Image.open(io.BytesIO(img_bytes)).convert("RGB")

        txt_response = self.minio_client.get_object(Bucket=self.bucket_name, Key=txt_key)
        description = txt_response['Body'].read().decode('utf-8').strip()

        inputs = self.processor(
            text=[description], 
            images=image, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True
        )
        
        return {k: v.squeeze(0) for k, v in inputs.items()}


## Initialization

Here we train the smaller clip model. We load it from the SkinLesionDataset class we created and the particularity is that we use AdamW. The AdamW is a widely used optimitzer for training Transformers. While the loss function tells the model where it needs to go, the optimitzer decides how fast it goes.

In [25]:
model = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
dataset = SkinLesionDataset(df, processor, minio_client, minio_bucket)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1)



## Model training

Here we train the model using the hyperparameters and all the information provided in the previous cells.

In [37]:
loss_history = []
model.train()

for epoch in range(EPOCHS):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    epoch_loss = 0
    
    for batch in pbar:
        optimizer.zero_grad()
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        
        outputs = model(
            input_ids=batch['input_ids'],
            pixel_values=batch['pixel_values'],
            attention_mask=batch['attention_mask'],
            return_loss=True
        )
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})
    
    loss_history.append(epoch_loss / len(dataloader))

Epoch 1:   0%|          | 0/83 [00:00<?, ?it/s]


NoSuchKey: An error occurred (NoSuchKey) when calling the GetObject operation: The specified key does not exist.

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from bert_score import score as bert_score_func
from sklearn.metrics import recall_score, f1_score, confusion_matrix

def extract_class_from_path(path):
    return "_".join(path.split("/")[-1].split("_")[:-3])

@torch.no_grad()
def get_comprehensive_metrics(model, dataloader, device):
    model.eval()
    all_image_embeds = []
    all_text_embeds = []
    all_ground_truth_texts = []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        img_emb = model.get_image_features(pixel_values=batch['pixel_values'])
        txt_emb = model.get_text_features(input_ids=batch['input_ids'], 
                                        attention_mask=batch['attention_mask'])
        
        all_image_embeds.append(F.normalize(img_emb, dim=-1))
        all_text_embeds.append(F.normalize(txt_emb, dim=-1))

    image_embeds = torch.cat(all_image_embeds)
    text_embeds = torch.cat(all_text_embeds)

    # Perspective 1: Text-to-Image Retrieval
    sim_matrix = text_embeds @ image_embeds.T
    
    num_queries = sim_matrix.size(0)
    ranks = []
    
    for i in range(num_queries):
        sorted_indices = torch.argsort(sim_matrix[i], descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item() + 1
        ranks.append(rank)
    
    ranks = np.array(ranks)

    # Perspective 2: Safety (Clinical Classification)
    # We pass an image, retrieve the best text and check if classes match.
    sim_matrix_i2t = image_embeds @ text_embeds.T
    
    # Map all text files in the dataset to their classes
    all_text_paths = df['text'].tolist()
    text_classes = np.array([extract_class_from_path(p) for p in all_text_paths])
    image_classes = np.array([extract_class_from_path(p) for p in df['text'].tolist()])

    top_text_indices = torch.argmax(sim_matrix_i2t, dim=-1).cpu().numpy()
    predicted_classes = text_classes[top_text_indices]

    sensitivity = recall_score(image_classes, predicted_classes, average='macro')
    f1 = f1_score(image_classes, predicted_classes, average='macro')

    cm = confusion_matrix(image_classes, predicted_classes)
    fp = cm.sum(axis=0) - np.diag(cm)
    fn = cm.sum(axis=1) - np.diag(cm)
    tp = np.diag(cm)
    tn = cm.sum() - (fp + fn + tp)
    specificity = np.mean(tn / (tn + fp + 1e-10))

    def get_text_content(path, client, bucket):
        response = client.get_object(Bucket=bucket, Key=path)
        return response['Body'].read().decode('utf-8').strip()

    sample_indices = np.random.choice(len(df), min(50, len(df)), replace=False)
    gt_texts = [get_text_content(df.iloc[i]['text'], minio_client, minio_bucket) for i in sample_indices]
    predicted_classes_texts = [get_text_content(df.iloc[top_text_indices[i]]['text'], minio_client, minio_bucket) for i in sample_indices]

    _, _, bert_f1 = bert_score_func(predicted_classes_texts, gt_texts, lang='en', verbose=False)

    metrics = {
        # Perspective 1
        "Recall@1":  np.mean(ranks <= 1),
        "Recall@5":  np.mean(ranks <= 5),
        "Recall@10": np.mean(ranks <= 10),
        "Mean Rank": np.mean(ranks),
        "Median Rank": np.median(ranks),
        "MRR": np.mean(1.0 / ranks),
        "NDCG": np.mean([1.0 / np.log2(r + 1) for r in ranks]),
        # Perspective 2
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "F1 Score": f1,
        # Perspective 3
        "BERTScore F1": bert_f1.mean().item()
    }
    
    return metrics

eval_results = get_comprehensive_metrics(model, dataloader, DEVICE)
print("Evaluation Results:", eval_results)

Evaluation Results: {'Recall@1': np.float64(0.10843373493975904), 'Recall@5': np.float64(0.3614457831325301), 'Recall@10': np.float64(0.5542168674698795), 'Mean Rank': np.float64(16.236144578313255), 'Median Rank': np.float64(9.0), 'MRR': np.float64(0.24258346426257882), 'NDCG': np.float64(0.396844550944322)}
