# Prepare data

In [17]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
## Prepare data for export and training in Colab

import pandas as pd
df = pd.read_csv('data/heritage_weaver_data.csv')
df.shape

(31153, 12)

In [12]:
!mkdir data/siglip-training

mkdir: data/siglip-training: File exists


In [13]:

!mkdir data/siglip-training/images

In [14]:
downloaded_images = list(df[df['img_path'] != ''].img_path.unique())
len(downloaded_images)

30130

In [15]:
downloaded_images[:10], len(downloaded_images)

(['smg_imgs/204|255|medium_cd0620_049_100527_2005_86_35_Professional_audio_cassette_tape_used_by_Radio_Manchester.jpg',
  'smg_imgs/477|975|medium_SMG00247371.jpg',
  'smg_imgs/58|255|medium_1982_1712__0001_.jpg',
  'smg_imgs/58|256|medium_1982_1712__0002_.jpg',
  'smg_imgs/58|257|medium_1982_1712__0003_.jpg',
  'smg_imgs/212|509|medium_cd0404_011_080808_2002_19_100_LCM_Speed_recorder_opened.jpg',
  'smg_imgs/212|510|medium_cd0404_012_080808_2002_19_100_LCM_Speed_recorder.jpg',
  'smg_imgs/247|329|medium_cd0098_006_050329_GG_1991_126_11_Light_shade.jpg',
  'smg_imgs/209|376|medium_cd0472_015_081216_1996_10_507_Ferranti_9E_Radio_component.jpg',
  'smg_imgs/247|310|medium_cd0097_009_050329_GG_1990_25_3_Gas_light.jpg'],
 30130)

In [16]:
# write a for loop to copy the images to the clip-training folder
import shutil
import os
from tqdm import tqdm

for img in tqdm(downloaded_images):
    shutil.copy(img, os.path.join('data/siglip-training/images', os.path.basename(img)))

100%|██████████| 30130/30130 [01:37<00:00, 307.80it/s]


In [18]:
import json
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader
from transformers import SiglipProcessor, SiglipModel

In [19]:
# Create folder for saving the model
Path('models').mkdir(exist_ok=True)

In [None]:
# Load SigLIP model and processor
checkpoint = "google/siglip-base-patch16-224"
processor = SiglipProcessor.from_pretrained(checkpoint)
model = SiglipModel.from_pretrained(checkpoint)
model.to(device)



In [None]:
df = pd.read_csv('/content/heritage_weaver_data.csv', index_col=0)
df = df[['name','description','img_path']].dropna().reset_index(drop=True)
df['filepath'] = df['img_path'].apply(lambda x: '/content/images/' + x.split('/')[-1])
df['downloaded'] =df['filepath'].apply(lambda x: Path(x).is_file())
df = df[df.downloaded==True].sample(frac=1.0).reset_index(drop=True)
df.shape

In [None]:
threshold = int(len(df)*.9)
df_train = df.iloc[:threshold]
df_eval = df.iloc[threshold:]
df_train[['filepath','name']].to_csv('train.csv', sep='\t') # name | description
df_eval[['filepath','name']].to_csv('eval.csv', sep='\t') # name | description
df_train.shape, df_eval.shape

In [None]:
# Define a custom dataset
class image_title_dataset():
    def __init__(self, df, column='name'):  # description | name
        # Initialize data
        self.df = df
        self.column = column

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Preprocess image using SigLIPS's preprocessing function
        processed = processor(text=[self.df.iloc[idx][self.column]], 
                              images=[Image.open(self.df.iloc[idx].filepath)],
                         return_tensors="pt",
                         max_length=64,
                         padding='max_length', truncation=True)
        return processed['input_ids'].to(device), processed['pixel_values'].squeeze(0).to(device) # 

In [None]:
dataset_train = image_title_dataset(df_train)
dataset_eval = image_title_dataset(df_eval)

In [None]:

# Prepare the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) # the lr is smaller, more safe for fine tuning to new dataset
train_dataloader = DataLoader(dataset_train, batch_size=32, shuffle=True) #Define your own dataloader
eval_dataloader = DataLoader(dataset_eval, batch_size=32, shuffle=True) #Define your own dataloader

# Specify the loss function
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()


In [None]:
# handy utility at https://github.com/wenwei202/pytorch-examples/blob/ecbb7beb0fac13133c0b09ef980caf002969d315/imagenet/main.py#L296
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
# Train the model
lowest_loss = 999
num_epochs = 10

losses_train = AverageMeter()
losses_eval = AverageMeter()

for epoch in range(num_epochs):

    model.train()
    running_loss = 0.0
    
    for idx, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        texts, images = batch

        # Forward pass
        output = model(texts, images)
        # Compute loss
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(output.logits_per_image,ground_truth) + loss_txt(output.logits_per_text,ground_truth))/2
        losses_train.update(total_loss.item(), len(images))
        # Backward pass
        total_loss.backward()
        optimizer.step()
        running_loss += total_loss.item()
        
        if idx % 10 == 0:
          print('Epoch: [{0}]\t'
                'Training Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                  epoch, loss=losses_train,))

    model.eval()
    
    for idx, batch in enumerate(tqdm(eval_dataloader)):
      texts,images = batch
      # Forward pass
      output = model(texts, images)
      # Compute loss
      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
      total_loss = (loss_img(output.logits_per_image,ground_truth) + loss_txt(output.logits_per_text,ground_truth))/2
      losses_eval.update(total_loss.item(), len(images))

      if idx % 10 == 0:
        print('Epoch: [{0}]\t'
                  'Eval Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(
                   epoch, loss=losses_eval))
    

In [None]:
!huggingface-cli login

In [None]:
model.save_pretrained(f'./models/{checkpoint}-ft-last')

In [None]:
model.push_to_hub("Kaspar/clip-heritage-weaver-name")