## Dataset

In [None]:
# !pip install torch torchvision ftfy regex tqdm pillow
# !pip install git+https://github.com/openai/CLIP.git

from ddgs import DDGS
import os
import requests
import imagehash
from PIL import Image
import glob
import os
import os, shutil, glob
import torch
import clip
from PIL import Image
from tqdm import tqdm


def collect_url_and_download_images(specie: str, max_pages: int=50, out_folder: str='dataset'):
    outdir = f"{out_folder}/{specie}/raw"
    os.makedirs(outdir, exist_ok=True)
    
    results = []
    for query in [specie, f"dinosaur {specie}"]:
        for i in range(max_pages):
            results.extend(DDGS().images(
                query=query,
                region="us-en",
                safesearch="off",
                max_results=1000,
                page=i))
            
    unique = {result["image"] for result in results}
    for idx, url in enumerate(unique, start=1):
        try:
            r = requests.get(url, timeout=20, headers={"User-Agent": "Mozilla/5.0"})
            r.raise_for_status()
            ext = os.path.splitext(url.split("?")[0])[1] or ".jpg"
            filename = os.path.join(outdir, f"{idx:04d}{ext}")
            with open(filename, "wb") as f:
                f.write(r.content)
            print(f"Downloaded {filename}")
        except Exception as e:
            print(f"Failed {url} -> {e}")\
                
    return outdir

def remove_duplicate_images(folder_path: str, similarity_threshold: int = 4):
    """
    Remove duplicate/similar images using perceptual hashing.
    
    Args:
        folder_path: Path to folder containing images
        similarity_threshold: Lower = more strict (0=identical, 5=default, 10+=lenient)
    """
    # Get all image files
    image_files = glob.glob(os.path.join(folder_path, "*"))
    image_files = [f for f in image_files if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'))]
    
    if not image_files:
        print("No images found in folder")
        return
    
    print(f"Checking {len(image_files)} images for duplicates...")
    
    hashes = {}
    duplicates = []
    
    for img_path in image_files:
        try:
            with Image.open(img_path) as img:
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                img_hash = imagehash.phash(img)
                
                # Check if similar image already exists
                for existing_hash in hashes:
                    if abs(img_hash - existing_hash) <= similarity_threshold:
                        duplicates.append(img_path)
                        print(f"Duplicate: {os.path.basename(img_path)}")
                        break
                else:
                    hashes[img_hash] = img_path
                    
        except Exception as e:
            print(f"Error processing {os.path.basename(img_path)}: {e}")
            print(f"Deleting {os.path.basename(img_path)}")
            os.remove(img_path)
    
    # Remove duplicates
    for duplicate in duplicates:
        os.remove(duplicate)
    
    print(f"Removed {len(duplicates)} duplicates. {len(image_files) - len(duplicates)} unique images remain.")
    
@torch.no_grad()
def score_image(img_path, model, preprocess, device, 
                TXT_IS_DINO, TXT_NOT_DINO, TXT_REAL, TXT_NONREAL,
                is_dino_prompts, not_dino_prompts, realistic_prompts, non_realistic_prompts,
                verbose=False):
    try:
        image = preprocess(Image.open(img_path).convert("RGB")).unsqueeze(0).to(device)
    except Exception:
        return None, None

    img_feat = model.encode_image(image)
    img_feat /= img_feat.norm(dim=-1, keepdim=True)

    def logits_for(prompts, txt_feats):
        # cosine sims
        sims = (img_feat @ txt_feats.T)  # [1, num_prompts]
        # apply CLIP temperature (sharpens distribution)
        scale = model.logit_scale.exp()
        return sims * scale

    # Get logits for each category
    L_is   = logits_for(is_dino_prompts, TXT_IS_DINO)
    L_not  = logits_for(not_dino_prompts, TXT_NOT_DINO)
    L_real = logits_for(realistic_prompts, TXT_REAL)
    L_non  = logits_for(non_realistic_prompts, TXT_NONREAL)

    # 1) Dinosaur classification: max positive vs max negative
    max_is_dino = L_is.max().item()
    max_not_dino = L_not.max().item()
    
    # Softmax between the two max logits
    dino_logits = torch.tensor([max_is_dino, max_not_dino])
    dino_probs = torch.softmax(dino_logits, dim=0)
    p_is_dino = dino_probs[0].item()
    p_not_dino = dino_probs[1].item()

    # 2) Realism classification: max positive vs max negative
    max_real = L_real.max().item()
    max_non_real = L_non.max().item()
    
    # Softmax between the two max logits
    real_logits = torch.tensor([max_real, max_non_real])
    real_probs = torch.softmax(real_logits, dim=0)
    p_real = real_probs[0].item()
    p_non_real = real_probs[1].item()

    # margins in [âˆ’1, 1]; 0 means tie, >0 favors positives
    is_dino_margin = p_is_dino - p_not_dino
    realism_margin = p_real - p_non_real

    if verbose:
        print(f"\nðŸ‘‰ {os.path.basename(img_path)}")
        print("  [DINO LOGITS]")
        for i, (s, logit) in enumerate(zip(is_dino_prompts, L_is.squeeze(0).tolist())):
            mark = "â˜…" if i == L_is.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [NOT_DINO LOGITS]")
        for i, (s, logit) in enumerate(zip(not_dino_prompts, L_not.squeeze(0).tolist())):
            mark = "â˜…" if i == L_not.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [REALISTIC LOGITS]")
        for i, (s, logit) in enumerate(zip(realistic_prompts, L_real.squeeze(0).tolist())):
            mark = "â˜…" if i == L_real.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print("  [NON_REALISTIC LOGITS]")
        for i, (s, logit) in enumerate(zip(non_realistic_prompts, L_non.squeeze(0).tolist())):
            mark = "â˜…" if i == L_non.argmax().item() else " "
            print(f"    {mark} {s:<50} {logit:.3f}")
        print(f"  Max dino: {max_is_dino:.3f} vs Max not-dino: {max_not_dino:.3f}")
        print(f"  Max real: {max_real:.3f} vs Max non-real: {max_non_real:.3f}")
        print(f"  p_is_dino={p_is_dino:.3f}  p_real={p_real:.3f}")
        print(f"  is_dino_margin={is_dino_margin:.3f}  realism_margin={realism_margin:.3f}")

    return is_dino_margin, realism_margin

def filter_folder(in_dir, out_good, out_bad, specie,
                  dino_thr=0.20, realism_thr=0.15, verbose=False):
    
    os.makedirs(out_good, exist_ok=True)
    os.makedirs(out_bad, exist_ok=True)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)

    # ---------- PROMPTS ----------
    # 1) "Ãˆ un dinosauro?"
    is_dino_prompts = [
        "a realistic illustration of a dinosaur",
        "a realistic toy dinosaur figure",
        "a photo of a dinosaur",
        "a paleoart illustration of a dinosaur",
        f"a realistic illustration of a {specie}",
        f"a realistic toy {specie} figure",
        f"a photo of a {specie}",
        f"a paleoart illustration of a {specie}",
    ]
    not_dino_prompts = [
        "a photo of a modern animal",
        "a person or human",
        "a landscape without animals",
        "a vehicle or building",
        "a toy",
        "a cloth",
        "an abstract image",
        "a geometric figure",
    ]


    # 2) Realistico vs Cartoon/Toy/Fossile
    realistic_prompts = [
        "a realistic illustration of a dinosaur",
        "a dinosaur fossil skeleton",
        "a realistic toy dinosaur figure",
        "a high-quality render of a dinosaur",
        
        f"a realistic illustration of a {specie}",
        f"a {specie} fossil skeleton",
        f"a realistic toy {specie} figure",
        f"a high-quality render of a {specie}",
    ]
    non_realistic_prompts = [
        "a cartoon dinosaur for kids",
        "a peluche toy dinosaur figure",
        "a pixel art dinosaur",
        "a simple line drawing of a dinosaur",
        "a non realistic drawing of a dinosaur"
    ]

    @torch.no_grad()
    def encode_text_batch(prompts):
        tokens = clip.tokenize(prompts).to(device)
        text_features = model.encode_text(tokens)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    TXT_IS_DINO = encode_text_batch(is_dino_prompts)
    TXT_NOT_DINO = encode_text_batch(not_dino_prompts)
    TXT_REAL = encode_text_batch(realistic_prompts)
    TXT_NONREAL = encode_text_batch(non_realistic_prompts)

    imgs = []
    for ext in ("*.jpg","*.jpeg","*.png","*.webp","*.bmp"):
        imgs += glob.glob(os.path.join(in_dir, ext))
    kept = 0; rejected = 0

    for p in tqdm(imgs, desc=f"Filtering {os.path.basename(in_dir)}"):
        res = score_image(p, model, preprocess, device, 
                        TXT_IS_DINO, TXT_NOT_DINO, TXT_REAL, TXT_NONREAL,
                        is_dino_prompts, not_dino_prompts, realistic_prompts, non_realistic_prompts, 
                        verbose=verbose)
        if res is None:
            rejected += 1
            shutil.copy(p, os.path.join(out_bad, os.path.basename(p)))
            continue
        is_dino_m, realism_m = res

        if (is_dino_m >= dino_thr) and (realism_m >= realism_thr):
            kept += 1
            shutil.copy(p, os.path.join(out_good, os.path.basename(p)))
        else:
            rejected += 1
            shutil.copy(p, os.path.join(out_bad, os.path.basename(p)))

    return kept, rejected



In [None]:
species = [
    "Ankylosaurus",
    "Brachiosaurus",
    "Compsognathus",
    "Corythosaurus",
    "Dilophosaurus",
    "Dimorphodon",
    "Gallimimus",
    "Microceratus",
    "Pachycephalosaurus",
    "Parasaurolophus",
    "Spinosaurus",
    "Stegosaurus",
    "Triceratops",
    "Tyrannosaurus",
    "Velociraptor"
]

for specie in species:
    print(f"processing dinosaur {specie}")
    specie_raw_dir = collect_url_and_download_images(specie=specie)
    remove_duplicate_images(folder_path=specie_raw_dir)
    print(f"filtering dinosaur {specie}")
    kept, rej = filter_folder(
        specie_raw_dir,
        out_good=os.path.join(specie_raw_dir, "clean"),
        out_bad=os.path.join(specie_raw_dir,"rejected"),
        specie=specie,
        dino_thr=0.20,
        realism_thr=0.20)
    print(f"Specie: {specie}. Kept {(kept)}, ejected {(rej)}")

In [None]:
for dinosaurs in species:
    print(f'dino: {dinosaurs}')
    folder = f'./dataset/{dinosaurs}/raw/clean'
    out_folder = f'dataset/to_phone/{dinosaurs}'
    #os.makedirs(out_folder)
    for file_name in tqdm(os.listdir(folder)):
        source = os.path.join(folder, file_name)
        destination = os.path.join(out_folder, file_name)
        if os.path.isfile(source):
            shutil.copy(source, destination)
    

### Dataset Split

In [None]:
import random
import os
from PIL import Image

curated_path = './dataset/hand_curated_datasets'
final_dataset = './dataset/dataset'
seed = 42
random.seed(seed)

for species_sub_folder in os.listdir(curated_path):
    print(species_sub_folder)
    species_images = os.listdir(os.path.join(curated_path,species_sub_folder))
    species_images = [image for image in species_images if not image.startswith('.')] 
    indices = random.sample(range(len(species_images)), int(0.15*len(species_images)))
    os.makedirs(os.path.join(final_dataset, 'train', species_sub_folder), exist_ok=True)
    os.makedirs(os.path.join(final_dataset, 'test', species_sub_folder), exist_ok=True)
    
    for idx, image in enumerate(species_images):
        path_to_image = os.path.join(curated_path, species_sub_folder, image)
        split = 'test' if idx in indices else 'train'
        dst_path = os.path.join(final_dataset, split, species_sub_folder, image)
        try:
            img = Image.open(path_to_image)
            png_filename = os.path.splitext(image)[0] + '.png'
            dst_path = os.path.join(final_dataset, split, species_sub_folder, png_filename)
            img.save(dst_path, 'PNG')
        except Exception as e:
            print(f"Error processing {image}: {e}")
        

### Visualize Images

In [None]:
import random
import os
import shutil
from PIL import Image
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.models import ResNet18_Weights, resnet18
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Use ImageNet normalization that matches the pretrained weights
weights = ResNet18_Weights.DEFAULT
base_tfms = weights.transforms()
size = 256

def visualize_image(image_tensor):
    """
    Visualize a PyTorch image tensor
    
    Args:
    - image_tensor: A torch tensor of shape [3, 224, 224] or [1, 3, 224, 224]
    """
    if image_tensor.dim() == 4:
        image_tensor = image_tensor.squeeze(0)
    
    image_np = image_tensor.cpu().numpy()

    image_np = np.transpose(image_np, (1, 2, 0))
    
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    image_np = image_np * std + mean
    
    # Clip values to 0-1 range
    image_np = np.clip(image_np, 0, 1)
    
    # Plot the image
    plt.figure(figsize=(10, 10))
    plt.imshow(image_np)
    plt.axis('off')
    plt.show()

def letterbox_to_square(img, size=256, fill=0):
    w, h = img.size
    scale = size / max(w, h)
    new_w, new_h = int(round(w * scale * 0.8)), int(round(h * scale * 0.8))
    img = F.resize(img, (new_h, new_w), antialias=True)
    
    pad_left   = (size - new_w) // 2
    pad_right  = size - new_w - pad_left
    pad_top    = (size - new_h) // 2
    pad_bottom = size - new_h - pad_top
    img = F.pad(img, [pad_left, pad_top, pad_right, pad_bottom], fill=fill)
    return img

train_tfms = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),
    transforms.Lambda(lambda im: letterbox_to_square(im, size=size, fill=0)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),                 # (or 244 if you really want)
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    base_tfms,                                  # ToTensor + Normalize
])

test_tfms = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),
    transforms.Lambda(lambda im: letterbox_to_square(im, size=size, fill=0)),
    transforms.CenterCrop(224),
    base_tfms,
])

root = "./dataset/dataset"

train_ds = datasets.ImageFolder(root=f"{root}/train", transform=train_tfms)
test_ds  = datasets.ImageFolder(root=f"{root}/test",  transform=test_tfms)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
# Pick random samples and show them
fig, axes = plt.subplots(1, 4, figsize=(12, 4))
for ax in axes:
    img, label = random.choice(test_ds)
    # img is a tensor; unnormalize for display
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    img_disp = img * std[:, None, None] + mean[:, None, None]
    img_disp = img_disp.clamp(0, 1).permute(1, 2, 0).numpy()
    print(img.shape)
    ax.imshow(img_disp)
    ax.set_title(test_ds.classes[label])
    ax.axis("off")

plt.tight_layout()
plt.show()

## Experiments

In [None]:
print('ciao')

In [None]:
import torch
from torch import nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)

model = model.cuda()
model.train()

In [None]:
for named,p in model.named_parameters():
    print(named)

In [None]:
import torch
from lora_pytorch import LoRA

# Replace head
num_classes = 8
in_features = model.heads.head.in_features
lora_model = LoRA.from_module(model, rank=5)
lora_model.module.heads.head = nn.Linear(in_features, num_classes)
print(lora_model)

In [None]:
for name, p in lora_model.named_parameters():
    print(name, p.shape)#  if "lora" in name else None

In [None]:
from torchvision.models import ResNet50_Weights, resnet50
# Initialize model
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights)

# Set model to eval mode
model.eval()

In [None]:
import torch
from lora_pytorch import LoRA

# Wrap your model with LoRA
lora_model = LoRA.from_module(model, rank=5)

print(lora_model)

In [None]:
# Initialize the Weight Transforms
weights = ResNet34_Weights.DEFAULT
preprocess = weights.transforms()
preprocess

In [None]:
import torch.nn as nn

class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()

        self.features = nn.Sequential(*list(model.children())[:-1])
    
    def forward(self, x):
        return self.features(x).squeeze()

feature_extractor = FeatureExtractor(model)
feature_extractor.eval()

In [None]:
# Initialize model
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)

# Set model to eval mode
model.eval()

In [None]:
from torchvision.models import EfficientNet_V2_S_Weights, efficientnet_v2_s

# Initialize model
weights = EfficientNet_V2_S_Weights.DEFAULT
model = efficientnet_v2_s(weights=weights)

# Set model to eval mode
model.eval()

In [None]:
import torch.nn as nn

class FeatureExtractor(nn.Module):
    def __init__(self, model):
        super().__init__()
        # Rimuovi il layer FC finale
        self.features = nn.Sequential(*list(model.children())[:-1])
    
    def forward(self, x):
        return self.features(x).squeeze()

# Usa il feature extractor
feature_extractor = FeatureExtractor(model)
feature_extractor.eval()

In [None]:
img, label = random.choice(train_ds)
img = img.unsqueeze(0)
visualize_image(img)

In [None]:
batch = next(iter(val_loader))

with torch.no_grad():
    x, y = batch
    features = feature_extractor(x)  # Shape: [batch_size, 512]

In [None]:
features.shape

In [None]:
from RexNet import RexNet

config = OmegaConf.load('config/config_rexnet.yaml')

model = RexNet(config=config.model, num_classes=config.num_classes)

# Carica direttamente dalla classe
model = RexNet.load_from_checkpoint(
    r'C:\Users\nicco\OneDrive\Documenti\JurassicClass\models\cross_validation\last_block_finetune\RexNet\fold_1\best-epoch=11-val_loss=0.5470-val_acc=0.8599.ckpt',
    config=config.model , # Devi passare gli argomenti non salvati    
    strict=False  # Ignora i pesi mancanti/extra
)
model.eval()

In [None]:
batch = next(iter(val_loader))

with torch.no_grad():
    x, y = batch
    features = feature_extractor(x)  # Shape: [batch_size, 512]

In [None]:
features.shape

In [None]:
prediction = model(img).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

#### Prova Training

In [None]:
def evaluate_model(model, dataloader, idx_to_class, device="cpu"):
    model.eval()
    model.to(device)

    results = {
        "scores": [],
        "predicted_names": [],
        "real_names": [],
        "correct": [],
        "overall_accuracy": None,
    }

    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x).softmax(1)
            predicted_class_id = pred.argmax(1)

            # Scores for predicted classes
            idx = torch.arange(len(predicted_class_id))
            score = pred[idx, predicted_class_id]

            # Per-image correctness
            correct = (predicted_class_id == y).float()

            # Move to CPU for Python conversion
            pred_cpu = predicted_class_id.cpu()
            y_cpu = y.cpu()
            score_cpu = score.cpu()
            correct_cpu = correct.cpu()

            # Append to results
            results["scores"].extend(score_cpu.tolist())
            results["predicted_names"].extend([idx_to_class[int(c)] for c in pred_cpu])
            results["real_names"].extend([idx_to_class[int(c)] for c in y_cpu])
            results["correct"].extend(correct_cpu.tolist())

            # Update aggregated accuracy
            total_correct += int(correct_cpu.sum().item())
            total_samples += len(y_cpu)

    # Final overall accuracy
    results["overall_accuracy"] = total_correct / total_samples

    return results


def run_train(model, train_loader, val_loader, config, out_fold, in_fold, lora):
    
    early_stopping_callback = EarlyStopping(
        monitor="val_loss", 
        mode="min",
        patience=8)
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=f"./models/cross_validation/{config.experiment_name}/{config.model.name}/fold_{out_fold}.{in_fold}",
        filename="best-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}",  # Include both loss and accuracy
        save_top_k=1,
        monitor="val_loss",
        mode="min"
    )
    
    tb_logger = TensorBoardLogger(
        save_dir=f"./models/cross_validation/{config.experiment_name}/{config.model.name}",  # parent dir
        name=f"tb_logs",                                            # subfolder name
        version=f"fold_{out_fold}.{in_fold}",                                     # unique run per fold
        default_hp_metric=False                                      # optional: disable hp metric
    )

    trainer = Trainer(
        default_root_dir=f"./models/cross_validation/{config.experiment_name}/{model.model_name}/fold_{out_fold}.{in_fold}",
        logger=tb_logger,
        callbacks=[early_stopping_callback, checkpoint_callback],
        enable_checkpointing=True,
        log_every_n_steps=10,
        limit_train_batches=0.25,
        limit_val_batches=0.25,
        max_epochs=4, # max_epochs=config.training.max_epochs,
    )
    
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, )
    
    OmegaConf.save(config, f"./models/cross_validation/{config.experiment_name}/{config.model.name}/fold_{out_fold}.{in_fold}/config.yaml")
    
    return checkpoint_callback.best_model_path
    

def run_test(model_class, ckpt_path, test_loader, idx_to_class, config, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model_class.load_from_checkpoint(
        ckpt_path,
        config=config,
        strict=False
    )

    results = evaluate_model(model, test_loader, idx_to_class, device=device)

    ckpt_dir = os.path.dirname(ckpt_path)
    pkl_path = os.path.join(ckpt_dir, "test_results.pkl")
    with open(pkl_path, "wb") as f:
        pickle.dump(results, f)

    return results


def run_outer_loop(train_ds, val_ds, config, network='EfficentRex', n_out=1, n_in=1, lora=False, rank=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    kfold_out = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    acc_by_params = {}

    experiment_name = "lora"

    # outer loop
    for out_fold, (train_idx_out, test_idx_out) in enumerate(kfold_out.split(train_ds, train_ds.targets)):
        if out_fold > n_out:
            break
        print(f'============ out fold {out_fold} ============')
        print(f'====================================')
        train_subset_out = Subset(train_ds, train_idx_out)
        test_subset_out = Subset(val_ds, test_idx_out)

        # print(f"\tOut train subset size: {len(train_subset_out)},\n\tOut test subset size: {len(test_subset_out)}")

        outer_train_labels = [train_ds.targets[i] for i in train_idx_out]
        kfold_in = StratifiedKFold(n_splits=5, shuffle=True, random_state=43)

        # inner loop
        for in_fold, (train_idx_in, val_idx_in) in enumerate(kfold_in.split(train_subset_out, outer_train_labels)):
            if in_fold > n_in:
                break
            print(f'\t============ in fold {in_fold} ============')

            # map indxs of train_subset_out on the train_ds
            train_indices = [train_idx_out[i] for i in train_idx_in]
            val_indices   = [train_idx_out[i] for i in val_idx_in]

            train_subset_in = Subset(train_ds, train_indices)
            val_subset_in = Subset(val_ds, val_indices)

            # print(f"\t\tIn train subset size: {len(train_subset_in)},\n\t\tIn validation subset size: {len(val_subset_in)}")

            base_cfg = OmegaConf.load(f'config/config_{network.lower()}.yaml')

            train_loader = DataLoader(train_subset_in, batch_size=config.training.batch_size, shuffle=True, num_workers=0)
            val_loader = DataLoader(val_subset_in, batch_size=config.training.batch_size, shuffle=False, num_workers=0)
            test_loader = DataLoader(test_subset_out, batch_size=config.training.batch_size, shuffle=False, num_workers=0)

            if network.lower() == 'rexnet':
                model_class = RexNet
            elif network.lower() == 'efficentrex':
                model_class = EfficentRex

            param_grid = {
                "classifier_lr": [5e-3, 3e-3],
                "factor": [5, 10],
            }
            
            for params in ParameterGrid(param_grid):
                lr_cls = params["classifier_lr"]
                factor = params["factor"]
                
                if lr_cls == 5e-3 and factor == 10:
                    continue

                cfg = copy.deepcopy(base_cfg)

                layers_dict = cfg["model"]["layers_to_finetune"]
                first_layer_name = list(layers_dict.keys())[0]
                layers_dict[first_layer_name]["lr"] = lr_cls

                backbone_lr = lr_cls / factor
                for layer_name in list(layers_dict.keys())[1:]:
                    layers_dict[layer_name]["lr"] = backbone_lr

                cfg["experiment_name"] = experiment_name

                cfg["model"]["name"] = f"{cfg['model']['name']}_firstlr{lr_cls:g}_fac{factor:g}"
            
                model = model_class(config=cfg.model, num_classes=cfg.num_classes)
                if lora:
                    model = LoRA.from_module(model, rank=rank)
                model.eval()
                model.to(device)
            
                ckpt_path = run_train(model, train_loader, val_loader, cfg, out_fold, in_fold, lora)
                
                idx_to_class = {idx: class_ for class_, idx in train_ds.class_to_idx.items()}

                results = run_test(
                    model_class=model_class,
                    ckpt_path=ckpt_path,
                    test_loader=test_loader,
                    idx_to_class=idx_to_class,
                    config=cfg.model,
                    device=device,
                )

                acc = results["overall_accuracy"]
                
                print(f"\t\tAccuracy on test set: {acc}")

                key = (lr_cls, factor)
                if key not in acc_by_params:
                    acc_by_params[key] = []
                acc_by_params[key].append(acc)


    # cartella dove salvare il CSV (ad es. sperimentazione globale)
    csv_dir = f"./models/cross_validation/{experiment_name}"
    os.makedirs(csv_dir, exist_ok=True)
    csv_path = os.path.join(csv_dir, "grid_search_results.csv")

    with open(csv_path, mode="w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["classifier_lr", "factor", "mean_test_accuracy", "std_test_accuracy", "n_runs"])

        for (lr_cls, factor), acc_list in acc_by_params.items():
            mean_acc = float(np.mean(acc_list))
            std_acc = float(np.std(acc_list))
            n_runs = len(acc_list)
            writer.writerow([lr_cls, factor, mean_acc, std_acc, n_runs])

    print(f"\nSaved grid-search summary to: {csv_path}\n")


In [None]:
import numpy as np
from argparse import ArgumentParser
from omegaconf import OmegaConf
import random
import copy
import pickle
import csv

import torch
from torchvision.models import ResNet18_Weights, EfficientNet_V2_S_Weights
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold, StratifiedKFold, ParameterGrid
from torch.utils.data import Subset
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
from lightning import Trainer, Callback
from lightning.pytorch.loggers import TensorBoardLogger
from lora_pytorch import LoRA

from scripts.utils import visualize_image, letterbox_to_square
from scripts.EfficentRex import EfficentRex
from scripts.RexNet import RexNet
from scripts.ViTRex import ViTRex_FullFT
from scripts.RexNet_FullFT import RexNet_FullFT
from scripts.LoRaViTRex import LoRaViTRex

In [None]:
config = OmegaConf.load('config/config_loravitrex.yaml')
model = LoRaViTRex(config=config.model, num_classes=config.num_classes)

In [None]:
config = OmegaConf.load('config/config_rexnet_fft.yaml')
model = RexNet_FullFT(config=config.model, num_classes=config.num_classes)

In [None]:
config = OmegaConf.load('config/config_vitrex_fft.yaml')
model = ViTRex_FullFT(config=config.model, num_classes=config.num_classes)

In [None]:
train_tfms = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),
    transforms.Lambda(lambda im: letterbox_to_square(im, size=config.size, fill=0)),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    model.base_tfms,                                  # ToTensor + Normalize
])

val_tfms = transforms.Compose([
    transforms.Lambda(lambda im: im.convert("RGB")),
    transforms.Lambda(lambda im: letterbox_to_square(im, size=config.size, fill=0)),
    transforms.CenterCrop(224),
    model.base_tfms,
])


In [None]:

root = "./dataset/dataset"

train_ds = datasets.ImageFolder(root=f"{root}/train", transform=train_tfms)
val_ds = datasets.ImageFolder(root=f"{root}/train", transform=val_tfms)  # Same data, different transforms

g = torch.Generator().manual_seed(42)

perm = torch.randperm(len(train_ds), generator=g).tolist()
split = int(0.8 * len(train_ds))

train_subset = Subset(train_ds, perm[:split])
test_subset = Subset(val_ds, perm[split:])


In [None]:
img, label = random.choice(train_subset)
img = img.unsqueeze(0)
visualize_image(img)

In [None]:
train_loader = DataLoader(train_subset, batch_size=config.training.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(test_subset, batch_size=config.training.batch_size, shuffle=False, num_workers=0)

early_stopping_callback = EarlyStopping(
    monitor="val_loss", 
    mode="min",
    patience=12)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"./models/prova/{config.model.name}",
    filename="best-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}",  # Include both loss and accuracy
    save_top_k=1,
    monitor="val_loss",
    mode="min"
)

tb_logger = TensorBoardLogger(
    save_dir=f"./models/prova/{config.model.name}",  # parent dir
    name=f"tb_logs",                                            # subfolder name
    version=f"lora_l",                                     # unique run per fold
    default_hp_metric=False                                      # optional: disable hp metric
)

trainer = Trainer(
    default_root_dir=f"./models/prova/{model.model_name}",
    logger=tb_logger,
    # limit_train_batches = 10,
    callbacks=[early_stopping_callback, checkpoint_callback],
    max_epochs=config.training.max_epochs,
    enable_checkpointing=True,
    log_every_n_steps=10
)

trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
train_loader = DataLoader(train_subset, batch_size=config.training.batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(test_subset, batch_size=config.training.batch_size, shuffle=False, num_workers=0)

early_stopping_callback = EarlyStopping(
    monitor="val_loss", 
    mode="min",
    patience=12)

checkpoint_callback = ModelCheckpoint(
    dirpath=f"./models/prova/{config.model.name}",
    filename="best-{epoch:02d}-{val_loss:.4f}-{val_acc:.4f}",  # Include both loss and accuracy
    save_top_k=1,
    monitor="val_loss",
    mode="min"
)

tb_logger = TensorBoardLogger(
    save_dir=f"./models/prova/{config.model.name}",  # parent dir
    name=f"tb_logs",                                            # subfolder name
    version=f"long_warmup_2",                                     # unique run per fold
    default_hp_metric=False                                      # optional: disable hp metric
)

trainer = Trainer(
    default_root_dir=f"./models/prova/{model.model_name}",
    logger=tb_logger,
    # limit_train_batches = 10,
    callbacks=[early_stopping_callback, checkpoint_callback],
    max_epochs=config.training.max_epochs,
    enable_checkpointing=True,
    log_every_n_steps=10
)

trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)