<a href="https://colab.research.google.com/github/gekoramy/uni.deep-learning/blob/standard-finetuning/standard_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Standard finetuning
In this notebook we propose the straightforward solution to fine tune CLIP. The general idea is to add linear layer(s) on top of the 1024 visual features of CLIP.

Dependences

In [1]:
%%shell
tee requirements.txt << END
ftfy
jaxtyping
jupyter
matplotlib
pydantic
regex
torch
torchvision
tqdm
END

pip install -q -r requirements.txt
pip install -q git+https://github.com/openai/CLIP.git
pip install -q ultralytics

ftfy
jaxtyping
jupyter
matplotlib
pydantic
regex
torch
torchvision
tqdm
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.9/121.9 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.9/84.9 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m55.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for clip (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m617.5/617.5 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h



In [2]:
import clip
import json
import os
import pickle
import re
import torch
import torch.nn.functional as F
import torchvision

from datetime import datetime
from jaxtyping import Float, UInt, Int
from pydantic.dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.io import read_image
from typing import Literal, Callable, Mapping, TypeVar
from tqdm import tqdm

In [3]:
device: Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
device

'cuda'

## Load the dataset
First of all we have to load the dataset.

In [4]:
%%shell
if ! [ -d dataset ]; then
  mkdir dataset &&
  gdown 1P8a1g76lDJ8cMIXjNDdboaRR5-HsVmUb &&
  tar -xf refcocog.tar.gz -C dataset &&
  rm refcocog.tar.gz
fi

Downloading...
From: https://drive.google.com/uc?id=1P8a1g76lDJ8cMIXjNDdboaRR5-HsVmUb
To: /content/refcocog.tar.gz
100% 13.5G/13.5G [02:36<00:00, 86.0MB/s]




Folder paths

In [5]:
root = os.path.join('dataset', 'refcocog', '')
data_instances = os.path.join(root, 'annotations', 'instances.json')
data_refs = os.path.join(root, 'annotations', 'refs(umd).p')
data_images = os.path.join(root, 'images', '')

Type declaration

In [6]:
I = TypeVar('I')
P = TypeVar('P')
B = TypeVar('B')
T = TypeVar('T')

Img = UInt[torch.Tensor, 'C W H']
BBox = UInt[torch.Tensor, '4']
Split = Literal['train', 'test', 'val']

@dataclass
class Info:
    description: str  # This is stable 1.0 version of the 2014 MS COCO dataset.
    url: str  # http://mscoco.org/
    version: str  # 1.0
    year: int  # 2014
    contributor: str  # Microsoft COCO group
    date_created: datetime  # 2015-01-27 09:11:52.357475

@dataclass
class Image:
    license: int  # each image has an associated licence id
    file_name: str  # file name of the image
    coco_url: str  # example http://mscoco.org/images/131074
    height: int
    width: int
    flickr_url: str  # example http://farm9.staticflickr.com/8308/7908210548_33e
    id: int  # id of the imag
    date_captured: datetime  # example '2013-11-21 01:03:06'

@dataclass
class License:
    url: str  # example http://creativecommons.org/licenses/by-nc-sa/2.0/
    id: int  # id of the licence
    name: str  # example 'Attribution-NonCommercial-ShareAlike License

@dataclass
class Annotation:
    # segmentation: list[list[float]]  # description of the mask; example [[44.17, 217.83, 36.21, 219.37, 33.64, 214.49, 31.08, 204.74, 36.47, 202.68, 44.17, 203.2]]
    area: int  # number of pixel of the described object
    iscrowd: Literal[1, 0]  # Crowd annotations (iscrowd=1) are used to label large groups of objects (e.g. a crowd of people)
    image_id: int  # id of the target image
    bbox: tuple[int, int, int, int]  # bounding box coordinates [xmin, ymin, width, height]
    category_id: int
    id: int  # annotation id

@dataclass
class Category:
    supercategory: str  # example 'vehicle'
    id: int  # category id
    name: str  # example 'airplane'

@dataclass
class Instances:
    info: Info
    images: list[Image]
    licenses: list[License]
    annotations: list[Annotation]
    categories: list[Category]

@dataclass
class Sentence:
    tokens: list[str]  # tokenized version of referring expression
    raw: str  # unprocessed referring expression
    sent: str  # referring expression with mild processing, lower case, spell correction, etc.
    sent_id: int  # unique referring expression id

@dataclass
class Ref:
    image_id: int  # unique image id
    split: Split
    sentences: list[Sentence]
    file_name: str  # file name of image relative to img_root
    category_id: int  # object category label
    ann_id: int  # id of object annotation in instance.json
    sent_ids: list[int]  # same ids as nested sentences[...][sent_id]
    ref_id: int  # unique id for refering expression

In [7]:
@dataclass
class Prediction:
  image: Image  # image on which the prediction has been computed
  description: str  # natural language description of the area of interest
  ground_truth_bbox: tuple[int, int, int, int] # ground truth bounding box
  output_bbox: tuple[int, int, int, int] # predicted bounding box

Read the dataset infos

In [8]:
def fix_ref(x: Ref) -> Ref:
    x.file_name = fix_filename(x.file_name)
    return x


def fix_filename(x: str) -> str:
    """
    :param x: COCO_..._[image_id]_[annotation_id].jpg
    :return:  COCO_..._[image_id].jpg
    """
    return re.sub('_\d+\.jpg$', '.jpg', x)

In [9]:
with open(data_refs, 'rb') as f:
    raw = pickle.load(f)

In [10]:
refs: list[Ref] = [
    fix_ref(Ref(**ref))
    for ref in raw
]

In [11]:
with open(data_instances, 'r') as f:
    raw = json.load(f)

In [None]:
instances: Instances = Instances(**raw)

In [14]:
id2annotation: Mapping[int, Annotation] = {
    x.id: x
    for x in instances.annotations
}

NameError: ignored

Define custom dataset

In [None]:
class CocoDataset(Dataset[tuple[I, P, B]]):

    def __init__(
        self,
        split: Split,
        img_transform: Callable[[Img], I] = lambda x: x,
        prompt_transform: Callable[[list[Sentence]], P] = lambda ps: [ p.sent for p in ps ],
        bb_transform: Callable[[UInt[torch.Tensor, '4']], B] = lambda x: x
    ):
        """
        :param split: train, test or val
        :param img_transform: apply transformation on the processed images
        :param prompt_transform: apply transformation on the prompts
        :param bb_transform: apply transformation on the bounding box
        """
        self.img_transform = img_transform
        self.prompt_transform = prompt_transform
        self.bb_transform = bb_transform

        # Internally the dataset is a list of tuple[str, list[Sentence], UInt[torch.Tensor, '4']]
        # Such that:
        # str                     : image filename
        # list[Sentence]          : list of reference expression objects
        # UInt[torch.Tensor, '4'] : bounding box
        self.items: list[tuple[str, list[Sentence], UInt[torch.Tensor, '4']]] = [
            (i, ps, o)
            for ref in refs
            if ref.split == split
            for i in [os.path.join(data_images, ref.file_name)]
            for ps in [ref.sentences]
            for o in [torch.tensor(id2annotation[ref.ann_id].bbox, dtype=torch.int)]
        ]


    def __len__(self) -> int:
        return len(self.items)


    def __getitem__(self, item: int) -> tuple[I, P, B]:
        i, ps, b = self.items[item]
        return (
            self.img_transform(read_image(i)),
            self.prompt_transform(ps),
            self.bb_transform(b),
        )

## Training free CLIP results
For the sake of comparison with the implementations below, we have to evaluate CLIP training free with the same portion of the dataset.

Load yolo model

In [None]:
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
yolo_model.to(device=device).eval()

Load CLIP model

In [None]:
clip_model, preprocess = clip.load('RN50')
clip_model = clip_model.to(device=device).eval()

Baseline evaluation

In [None]:
test_dataset: Dataset[tuple[Img, list[str], UInt[torch.Tensor, '4']]] = CocoDataset(split='test')

In [None]:
BATCH_SIZE = 32
NUM_WORKERS = os.cpu_count()
print(f"Creating DataLoader's with batch size {BATCH_SIZE} and {NUM_WORKERS} workers.")

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=False
)

In [None]:
MAX_ITER: int = 50 # max number of iterations
current_iteration: int = 0

stored_predictions: list[Prediction] = []

ious: list[float] = []
coss: list[float] = []
euds: list[float] = []

batch: tuple[UInt[torch.Tensor, '1 C W H'], tuple[list[tuple[str]]], UInt[torch.Tensor, '1 4']]

with torch.no_grad():
    for batch in tqdm(iter(test_dataloader)):

        if current_iteration >= MAX_ITER:
          print("Stop iteration: MAX_ITER = "+str(MAX_ITER))
          break;

        [img], (prompts), [true_xywh] = batch

        [true_xyxy] = torchvision.ops.box_convert(true_xywh.unsqueeze(0), in_fmt='xywh', out_fmt='xyxy')

        img_pil: Image = transforms.ToPILImage()(img)

        # yolo bboxes
        predictions = yolo_model(img_pil)

        # xmin,      ymin,      xmax,      ymax,      confidence, class
        # 274.06390, 231.20389, 392.66345, 372.59018, 0.93251,    23.00000
        bboxes: Float[torch.Tensor, 'X 6'] = predictions.xyxy[0]

        # if empty, put a bbox equal to image size
        if len(bboxes) == 0:
            bboxes = torch.tensor([[0, 0, img.size()[1], img.size()[2], 0, 0]], dtype=torch.float)

        # from yolo bboxes to cropped images
        crops: list[Image] = [
            img_pil.crop((xmin, ymin, xmax, ymax))
            for bbox in bboxes
            for [xmin, ymin, xmax, ymax, _, _] in [bbox.tolist()]
        ]

        # clip preprocess on cropped images
        preprocess_crops: Float[torch.Tensor, 'X 3 244 244'] = torch.stack([
            preprocess(crop)
            for crop in crops
        ]).to(device=device)

        # format each available prompt
        prompts_tokens: Int[torch.Tensor, 'P 77'] = clip.tokenize([
            template.format(prompt)
            for template in ["{}", "A photo of {}", "We can see {}"]
            for (prompt,) in prompts  # <- ¯\_(ツ)_/¯
        ])

        # clip scores
        ass_z: tuple[Float[torch.Tensor, 'X P'], Float[torch.Tensor, 'P X']] = clip_model(preprocess_crops, prompts_tokens)
        _, logits_per_prompt = ass_z

        # final prediction
        best_match: int = torch.argmax(torch.max(logits_per_prompt, 0).values).item()
        prediction_bbox: Float[torch.Tensor, '4'] = bboxes[best_match][:4]

        # metrics
        iou: float = torchvision.ops.box_iou(true_xyxy.unsqueeze(0), prediction_bbox.unsqueeze(0)).item()
        ious.append(iou)

        rectangle: tuple[int, int, int, int] = true_xyxy.tolist()
        ground_truth_crop = img_pil.crop(rectangle)

        rectangle: tuple[int, int, int, int] = torch.tensor(prediction_bbox, dtype=torch.int).tolist()
        prediction_crop = img_pil.crop(rectangle)

        # from float16 to float32
        X: Float[torch.Tensor, '1'] = torch.tensor(
            clip_model.encode_image(torch.tensor(preprocess(ground_truth_crop)).unsqueeze(0)),
            dtype=torch.float
        )
        Y: Float[torch.Tensor, '1'] = torch.tensor(
            clip_model.encode_image(torch.tensor(preprocess(prediction_crop)).unsqueeze(0)),
            dtype=torch.float
        )

        cos: float = F.cosine_similarity(X, Y).item()
        coss.append(cos)

        eud: float = torch.cdist(X, Y, p=2).item()
        euds.append(eud)

        # store the prediction
        pred : Prediction
        pred.image = img_pil
        pred.description = prompts
        pred.ground_truth_bbox = true_xyxy
        pred.output_bbox = prediction_bbox
        stored_predictions.append(pred)

        torch.cuda.empty_cache()

In [None]:
len(stored_predictions)

## Linear layers on top of image encoder

## Linear layers on top of text encoder

## Linear layers on top of image encoder and on top of text encoder

## Bottleneck