# CLIP Fine-Tuning Demo

*Based on https://github.com/huggingface/transformers/tree/main/examples/pytorch/contrastive-image-text*
*and https://github.com/clip-italian/clip-italian*

Before running, download the [TrashNet dataset](https://github.com/garythung/trashnet/raw/master/data/dataset-resized.zip) and extract it to the 'dataset' folder.

In [None]:
%pip install huggingface
%pip install datasets
%pip install pillow
%pip install transformers
%pip install scikit-learn
%pip install transformers[torch]
%pip install tensorboardX
# PyTorch + CUDA should be installed manually

## Dataset Processing / Model Initialization

In [1]:
from transformers import CLIPProcessor, CLIPConfig, CLIPImageProcessor, CLIPModel

caption_map = {
    "cardboard": "a photo of cardboard",
    "glass": "a photo of glass",
    "metal": "a photo of metal",
    "paper": "a photo of paper",
    "plastic": "a photo of plastic",
    "trash": "a photo of trash"
}

model_name_or_path = 'openai/clip-vit-base-patch32'
config = CLIPConfig.from_pretrained(model_name_or_path)
processor = CLIPProcessor.from_pretrained(model_name_or_path)
image_processor = CLIPImageProcessor.from_pretrained(model_name_or_path)
model = CLIPModel.from_pretrained(model_name_or_path, config=config).to('cuda')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets import load_dataset

# Load dataset
ds = load_dataset("imagefolder", data_dir="dataset", split="train")
ds = ds.train_test_split(test_size=0.2, seed=512)
labels = ds['train'].features['label'].names

Resolving data files: 100%|██████████| 2527/2527 [00:00<00:00, 84235.42it/s]


In [3]:
from torchvision.transforms import Compose, RandomResizedCrop, RandomHorizontalFlip, ToTensor, Normalize, Resize, CenterCrop, ConvertImageDtype, InterpolationMode, ColorJitter
import torch

# Setup image transforms
_train_transform = Compose([
    RandomResizedCrop(config.vision_config.image_size, interpolation=InterpolationMode.BICUBIC, antialias=None),
    RandomHorizontalFlip(),
    ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
])

_val_transform = Compose([
    Resize([config.vision_config.image_size], interpolation=InterpolationMode.BICUBIC, antialias=None),
    CenterCrop(config.vision_config.image_size),
    ToTensor(),
    ConvertImageDtype(torch.float),
    Normalize(
        mean=image_processor.image_mean,
        std=image_processor.image_std
    )
])

def train_transform(example_batch):
    example_batch['pixel_values'] = [
        _train_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

def val_transform(example_batch):
    example_batch['pixel_values'] = [
        _val_transform(img.convert('RGB')) for img in example_batch['image']
    ]
    return example_batch

# Setup text transform (tokenization)
max_target_length = 32
def tokenize_captions(examples):
    captions = [caption_map[labels[l]] for l in examples['label']]
    text_inputs = processor.tokenizer(captions, padding=True, truncation=True, max_length=max_target_length)
    examples['input_ids'] = text_inputs['input_ids']
    examples['attention_mask'] = text_inputs['attention_mask']
    return examples

# Perform tokenization and image transformation
prepared_ds = ds.map(tokenize_captions, batched=True, remove_columns=['label'])
prepared_ds['train'].set_transform(train_transform)
prepared_ds['test'].set_transform(val_transform)

## Training

In [4]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    input_ids = torch.tensor([example["input_ids"] for example in examples], dtype=torch.long)
    attention_mask = torch.tensor([example["attention_mask"] for example in examples], dtype=torch.long)
    return {
        "pixel_values": pixel_values,
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "return_loss": True,
    }

In [5]:
import numpy as np
from datasets import load_metric

# Setup evaluation model input
eval_input = {}
eval_input['pixel_values']  = torch.stack([example["pixel_values"] for example in prepared_ds['test']]).to('cuda')

eval_text = [
    "a photo of cardboard",
    "a photo of glass",
    "a photo of metal",
    "a photo of paper",
    "a photo of plastic",
    "a photo of trash"
]
eval_tokens = processor.tokenizer(eval_text, padding=True, truncation=True, max_length=max_target_length)
eval_input['input_ids'] = torch.tensor(eval_tokens['input_ids']).to('cuda')
eval_input['attention_mask'] = torch.tensor(eval_tokens['attention_mask']).to('cuda')

metric = load_metric("accuracy")
def compute_metrics(p):
    with torch.no_grad():
        output = model(**eval_input)
    logits_per_image = output.logits_per_image
    probs = logits_per_image.softmax(dim=1).cpu().numpy()
    predictions = np.argmax(probs, axis=1)
    references = ds['test']['label']
    return metric.compute(predictions=predictions, references=references)

  metric = load_metric("accuracy")


In [6]:
# Freeze text and vision model
for param in model.text_model.parameters():
    param.requires_grad = False

for param in model.vision_model.parameters():
    param.requires_grad = False

print(f"num params:", model.num_parameters())
print(f"num trainable params:", model.num_parameters(only_trainable=True))

num params: 151277313
num trainable params: 655361


In [7]:
from transformers import TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
import os

if os.path.exists('./clip-trash'):
    last_checkpoint = get_last_checkpoint("./clip-trash")
else:
    last_checkpoint = None
  
training_args = TrainingArguments(
  output_dir="./clip-trash",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=64,
  #fp16=True,
  save_steps=200,
  eval_steps=200,
  logging_steps=50,
  learning_rate=1e-4,
  save_total_limit=1,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
  resume_from_checkpoint=last_checkpoint,
)

In [8]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
)

In [9]:
train_results = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
processor.save_pretrained("./clip-trash")
image_processor.save_pretrained("./clip-trash")

 99%|█████████▉| 8050/8128 [00:07<00:00, 5779.51it/s]

{'loss': 1.4595, 'learning_rate': 9.596456692913387e-07, 'epoch': 63.39}


100%|█████████▉| 8100/8128 [00:14<00:00, 454.34it/s] 

{'loss': 1.4976, 'learning_rate': 3.44488188976378e-07, 'epoch': 63.78}


100%|██████████| 8128/8128 [00:18<00:00, 436.55it/s]


{'train_runtime': 18.5818, 'train_samples_per_second': 6960.794, 'train_steps_per_second': 437.418, 'train_loss': 0.022948608154386985, 'epoch': 64.0}
***** train metrics *****
  epoch                    =       64.0
  train_loss               =     0.0229
  train_runtime            = 0:00:18.58
  train_samples_per_second =   6960.794
  train_steps_per_second   =    437.418


['./clip-trash\\preprocessor_config.json']

In [10]:
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

100%|██████████| 64/64 [00:23<00:00,  2.76it/s]

***** eval metrics *****
  epoch                   =       64.0
  eval_accuracy           =     0.9209
  eval_loss               =     0.9143
  eval_runtime            = 0:00:23.21
  eval_samples_per_second =     21.792
  eval_steps_per_second   =      2.756





## Second Stage Training

I am not sure yet how helpful this is, as CLIP's visual and text models are already trained on a dataset of 600 million images :/

In [11]:
# Unfreeze text and vision model
for param in model.text_model.parameters():
    param.requires_grad = True

for param in model.vision_model.parameters():
    param.requires_grad = True

# Freeze embedding layer
for param in model.vision_model.embeddings.parameters():
    param.requires_grad = False

for param in model.text_model.embeddings.parameters():
    param.requires_grad = False

print(f"num params:", model.num_parameters())
print(f"num trainable params:", model.num_parameters(only_trainable=True))

num params: 151277313
num trainable params: 123542529


In [12]:
if os.path.exists('./clip-trash-s2'):
    last_checkpoint = get_last_checkpoint("./clip-trash-s2")
else:
    last_checkpoint = None

training_args = TrainingArguments(
  output_dir="./clip-trash-s2",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=8,
  #fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=20,
  learning_rate=1e-8,
  save_total_limit=1,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
  resume_from_checkpoint=last_checkpoint,
  #weight_decay=1e-2
)

In [13]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
)

In [14]:
train_results = trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
processor.save_pretrained("./clip-trash-s2")
image_processor.save_pretrained("./clip-trash-s2")

100%|██████████| 1016/1016 [00:05<00:00, 202.83it/s] 


{'train_runtime': 4.8886, 'train_samples_per_second': 3307.306, 'train_steps_per_second': 207.832, 'train_loss': 0.022579181851364497, 'epoch': 8.0}
***** train metrics *****
  epoch                    =        8.0
  train_loss               =     0.0226
  train_runtime            = 0:00:04.88
  train_samples_per_second =   3307.306
  train_steps_per_second   =    207.832


['./clip-trash-s2\\preprocessor_config.json']

In [16]:
trainer._load_best_model()
metrics = trainer.evaluate(prepared_ds['test'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

100%|██████████| 64/64 [00:28<00:00,  2.21it/s]

***** eval metrics *****
  epoch                   =        8.0
  eval_accuracy           =     0.9209
  eval_loss               =      0.914
  eval_runtime            = 0:00:28.99
  eval_samples_per_second =     17.454
  eval_steps_per_second   =      2.208



