In [1]:
EXPERIMENT_NAME = "augmented-baseline-teacher"

In [2]:
import boto3
import os
from dotenv import load_dotenv

load_dotenv()
access_key_id = os.getenv("ACCESS_KEY_ID")
secret_access_key = os.getenv("SECRET_ACCESS_KEY")
minio_url = "http://" + os.getenv("S3_API_ENDPOINT")


minio_client = boto3.client(
    "s3",
    aws_access_key_id=access_key_id,
    aws_secret_access_key=secret_access_key,
    endpoint_url=minio_url
)

minio_bucket = "training-preparation-zone"
manifest_name = "dataset_train_augmented.json"
local_file = "./dataset_train_augmented.json"

In [3]:
print(access_key_id)
print(secret_access_key)
print(minio_url)
print(minio_bucket)
print(manifest_name)
print(local_file)


minioadmin
minioadmin
http://localhost:9000
training-preparation-zone
dataset_train_augmented.json
./dataset_train_augmented.json


In [4]:
def download_manifest_from_minio(bucket_name, object_name, local_path):
    try:
        minio_client.download_file(bucket_name, object_name, local_path)
    except Exception as e:
        print(f"Error downloading {object_name} from bucket {bucket_name}: {e}")
    return local_path

downloaded_path = download_manifest_from_minio(minio_bucket, manifest_name, local_file)

In [5]:
import pandas as pd

def load_manifest(manifest_path):
    with open(manifest_path, 'r') as f:
        data = pd.read_json(f)
    
    print(f"Loaded {len(data)} entries from the manifest.")
    return data

df = load_manifest(downloaded_path)
print(df)

Loaded 1992 entries from the manifest.
                                                  image  \
0                               images/ISIC_0025899.png   
1       images/ISIC_0025899_brightness_1.12_flipped.png   
2     images/ISIC_0025899_rotated_-13_contrast_1.06.png   
3                               images/ISIC_0026803.png   
4     images/ISIC_0026803_rotated_-2_contrast_1.19_f...   
...                                                 ...   
1987  images/ISIC_0029694_brightness_1.06_contrast_0...   
1988      images/ISIC_0029694_contrast_1.15_flipped.png   
1989                            images/ISIC_0028103.png   
1990                  images/ISIC_0028103_augmented.png   
1991  images/ISIC_0028103_rotated_2_contrast_0.89_fl...   

                               text     score  
0     texts/actinic_keratosis_0.txt  9934.281  
1     texts/actinic_keratosis_0.txt  9934.281  
2     texts/actinic_keratosis_0.txt  9934.281  
3     texts/actinic_keratosis_1.txt  9933.505  
4     texts/

## Hyperparameters

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torch.optim import AdamW
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "openai/clip-vit-large-patch14" # Baseline CLIP model
BATCH_SIZE = 8
LEARNING_RATE = 5e-6
EPOCHS = 3

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0
Is CUDA available: False


## Data retrieval

The inputs variable is defined as it is because the model needs all of those parameters:

- Truncation=True means that if we provide more than 77 tokens (the usual maximum) it truncates the data

- Padding=max_length means that we add zeros to fill the max_length. We need to provide the same length for all the data (specially in text).



In [8]:
import io

class SkinLesionDataset(Dataset):
    def __init__(self, dataframe, processor, minio_client, bucket_name):
        self.df = dataframe
        self.processor = processor
        self.minio_client = minio_client
        self.bucket_name = bucket_name

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_key = self.df.iloc[idx]['image']
        txt_key = self.df.iloc[idx]['text']

        img_response = self.minio_client.get_object(Bucket=self.bucket_name, Key=img_key)
        img_bytes = img_response['Body'].read()
        image = Image.open(io.BytesIO(img_bytes)).convert("RGB")

        txt_response = self.minio_client.get_object(Bucket=self.bucket_name, Key=txt_key)
        description = txt_response['Body'].read().decode('utf-8').strip()

        inputs = self.processor(
            text=[description], 
            images=image, 
            return_tensors="pt", 
            padding="max_length", 
            truncation=True
        )
        
        return {k: v.squeeze(0) for k, v in inputs.items()}


## Initialization

Here we train the smaller clip model. We load it from the SkinLesionDataset class we created and the particularity is that we use AdamW. The AdamW is a widely used optimitzer for training Transformers. While the loss function tells the model where it needs to go, the optimitzer decides how fast it goes.

In [None]:
model = CLIPModel.from_pretrained(MODEL_ID).to(DEVICE)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
dataset = SkinLesionDataset(df, processor, minio_client, minio_bucket)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.1)

  [2m2026-01-02T23:33:29.355421Z[0m [33m WARN[0m  [33mReqwest(reqwest::Error { kind: Request, url: "https://transfer.xethub.hf.co/xorbs/default/0690af2dd565ebca8a4ba9275970509e204dbaa610eb192d62aeb1056904508b?X-Xet-Signed-Range=bytes%3D0-49584236&X-Xet-Session-Id=01KE0GR8E9224PKHVENT3X38R8&Expires=1767400215&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly90cmFuc2Zlci54ZXRodWIuaGYuY28veG9yYnMvZGVmYXVsdC8wNjkwYWYyZGQ1NjVlYmNhOGE0YmE5Mjc1OTcwNTA5ZTIwNGRiYWE2MTBlYjE5MmQ2MmFlYjEwNTY5MDQ1MDhiP1gtWGV0LVNpZ25lZC1SYW5nZT1ieXRlcyUzRDAtNDk1ODQyMzYmWC1YZXQtU2Vzc2lvbi1JZD0wMUtFMEdSOEU5MjI0UEtIVkVOVDNYMzhSOCIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NzQwMDIxNX19fV19&Signature=MEUCIQDE04ItUnCQNwHwWpzi~tCoeRZRs8~57nBHVhWy2x~6cgIgPENeAWzfxhonqO32I~b4r1hvCqQek7jcY~mJN~uGNjU_&Key-Pair-Id=K3TGA0E1JYVXF7", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: nodename nor ser

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 92c67af5-fb1c-4925-a4a3-b8bfc888a2d2)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json
Retrying in 1s [Retry 1/5].


  [2m2026-01-02T23:33:36.846331Z[0m [33m WARN[0m  [33mReqwest(reqwest::Error { kind: Request, url: "https://transfer.xethub.hf.co/xorbs/default/4690799c77191352e1df12c044fcbd0beadf1853276e9f218a9f25cf53f15a85?X-Xet-Signed-Range=bytes%3D0-44567244&X-Xet-Session-Id=01KE0GR8E9224PKHVENT3X38R8&Expires=1767400216&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly90cmFuc2Zlci54ZXRodWIuaGYuY28veG9yYnMvZGVmYXVsdC80NjkwNzk5Yzc3MTkxMzUyZTFkZjEyYzA0NGZjYmQwYmVhZGYxODUzMjc2ZTlmMjE4YTlmMjVjZjUzZjE1YTg1P1gtWGV0LVNpZ25lZC1SYW5nZT1ieXRlcyUzRDAtNDQ1NjcyNDQmWC1YZXQtU2Vzc2lvbi1JZD0wMUtFMEdSOEU5MjI0UEtIVkVOVDNYMzhSOCIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NzQwMDIxNn19fV19&Signature=MEUCIQDoxuxHecshirG3lHn71OzN7ceXT7eHZIzHTU74R~GIhwIgeetWddnKT~q6~lwZ-J4~~cx~udaKWeaqXj9DBGS~ggk_&Key-Pair-Id=K3TGA0E1JYVXF7", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: nodename nor ser

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x17a80e8c0>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 06cb91f4-0ac4-4684-8b00-9684a742ee5e)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json
Retrying in 2s [Retry 2/5].


  [2m2026-01-02T23:33:38.382719Z[0m [33m WARN[0m  [33mReqwest(reqwest::Error { kind: Request, url: "https://transfer.xethub.hf.co/xorbs/default/7e86fc9a6eb0fa83c71c7daeeab37343ad7cb4711b821320810b33e1cce83a51?X-Xet-Signed-Range=bytes%3D0-45071040&X-Xet-Session-Id=01KE0GR8E9224PKHVENT3X38R8&Expires=1767400216&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly90cmFuc2Zlci54ZXRodWIuaGYuY28veG9yYnMvZGVmYXVsdC83ZTg2ZmM5YTZlYjBmYTgzYzcxYzdkYWVlYWIzNzM0M2FkN2NiNDcxMWI4MjEzMjA4MTBiMzNlMWNjZTgzYTUxP1gtWGV0LVNpZ25lZC1SYW5nZT1ieXRlcyUzRDAtNDUwNzEwNDAmWC1YZXQtU2Vzc2lvbi1JZD0wMUtFMEdSOEU5MjI0UEtIVkVOVDNYMzhSOCIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NzQwMDIxNn19fV19&Signature=MEUCIH3094PFRGRHlq0umS9b5P~CgMVJMubzoashaWU7nMOSAiEAsqJNN-62qEDk6uinJdof~zJWP980u8QNkV~zyURmN1I_&Key-Pair-Id=K3TGA0E1JYVXF7", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: nodename nor ser

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x314334d00>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 7c909d65-2827-4b03-9357-80f4e66aca71)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json
Retrying in 4s [Retry 3/5].
'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x314334340>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 137d3003-bf5c-4239-a3fc-e5d0c1399717)')' thrown while requesting HEAD https://huggingface.co

  [2m2026-01-02T23:33:43.440348Z[0m [33m WARN[0m  [33mReqwest(reqwest::Error { kind: Request, url: "https://transfer.xethub.hf.co/xorbs/default/c331e7eefc7660a633c577bc26c25e0c0105843d46f521e81cb6cd20583608a4?X-Xet-Signed-Range=bytes%3D0-44745431&X-Xet-Session-Id=01KE0GR8E9224PKHVENT3X38R8&Expires=1767400216&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly90cmFuc2Zlci54ZXRodWIuaGYuY28veG9yYnMvZGVmYXVsdC9jMzMxZTdlZWZjNzY2MGE2MzNjNTc3YmMyNmMyNWUwYzAxMDU4NDNkNDZmNTIxZTgxY2I2Y2QyMDU4MzYwOGE0P1gtWGV0LVNpZ25lZC1SYW5nZT1ieXRlcyUzRDAtNDQ3NDU0MzEmWC1YZXQtU2Vzc2lvbi1JZD0wMUtFMEdSOEU5MjI0UEtIVkVOVDNYMzhSOCIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NzQwMDIxNn19fV19&Signature=MEUCIDTWdggNSIhEZxzDTxBFQ2TWWvQSaiyO9mlLVFaCW1VqAiEAk29dXdzzTMYY4ECJaXzSXmXac~p4ufz-lfS~NSUUjbU_&Key-Pair-Id=K3TGA0E1JYVXF7", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: nodename nor ser

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x3143341c0>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 0a53ab54-f235-4b40-ad07-0f54804ab4ec)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json
Retrying in 8s [Retry 5/5].


  [2m2026-01-02T23:33:56.989401Z[0m [33m WARN[0m  [33mReqwest(reqwest::Error { kind: Request, url: "https://transfer.xethub.hf.co/xorbs/default/c5b02ee2caa8a656ac73ce3212ba2dc6f4cbbedc22fb47e15af1bddaaba151f8?X-Xet-Signed-Range=bytes%3D0-44605794&X-Xet-Session-Id=01KE0GR8E9224PKHVENT3X38R8&Expires=1767400215&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly90cmFuc2Zlci54ZXRodWIuaGYuY28veG9yYnMvZGVmYXVsdC9jNWIwMmVlMmNhYThhNjU2YWM3M2NlMzIxMmJhMmRjNmY0Y2JiZWRjMjJmYjQ3ZTE1YWYxYmRkYWFiYTE1MWY4P1gtWGV0LVNpZ25lZC1SYW5nZT1ieXRlcyUzRDAtNDQ2MDU3OTQmWC1YZXQtU2Vzc2lvbi1JZD0wMUtFMEdSOEU5MjI0UEtIVkVOVDNYMzhSOCIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc2NzQwMDIxNX19fV19&Signature=MEQCIGQZVmzordCAFIVI91h7CBneDoTWeGVl2F3xdCsXhjfBAiAutHfitWdCmVlZ-zk-1Nsrox1R1Vbco5LLjA9KmzR7NA__&Key-Pair-Id=K3TGA0E1JYVXF7", source: hyper_util::client::legacy::Error(Connect, ConnectError("dns error", Custom { kind: Uncategorized, error: "failed to lookup address information: nodename nor ser

'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x314334370>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: 720b7a28-165d-4660-b13f-a49bad61aab8)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/config.json
'(MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /openai/clip-vit-large-patch14/resolve/main/config.json (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x17a80eb30>: Failed to resolve \'huggingface.co\' ([Errno 8] nodename nor servname provided, or not known)"))'), '(Request ID: ed1737d9-f76a-486a-85da-49468c1c524e)')' thrown while requesting HEAD https://huggingface.co/openai/clip-vit-large-patch

In [9]:
def get_trainable_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    all_params = sum(p.numel() for p in model.parameters())

    print(f"Trainable parameters: {trainable_params}")
    print(f"All parameters: {all_params}")
    print(f"Percentage of trainable parameters: {100 * trainable_params / all_params:.2f}%")

    return trainable_params, all_params

trainable, total = get_trainable_parameters(model)

Trainable parameters: 427616513
All parameters: 427616513
Percentage of trainable parameters: 100.00%


## Model training

Here we train the model using the hyperparameters and all the information provided in the previous cells.

In [10]:
def get_peak_vram(device):
    if device == 'cuda':
        return torch.cuda.max_memory_allocated(device) / (1024 ** 3)  # Convert to GB
    else:
        return 0

In [11]:
loss_history = []
model.train()

if DEVICE == 'cuda':
    torch.cuda.reset_peak_memory_stats(DEVICE)

for epoch in range(EPOCHS):
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}")
    epoch_loss = 0
    
    for batch in pbar:
        optimizer.zero_grad()
        batch = {k: v.to(DEVICE) for k, v in batch.items()}
        
        outputs = model(
            input_ids=batch['input_ids'],
            pixel_values=batch['pixel_values'],
            attention_mask=batch['attention_mask'],
            return_loss=True
        )
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})
    
    loss_history.append(epoch_loss / len(dataloader))

peak_mem = get_peak_vram(DEVICE)
print(f"Peak VRAM usage during training: {peak_mem:.2f} GB")

Epoch 1: 100%|██████████| 52/52 [12:37<00:00, 14.57s/it, loss=1.29] 
Epoch 2: 100%|██████████| 52/52 [12:32<00:00, 14.47s/it, loss=0.395]
Epoch 3: 100%|██████████| 52/52 [12:32<00:00, 14.47s/it, loss=0.0211]

Peak VRAM usage during training: 9.15 GB





In [12]:
import time
import numpy as np

def calculate_inference_latency(model, dataloader, device, num_samples=50):
    model.eval()
    latencies = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= num_samples: break
            
            batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
            
            start_time = time.time()
            _ = model.get_image_features(pixel_values=batch['pixel_values'])
            end_time = time.time()
            
            latencies.append((end_time - start_time) * 1000)
            
    avg_latency = np.mean(latencies)
    print(f"Average Inference Latency: {avg_latency:.2f} ms")
    return avg_latency
latency = calculate_inference_latency(model, dataloader, DEVICE)

Average Inference Latency: 2109.88 ms


In [13]:
import torch
import torch.nn.functional as F
import numpy as np
from bert_score import score as bert_score_func
from sklearn.metrics import recall_score, f1_score, confusion_matrix

def extract_class_from_path(path):
    return "_".join(path.split("/")[-1].split("_")[:-3])

@torch.no_grad()
def get_comprehensive_metrics(model, dataloader, device):
    model.eval()
    all_image_embeds = []
    all_text_embeds = []
    all_ground_truth_texts = []

    for batch in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        img_emb = model.get_image_features(pixel_values=batch['pixel_values'])
        txt_emb = model.get_text_features(input_ids=batch['input_ids'], 
                                        attention_mask=batch['attention_mask'])
        
        all_image_embeds.append(F.normalize(img_emb, dim=-1))
        all_text_embeds.append(F.normalize(txt_emb, dim=-1))

    image_embeds = torch.cat(all_image_embeds)
    text_embeds = torch.cat(all_text_embeds)

    # Perspective 1: Text-to-Image Retrieval
    sim_matrix = text_embeds @ image_embeds.T
    
    num_queries = sim_matrix.size(0)
    ranks = []
    
    for i in range(num_queries):
        sorted_indices = torch.argsort(sim_matrix[i], descending=True)
        rank = (sorted_indices == i).nonzero(as_tuple=True)[0].item() + 1
        ranks.append(rank)
    
    ranks = np.array(ranks)

    # Perspective 2: Safety (Clinical Classification)
    # We pass an image, retrieve the best text and check if classes match.
    sim_matrix_i2t = image_embeds @ text_embeds.T
    
    # Map all text files in the dataset to their classes
    all_text_paths = df['text'].tolist()
    text_classes = np.array([extract_class_from_path(p) for p in all_text_paths])
    image_classes = np.array([extract_class_from_path(p) for p in df['text'].tolist()])

    top_text_indices = torch.argmax(sim_matrix_i2t, dim=-1).cpu().numpy()
    predicted_classes = text_classes[top_text_indices]

    sensitivity = recall_score(image_classes, predicted_classes, average='macro')
    f1 = f1_score(image_classes, predicted_classes, average='macro')

    cm = confusion_matrix(image_classes, predicted_classes)
    fp = cm.sum(axis=0) - np.diag(cm)
    fn = cm.sum(axis=1) - np.diag(cm)
    tp = np.diag(cm)
    tn = cm.sum() - (fp + fn + tp)
    specificity = np.mean(tn / (tn + fp + 1e-10))

    def get_text_content(path, client, bucket):
        response = client.get_object(Bucket=bucket, Key=path)
        return response['Body'].read().decode('utf-8').strip()

    sample_indices = np.random.choice(len(df), min(50, len(df)), replace=False)
    gt_texts = [get_text_content(df.iloc[i]['text'], minio_client, minio_bucket) for i in sample_indices]
    predicted_classes_texts = [get_text_content(df.iloc[top_text_indices[i]]['text'], minio_client, minio_bucket) for i in sample_indices]

    P, R, F1 = bert_score_func(predicted_classes_texts, gt_texts, lang='en', verbose=False)

    metrics = {
        # Perspective 1
        "Recall@1":  np.mean(ranks <= 1),
        "Recall@5":  np.mean(ranks <= 5),
        "Recall@10": np.mean(ranks <= 10),
        "Mean Rank": np.mean(ranks),
        "Median Rank": np.median(ranks),
        "MRR": np.mean(1.0 / ranks),
        "NDCG": np.mean([1.0 / np.log2(r + 1) for r in ranks]),
        # Perspective 2
        "Sensitivity": sensitivity,
        "Specificity": specificity,
        "F1 Score": f1,
        # Perspective 3
        "BertScore Precision": P.mean().item(),
        "BertScore Recall": R.mean().item(),
        "BERTScore F1": F1.mean().item(),
        "Trainable Parameters": trainable,
        "Total Parameters": total,
        "Inference Latency (ms)": latency,
        "Peak VRAM Usage (GB)": peak_mem
    }
    
    return metrics

eval_results = get_comprehensive_metrics(model, dataloader, DEVICE)
print("Evaluation Results:", eval_results)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Evaluation Results: {'Recall@1': np.float64(0.15903614457831325), 'Recall@5': np.float64(0.5036144578313253), 'Recall@10': np.float64(0.653012048192771), 'Mean Rank': np.float64(10.026506024096385), 'Median Rank': np.float64(5.0), 'MRR': np.float64(0.3208132351130705), 'NDCG': np.float64(0.46559744979785117), 'Sensitivity': 0.3849298997210845, 'Specificity': np.float64(0.8328873753800096), 'F1 Score': 0.34356208965893487, 'BertScore Precision': 0.8850935101509094, 'BertScore Recall': 0.8794289231300354, 'BERTScore F1': 0.8820855021476746, 'Trainable Parameters': 427616513, 'Total Parameters': 427616513, 'Inference Latency (ms)': np.float64(2109.8761320114136), 'Peak VRAM Usage (GB)': 9.151876449584961}


In [14]:
import json
import datetime
import matplotlib.pyplot as plt
os.makedirs('../results', exist_ok=True)

final_experiment_data = {
    "metadata": {
        "model_name": MODEL_ID,
        "device_used": DEVICE,
        "hyperparameters": {
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "epochs": EPOCHS
        }
    },
    "metrics": eval_results,
    "loss_history": loss_history
}

json_filename = f"../results/{EXPERIMENT_NAME}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
png_filename = f"../results/{EXPERIMENT_NAME}_loss_curve_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.png"

with open(json_filename, 'w') as f:
    json.dump(final_experiment_data, f, indent=4)

plt.figure(figsize=(10, 5))
plt.plot(loss_history, marker='o', linestyle='-', color='#2ca02c', label='Training Loss')
plt.title("Skin Cancer Model: Fine-Tuning Learning Curve")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.grid(True, alpha=0.3)
plt.legend()
plt.savefig(png_filename, bbox_inches='tight')
plt.close()

In [15]:
import json
import datetime
import io

RESULTS_BUCKET = "visualization-zone"

def setup_results_storage(client, bucket_name):
    try:
        client.head_bucket(Bucket=bucket_name)
        print(f"Bucket '{bucket_name}' already exists.")
    except:
        print(f"Creating bucket '{bucket_name}'...")
        client.create_bucket(Bucket=bucket_name)

def upload_experiment_assets(client, bucket_name, results_dir, json_file, png_file):
    try:
        client.head_bucket(Bucket=bucket_name)
    except:
        client.create_bucket(Bucket=bucket_name)
        print(f"Created bucket: {bucket_name}")

    assets = [json_file, png_file]
    
    for asset_name in assets:
        local_path = os.path.join(results_dir, asset_name)
        
        if os.path.exists(local_path):
            # We store them in a folder named after the run_id for the viz page
            object_key = f"{EXPERIMENT_NAME}/{asset_name}"
            try:
                client.upload_file(local_path, bucket_name, object_key)
                print(f"Successfully uploaded {asset_name} to {object_key}")
            except Exception as e:
                print(f"Failed to upload {asset_name}: {e}")
        else:
            print(f"Warning: Asset not found at {local_path}")


json_file = os.path.basename(json_filename)
png_file = os.path.basename(png_filename)

setup_results_storage(minio_client, RESULTS_BUCKET)
upload_experiment_assets(minio_client, RESULTS_BUCKET, "../results", json_file, png_file)

Bucket 'visualization-zone' already exists.
Successfully uploaded baseline_20260102_140258.json to baseline/baseline_20260102_140258.json
Successfully uploaded baseline_loss_curve_20260102_140258.png to baseline/baseline_loss_curve_20260102_140258.png
