# **2. Model Fine-Tuning with Triplet Loss**

**Project:** FashionCLIP (The Seeker)
**Author:** [Tu Nombre]
**Goal:** This notebook uses the dataset created in the previous step to fine-tune a CLIP-based `ImageEncoderNetwork` using a custom triplet loss function.

---

### **Overview**

The goal is to train our model to map images into a high-dimensional vector space where the distance between vectors reflects semantic similarity. An anchor image's embedding should be closer to its positive partner's than to its negative partner's.

This is achieved through the following steps:
1.  **Setup**: Import libraries and configure paths.
2.  **Data Loading**: Load the generated dataset and prepare it for PyTorch.
3.  **Model & Loss**: Initialize the `ImageEncoderNetwork` and our custom `TripletSemiPosMarginWithDistanceLoss`.
4.  **Training**: Execute the training loop, validate on a hold-out set, and save the best model.

In [3]:
import sys
import os
import torch
import torch.nn as nn
# from dataclasses import dataclass, field

from transformers import (
    AutoImageProcessor,
    AutoModel)

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from src.models.models import ImageEncoderNetwork, learning_loop



current_path = os.getcwd()
dir = current_path.split('/')[-1]
# proof of mac!
if dir in ['research', 'dataprep', 'src']:
    current_path = '/'.join(current_path.split('/')[:-1])
PROJECT_PATH = current_path

### **2.1. Data Loading and Preparation**

Here, we load the dataset created by `01_generate_triplet_input.ipynb`. We then create PyTorch `DataLoader` objects for the training and validation sets. These loaders will handle batching, shuffling, and feeding the data to the GPU efficiently.

In [None]:
from datasets import load_from_disk
from src.models.datasets import mac_path_format

dataset_path = f'{PROJECT_PATH}/data/datasets/uncropped_triplet_toy_semipos'
dataset = load_from_disk(dataset_path)

if device == 'mps':
    dataset = dataset.map(mac_path_format)
# dataset[0

# Will show something like this
# {'anchor': 'path_to_anchor.png',
#  'pos': 'path_to_pos_image.png',
#  'neg': 'path_to_neg_image.png',
#  'semipos': 0.08,
#  'caption': 'than it otherwise might."'}

In [None]:
from PIL import Image

from src.models.utils import display_triplet
# display sample of images (hidden)    
# display_triplet(dataset, 1305)

In [9]:
from datasets import load_from_disk
from src.models.datasets import load_images

# dataset = dataset.map(load_images, 
#                       fn_kwargs={'image_processor': image_processor}).with_format("torch")
# dataset = dataset.train_test_split(test_size=0.2)

# dataset.save_to_disk(f'{PROJECT_PATH}/data/datasets/uncropped_final')
dataset = load_from_disk(f'{PROJECT_PATH}/data/datasets/uncropped_final')
dataset

                                                                                            

DatasetDict({
    train: Dataset({
        features: ['anchor', 'pos', 'neg', 'semipos', 'caption', 'anchor_image', 'pos_image', 'neg_image'],
        num_rows: 1216
    })
    test: Dataset({
        features: ['anchor', 'pos', 'neg', 'semipos', 'caption', 'anchor_image', 'pos_image', 'neg_image'],
        num_rows: 304
    })
})

In [None]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset['train'], batch_size=64, shuffle=True)
val_dataloader = DataLoader(dataset['test'], batch_size=64, shuffle=True)
# next(iter(train_dataloader))

In [9]:
b = next(iter(train_dataloader))
b.keys()

dict_keys(['anchor', 'pos', 'neg', 'semipos', 'caption', 'anchor_image', 'pos_image', 'neg_image'])

### **2.3. Model, Optimizer, and Loss Function**

This is the core of our training setup:

1.  **Model**: We initialize `ImageEncoderNetwork`, which uses a pre-trained CLIP vision model as its backbone. We load the official weights from `openai/clip-vit-base-patch32` to leverage transfer learning.
2.  **Loss Function**: We use our custom `TripletSemiPosMarginWithDistanceLoss` from `src/losses.py`. This loss is critical for teaching the model the desired embedding structure.
3.  **Optimizer**: We use the AdamW optimizer, a standard choice for transformer-based models.
4.  **Scheduler**: A learning rate scheduler (`ReduceLROnPlateau`) is used to decrease the learning rate if the validation loss stops improving, helping the model to converge to a better minimum.

In [4]:
device = 'mps' if torch.backends.mps.is_available()  else 'cuda' if torch.cuda.is_available() else 'cpu'

model_path = f'{PROJECT_PATH}/data/models/OCR_clip-roberta-finetuned'
clip = AutoModel.from_pretrained(model_path)

image_processor = AutoImageProcessor.from_pretrained(model_path)
# processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)
model = ImageEncoderNetwork()
model.load_from_clip(clip)
model = model.to(device)
del clip

### **2.4. The Training Loop**

The `learning_loop` function encapsulates the entire training process. For each epoch, it performs a full pass over the training data to update the model's weights and then evaluates the model on the validation set.

Key features of this loop:
-   **Validation**: After each epoch, performance is measured on the validation set.
-   **Early Stopping**: The loop monitors the validation loss and will stop training if it fails to improve for a set number of `patience` epochs, preventing overfitting.
-   **Model Checkpointing**: The version of the model with the best validation score is saved to `best.pt`. This ensures we always keep the best-performing model.

In [7]:
import torch.optim as optim

from src.models.models import learning_loop
from src.models.losses import TripletSemiPosMarginWithDistanceLoss


optimizer = optim.Adam(model.parameters(), lr=0.00001)
# criterion = TripletMarginLoss(margin=1, p=2)
lr_scheduler = None
criterion = TripletSemiPosMarginWithDistanceLoss(margin=1)
max_epochs = 30
max_bad_epochs = max_epochs


history = learning_loop(model=model, device=device, 
                        optimizer=optimizer, lr_scheduler=lr_scheduler, criterion=criterion, 
                        max_epochs=max_epochs, max_bad_epochs=max_bad_epochs, 
                        train_dataloader=train_dataloader, val_dataloader=val_dataloader)

Learning phase
Used device: cuda
--------------
Epoch 001/030


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:44<00:00,  2.25s/it, batch=20/20, lr=1e-5, train_loss=1.08] 
100%|██████████| 5/5 [00:08<00:00,  1.80s/it, batch=5/5, dev_loss=1]   


Finished epoch 001/030 - Train loss: 1.0609421 - Valid loss: 1.0793333 - SAVED (NEW) BEST MODEL. Duration: 53.953 s
Epoch 002/030


100%|██████████| 20/20 [00:41<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=1.05] 
100%|██████████| 5/5 [00:09<00:00,  1.83s/it, batch=5/5, dev_loss=1.08]


Finished epoch 002/030 - Train loss: 0.9988939 - Valid loss: 1.1044528 - NUMBER OF BAD EPOCH.S: 1. Duration: 50.153 s
Epoch 003/030


100%|██████████| 20/20 [00:40<00:00,  2.04s/it, batch=20/20, lr=1e-5, train_loss=0.668]
100%|██████████| 5/5 [00:08<00:00,  1.77s/it, batch=5/5, dev_loss=1.13]


Finished epoch 003/030 - Train loss: 0.9127808 - Valid loss: 1.0760817 - SAVED (NEW) BEST MODEL. Duration: 49.747 s
Epoch 004/030


100%|██████████| 20/20 [00:41<00:00,  2.08s/it, batch=20/20, lr=1e-5, train_loss=0.821]
100%|██████████| 5/5 [00:08<00:00,  1.78s/it, batch=5/5, dev_loss=0.918]


Finished epoch 004/030 - Train loss: 0.8126458 - Valid loss: 0.9516670 - SAVED (NEW) BEST MODEL. Duration: 50.428 s
Epoch 005/030


100%|██████████| 20/20 [00:40<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.846]
100%|██████████| 5/5 [00:09<00:00,  1.81s/it, batch=5/5, dev_loss=0.871]


Finished epoch 005/030 - Train loss: 0.7705238 - Valid loss: 0.9300179 - SAVED (NEW) BEST MODEL. Duration: 50.022 s
Epoch 006/030


100%|██████████| 20/20 [00:41<00:00,  2.07s/it, batch=20/20, lr=1e-5, train_loss=0.793]
100%|██████████| 5/5 [00:09<00:00,  1.84s/it, batch=5/5, dev_loss=0.854]


Finished epoch 006/030 - Train loss: 0.7221752 - Valid loss: 0.8467252 - SAVED (NEW) BEST MODEL. Duration: 50.597 s
Epoch 007/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.721]
100%|██████████| 5/5 [00:09<00:00,  1.83s/it, batch=5/5, dev_loss=0.886]


Finished epoch 007/030 - Train loss: 0.6796598 - Valid loss: 0.9378610 - NUMBER OF BAD EPOCH.S: 1. Duration: 50.358 s
Epoch 008/030


100%|██████████| 20/20 [00:41<00:00,  2.08s/it, batch=20/20, lr=1e-5, train_loss=0.721]
100%|██████████| 5/5 [00:08<00:00,  1.80s/it, batch=5/5, dev_loss=0.776]


Finished epoch 008/030 - Train loss: 0.7256453 - Valid loss: 0.8041363 - SAVED (NEW) BEST MODEL. Duration: 50.656 s
Epoch 009/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.71] 
100%|██████████| 5/5 [00:09<00:00,  1.82s/it, batch=5/5, dev_loss=0.843]


Finished epoch 009/030 - Train loss: 0.6702419 - Valid loss: 0.7603895 - SAVED (NEW) BEST MODEL. Duration: 50.244 s
Epoch 010/030


100%|██████████| 20/20 [00:41<00:00,  2.07s/it, batch=20/20, lr=1e-5, train_loss=0.599]
100%|██████████| 5/5 [00:08<00:00,  1.77s/it, batch=5/5, dev_loss=0.597]


Finished epoch 010/030 - Train loss: 0.6141637 - Valid loss: 0.6843726 - SAVED (NEW) BEST MODEL. Duration: 50.237 s
Epoch 011/030


100%|██████████| 20/20 [00:41<00:00,  2.08s/it, batch=20/20, lr=1e-5, train_loss=0.497]
100%|██████████| 5/5 [00:09<00:00,  1.85s/it, batch=5/5, dev_loss=0.684]


Finished epoch 011/030 - Train loss: 0.5546740 - Valid loss: 0.6448086 - SAVED (NEW) BEST MODEL. Duration: 50.856 s
Epoch 012/030


100%|██████████| 20/20 [00:41<00:00,  2.07s/it, batch=20/20, lr=1e-5, train_loss=0.463]
100%|██████████| 5/5 [00:09<00:00,  1.84s/it, batch=5/5, dev_loss=0.64] 


Finished epoch 012/030 - Train loss: 0.5130757 - Valid loss: 0.5833591 - SAVED (NEW) BEST MODEL. Duration: 50.616 s
Epoch 013/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.494]
100%|██████████| 5/5 [00:09<00:00,  1.85s/it, batch=5/5, dev_loss=0.527]


Finished epoch 013/030 - Train loss: 0.4672229 - Valid loss: 0.5730131 - SAVED (NEW) BEST MODEL. Duration: 50.566 s
Epoch 014/030


100%|██████████| 20/20 [00:40<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.642]
100%|██████████| 5/5 [00:09<00:00,  1.82s/it, batch=5/5, dev_loss=0.536]


Finished epoch 014/030 - Train loss: 0.4399947 - Valid loss: 0.5429432 - SAVED (NEW) BEST MODEL. Duration: 50.008 s
Epoch 015/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.394]
100%|██████████| 5/5 [00:09<00:00,  1.82s/it, batch=5/5, dev_loss=0.621]


Finished epoch 015/030 - Train loss: 0.4084833 - Valid loss: 0.4930957 - SAVED (NEW) BEST MODEL. Duration: 50.375 s
Epoch 016/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.367]
100%|██████████| 5/5 [00:08<00:00,  1.78s/it, batch=5/5, dev_loss=0.425]


Finished epoch 016/030 - Train loss: 0.3615993 - Valid loss: 0.4343035 - SAVED (NEW) BEST MODEL. Duration: 50.112 s
Epoch 017/030


100%|██████████| 20/20 [00:41<00:00,  2.08s/it, batch=20/20, lr=1e-5, train_loss=0.452]
100%|██████████| 5/5 [00:08<00:00,  1.76s/it, batch=5/5, dev_loss=0.427]


Finished epoch 017/030 - Train loss: 0.3495908 - Valid loss: 0.4114026 - SAVED (NEW) BEST MODEL. Duration: 50.353 s
Epoch 018/030


100%|██████████| 20/20 [00:40<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.267]
100%|██████████| 5/5 [00:08<00:00,  1.80s/it, batch=5/5, dev_loss=0.472]


Finished epoch 018/030 - Train loss: 0.3029065 - Valid loss: 0.3969352 - SAVED (NEW) BEST MODEL. Duration: 49.925 s
Epoch 019/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.269]
100%|██████████| 5/5 [00:09<00:00,  1.84s/it, batch=5/5, dev_loss=0.388]


Finished epoch 019/030 - Train loss: 0.3003090 - Valid loss: 0.3709843 - SAVED (NEW) BEST MODEL. Duration: 50.484 s
Epoch 020/030


100%|██████████| 20/20 [00:40<00:00,  2.04s/it, batch=20/20, lr=1e-5, train_loss=0.356]
100%|██████████| 5/5 [00:09<00:00,  1.82s/it, batch=5/5, dev_loss=0.452]


Finished epoch 020/030 - Train loss: 0.2736409 - Valid loss: 0.3933453 - NUMBER OF BAD EPOCH.S: 1. Duration: 49.905 s
Epoch 021/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.348]
100%|██████████| 5/5 [00:08<00:00,  1.76s/it, batch=5/5, dev_loss=0.24] 


Finished epoch 021/030 - Train loss: 0.2632096 - Valid loss: 0.3556644 - SAVED (NEW) BEST MODEL. Duration: 50.048 s
Epoch 022/030


100%|██████████| 20/20 [00:41<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.151]
100%|██████████| 5/5 [00:09<00:00,  1.81s/it, batch=5/5, dev_loss=0.283]


Finished epoch 022/030 - Train loss: 0.2613410 - Valid loss: 0.3438296 - SAVED (NEW) BEST MODEL. Duration: 50.154 s
Epoch 023/030


100%|██████████| 20/20 [00:41<00:00,  2.07s/it, batch=20/20, lr=1e-5, train_loss=0.18] 
100%|██████████| 5/5 [00:09<00:00,  1.81s/it, batch=5/5, dev_loss=0.352]


Finished epoch 023/030 - Train loss: 0.2347304 - Valid loss: 0.3310875 - SAVED (NEW) BEST MODEL. Duration: 50.431 s
Epoch 024/030


100%|██████████| 20/20 [00:40<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.246]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it, batch=5/5, dev_loss=0.346]


Finished epoch 024/030 - Train loss: 0.2273799 - Valid loss: 0.3604387 - NUMBER OF BAD EPOCH.S: 1. Duration: 49.912 s
Epoch 025/030


100%|██████████| 20/20 [00:40<00:00,  2.04s/it, batch=20/20, lr=1e-5, train_loss=0.264]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it, batch=5/5, dev_loss=0.4]  


Finished epoch 025/030 - Train loss: 0.2206394 - Valid loss: 0.3243571 - SAVED (NEW) BEST MODEL. Duration: 49.790 s
Epoch 026/030


100%|██████████| 20/20 [00:41<00:00,  2.05s/it, batch=20/20, lr=1e-5, train_loss=0.317]
100%|██████████| 5/5 [00:09<00:00,  1.88s/it, batch=5/5, dev_loss=0.264]


Finished epoch 026/030 - Train loss: 0.2187609 - Valid loss: 0.3331232 - NUMBER OF BAD EPOCH.S: 1. Duration: 50.440 s
Epoch 027/030


100%|██████████| 20/20 [00:40<00:00,  2.03s/it, batch=20/20, lr=1e-5, train_loss=0.232]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it, batch=5/5, dev_loss=0.243]


Finished epoch 027/030 - Train loss: 0.2291116 - Valid loss: 0.3112364 - SAVED (NEW) BEST MODEL. Duration: 49.645 s
Epoch 028/030


100%|██████████| 20/20 [00:40<00:00,  2.04s/it, batch=20/20, lr=1e-5, train_loss=0.136]
100%|██████████| 5/5 [00:09<00:00,  1.88s/it, batch=5/5, dev_loss=0.305]


Finished epoch 028/030 - Train loss: 0.2072894 - Valid loss: 0.3235063 - NUMBER OF BAD EPOCH.S: 1. Duration: 50.242 s
Epoch 029/030


100%|██████████| 20/20 [00:41<00:00,  2.07s/it, batch=20/20, lr=1e-5, train_loss=0.183]
100%|██████████| 5/5 [00:08<00:00,  1.74s/it, batch=5/5, dev_loss=0.2]  


Finished epoch 029/030 - Train loss: 0.1981710 - Valid loss: 0.3079689 - SAVED (NEW) BEST MODEL. Duration: 50.086 s
Epoch 030/030


100%|██████████| 20/20 [00:41<00:00,  2.06s/it, batch=20/20, lr=1e-5, train_loss=0.16] 
100%|██████████| 5/5 [00:09<00:00,  1.81s/it, batch=5/5, dev_loss=0.312]


Finished epoch 030/030 - Train loss: 0.1783606 - Valid loss: 0.2730992 - SAVED (NEW) BEST MODEL. Duration: 50.290 s


### **2.5. Results and Conclusion**

The training process is now complete. The training and validation loss history has been recorded, and the best model checkpoint has been saved to disk as `best.pt`.

This model can now be used to extract semantically rich embeddings from new fashion images, powering applications like image search, product recommendation, and thematic clustering.


In [66]:
from tqdm import tqdm

for step, element in enumerate(tqdm(dataloader)):
    anchor_img = element['anchor_image']['pixel_values'].to(device)
    positive_img = element['pos_image']['pixel_values'].to(device)
    negative_img = element['neg_image']['pixel_values'].to(device)
    semipos = element['semipos'].to(device)
    anchor_paths = element['anchor']
    positive_paths = element['pos']
    negative_paths = element['neg']
    

    with torch.no_grad():
        anchor_out = model(anchor_img.squeeze())
        positive_out = model(positive_img.squeeze())
        negative_out = model(negative_img.squeeze())
        

    break
    # positive_img = positive_img.to(device)
    # negative_img = negative_img.to(device)
    
# #     optimizer.zero_grad()
    # anchor_img = image_processor.preprocess(anchor_img, return_tensors='pt')
    # anchor_out = model(anchor_img)
    # break
    # positive_out = model(positive_img)
    # negative_out = model(negative_img)

  0%|          | 0/23 [00:01<?, ?it/s]


In [65]:
len(positive_paths)

1

In [74]:
import pandas as pd
negatives = (anchor_out - negative_out).norm(p=2, dim=-1, keepdim=True).squeeze().cpu().numpy()
positives = (anchor_out - positive_out).norm(p=2, dim=-1, keepdim=True).squeeze().cpu().numpy()
semipos_numpy = semipos.cpu().numpy()
df_res = pd.DataFrame({'pos_distances': positives, 'neg_distances': negatives, 'semipos': semipos_numpy,
                       'difference': negatives - positives,
                       'anchor_paths': anchor_paths, 'pos_paths': positive_paths, 'neg_paths': negative_paths})

In [84]:
well_classified = df_res[df_res['difference'] > 0]
print(len(well_classified)/len(df_res)) # 92%
well_classified[well_classified.semipos > 0].describe()

0.921875


Unnamed: 0,pos_distances,neg_distances,semipos,difference
count,25.0,25.0,25.0,25.0
mean,0.727154,1.510402,0.122493,0.783248
std,0.44777,0.260478,0.041703,0.46958
min,0.213573,0.723187,0.0625,0.003665
25%,0.341336,1.417784,0.090909,0.375635
50%,0.577792,1.561155,0.125,0.864405
75%,0.973428,1.707536,0.153846,1.192614
max,1.823732,1.885392,0.2,1.386325


In [85]:
well_classified[well_classified.semipos == 0].describe()

Unnamed: 0,pos_distances,neg_distances,semipos,difference
count,34.0,34.0,34.0,34.0
mean,0.56708,1.566481,0.0,0.999401
std,0.41523,0.216463,0.0,0.426362
min,0.101041,0.633963,0.0,0.116095
25%,0.227587,1.518892,0.0,0.797545
50%,0.422938,1.59186,0.0,1.051586
75%,0.785881,1.688899,0.0,1.330646
max,1.653002,1.840469,0.0,1.665261
