In [1]:
import os
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn.functional as F

import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch import optim
import torchvision.transforms as transforms

import clip

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CONTEXT_LENGTH = 77
SAMPLE_SIZE = None
TRAIN_PROP = 0.7
VAL_PROP = 0.2
BATCH_SIZE = 64
NUM_WORKERS = 0
EPOCHS = 4
LEARNING_RATE = 5e-6

# Params same as paper
BETAS = (0.9,0.98)
EPS = 1e-6
WEIGHT_DECAY = 0.2

In [3]:
models = clip.available_models()
print(models)
model, preprocess = clip.load('RN50', device, jit=False)

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


In [4]:
print(model)

CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
     

## Create Dataset

In [5]:
class TextImg_Dataset(Dataset):
    def __init__(self, data_dir, images_dir, images, titles, augmented_text=False, augmented_image=False):
        self.data_dir = data_dir
        self.images_dir = images_dir
        self.images_paths = images
        self.titles = titles
        assert (len(titles)==len(images))
        self.augmented_text = augmented_text
        self.augmented_image = augmented_image

    def __len__(self):
        return len(self.images_paths)
    
    def transform_img(self, img):
        if img.mode != 'RGB':
            img = img.convert('RGB')
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
            transforms.RandomRotation(30),  # Randomly rotate the image by up to 30 degrees
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Randomly change brightness, contrast, saturation, and hue
        ])
        return transform(img)
    
    def __getitem__(self, index):
        txt = self.titles[index]
        img_path = self.images_paths[index]
        text_tokens = clip.tokenize(txt, context_length=CONTEXT_LENGTH, truncate=True).squeeze().to(device) # torch.Size([77])
        # print(type(text_tokens), text_tokens.shape)
        path = os.path.join(self.data_dir, self.images_dir, img_path)
        img = Image.open(path)
        if self.augmented_image:
            img = self.transform_img(img)
        image = preprocess(img).to(device) # torch.Size([3, 224, 224])
        # print(type(image), image.shape)
        return text_tokens, image


In [6]:
summary_bert_path = "data/Summaries/Summary_Bert_65.csv"
summary_bert_path = "data/Summaries/export_summary_bert_with_ingredients.csv"
data_dir = "data"
images_dir = "Food Images/Food Images"

In [7]:
summary_df = pd.read_csv(summary_bert_path)
liste_images = summary_df["Image_Name"].tolist()
liste_textes = summary_df["summary_with_ingredients"].tolist()
if SAMPLE_SIZE != None:
    liste_images=liste_images[:SAMPLE_SIZE]
    liste_textes=liste_textes[:SAMPLE_SIZE]
liste_images = [image + ".jpg" for image in liste_images]

In [8]:
liste_images[0], liste_textes[0]

('miso-butter-roast-chicken-acorn-squash-panzanella.jpg',
 'Roast chicken in a large cast-iron skillet until an instant-read thermometer inserted into the thickest part of breast registers 155°F, 50–60 minutes. Meanwhile, roast squash on lower rack until mostly tender, about 25 minutes.Whole chicken')

In [9]:
assert(len(liste_textes) == len(liste_images))
length = len(liste_images)
print(length)
train_size = int(length*TRAIN_PROP)
val_size = int(length*VAL_PROP)

1221


In [10]:
train_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[:train_size], liste_textes[:train_size])
train_dataset_augmented = TextImg_Dataset(data_dir, images_dir, liste_images[:train_size], liste_textes[:train_size], augmented_text=True, augmented_image=True)
val_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[train_size:train_size+val_size], liste_textes[train_size:train_size+val_size])
test_dataset = TextImg_Dataset(data_dir, images_dir, liste_images[train_size+val_size:], liste_textes[train_size+val_size:])

In [11]:
train_dl = DataLoader(train_dataset_augmented, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dl = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_dl = DataLoader(test_dataset, batch_size=500, shuffle=False, num_workers=NUM_WORKERS)

In [12]:
len(train_dl), next(iter(train_dl))

(14,
 [tensor([[49406,  1983,   320,  ...,     0,     0,     0],
          [49406,  4741,   753,  ...,     0,     0,     0],
          [49406, 12919,   518,  ...,     0,     0,     0],
          ...,
          [49406,   622,   516,  ...,     0,     0,     0],
          [49406,  6803,  5066,  ...,     0,     0,     0],
          [49406,   679,  4403,  ...,     0,     0,     0]], device='cuda:0',
         dtype=torch.int32),
  tensor([[[[-1.3981, -1.3981, -1.3981,  ...,  1.6822,  1.6822,  1.6822],
            [-1.3981, -1.3981, -1.3981,  ...,  1.6822,  1.6822,  1.6822],
            [-1.3981, -1.3981, -1.3981,  ...,  1.6822,  1.6822,  1.6822],
            ...,
            [ 1.4778,  1.4778,  1.4778,  ..., -1.3981, -1.3981, -1.3981],
            [ 1.4778,  1.4778,  1.4778,  ..., -1.3981, -1.3981, -1.3981],
            [ 1.4778,  1.4778,  1.4778,  ..., -1.3981, -1.3981, -1.3981]],
  
           [[-1.3469, -1.3469, -1.3469,  ...,  1.8198,  1.8198,  1.8198],
            [-1.3469, -1.3469, -1.

## Fine Tune CLIP

In [None]:
def compute_loss(text_embeddings, image_embeddings):
    logits = (text_embeddings @ image_embeddings.T)
    images_similarity = image_embeddings @ image_embeddings.T
    texts_similarity = text_embeddings @ text_embeddings.T
    targets = F.softmax(
        (images_similarity + texts_similarity) / 2, dim=-1
    )
    texts_loss = cross_entropy(logits, targets, reduction='none')
    images_loss = cross_entropy(logits.T, targets.T, reduction='none')
    loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
    return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = torch.nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()
    
def get_accuracy(text_features, image_features):
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    ground_truth = torch.arange(similarity.shape[0]).to(device)
    max_indices = torch.argmax(similarity, dim=1)
    return torch.mean((max_indices == ground_truth).float()).item() *100

def convert_models_to_fp32(model): 
    for p in model.parameters():
        p.data = p.data.float()
        print(type(p.data), p.data)
        p.grad.data = p.grad.data.float() 

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE,betas=BETAS,eps=EPS,weight_decay=WEIGHT_DECAY)
train_losses = []
val_losses = []
val_accuracies = []

In [None]:
clip.model.convert_weights(model)

for ep in range(EPOCHS):
    model.train()
    train_loss = 0
    for batch in tqdm(train_dl): 
        optimizer.zero_grad()
        text, img = batch
        # print(text.shape, img.shape)
        text_features = model.encode_text(text)
        image_features = model.encode_image(img)
        loss = compute_loss(text_features, image_features)
        train_loss+=loss.item()
        loss.backward()
        optimizer.step()
    train_losses.append(train_loss)

    # VALIDATION LOOP
    with torch.no_grad():
        model.eval()
        val_loss = 0
        val_acc = 0
        for b_idx, batch in enumerate(tqdm(val_dl)): 
            text, img = batch
            # print(text.shape, img.shape)
            text_features = model.encode_text(text)
            image_features = model.encode_image(img)
            loss = compute_loss(text_features, image_features)
            val_loss+=loss.item()
            val_acc += get_accuracy(text_features, image_features)
        val_accuracies.append(val_acc/b_idx+1)
        val_losses.append(val_loss)

torch.save(model.state_dict(), "data/checkpoints/trained_model.pt")    

100%|██████████| 14/14 [00:10<00:00,  1.39it/s]
100%|██████████| 4/4 [00:00<00:00,  4.10it/s]
100%|██████████| 14/14 [00:09<00:00,  1.49it/s]
100%|██████████| 4/4 [00:00<00:00,  4.17it/s]
100%|██████████| 14/14 [00:09<00:00,  1.50it/s]
100%|██████████| 4/4 [00:00<00:00,  4.11it/s]


RuntimeError: Parent directory data/checkpoints does not exist.

In [None]:
train_losses, val_losses, val_accuracies

([54.931640625,
  46.724609375,
  41.876953125,
  36.494140625,
  31.0263671875,
  23.572265625,
  17.1455078125,
  12.544921875,
  8.808837890625],
 [14.71875,
  14.4921875,
  14.236328125,
  14.173828125,
  14.13671875,
  14.130859375,
  14.23046875,
  14.60546875,
  14.662109375],
 [12.658654113610586,
  13.700320780277252,
  15.903846383094788,
  14.100961595773697,
  17.98717971642812,
  16.70512826244036,
  17.346154113610584,
  17.58653865257899,
  16.544871985912323])

## Run Inference on test_set

In [None]:
len(test_dl), train_size+val_size, length

(1, 1098, 1221)

In [None]:
with torch.no_grad():
    for b in test_dl:
        text, img = b
        print(text.shape, img.shape)
        text_features = model.encode_text(text)
        image_features = model.encode_image(img)

        accuracy = get_accuracy(text_features, image_features)

        # clip_score = torch.matmul(image_features, text_features.T)
        print(f'Accuracy: {accuracy:.2f} %')

torch.Size([123, 77]) torch.Size([123, 3, 224, 224])
Accuracy: 36.59 %
