In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from weavingtools.weaving_tools import *
import pandas as pd
import json
import torch
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel

Path('models').mkdir(exist_ok=True)

In [None]:
# Define a custom dataset
class image_title_dataset():
    def __init__(self, df):
        # Initialize data
        self.df = df 
       
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Preprocess image using CLIP's preprocessing function
        #print(self.df.iloc[idx]['name'])
        processed = processor(text=[self.df.iloc[idx]['name']], images=[Image.open(self.df.iloc[idx].img_path)], 
                         return_tensors="pt", 
                         max_length=77, padding='max_length', truncation=True)
        return processed['input_ids'].to(device), processed['pixel_values'].squeeze(0).to(device)

In [None]:
df = pd.read_csv('data/communications_data_updated.csv')
df = df[['name','img_path']].dropna().reset_index(drop=True)
df['filepath'] = df['img_path'].apply(lambda x: '/content/img_data/' + x.split('/')[-1])
df = df.sample(frac=1.0).reset_index(drop=True)

In [None]:
threshold = int(len(df)*.9)
df_train = df.iloc[:threshold]
df_eval = df.iloc[threshold:]
df_train[['filepath','name']].to_csv('train.csv', sep='\t')
df_eval[['filepath','name']].to_csv('eval.csv', sep='\t')

In [None]:
# from pathlib import Path
# imgs_path = Path('data/img_data')
# imgs_path.mkdir(exist_ok=True)
# for i,row in tqdm(df.iterrows(), total=len(df)):
#     img = Image.open(row['img_path'])
#     img.save(imgs_path / row['img_name'])

In [None]:
# Choose computation device
device = "mps" if torch.backends.mps.is_available() else "cpu" 
print(device)

In [None]:
# Load the CLIP model and processor
checkpoint = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(checkpoint)
processor = CLIPProcessor.from_pretrained(checkpoint)
model.to(device)

In [None]:
dataset_train = image_title_dataset(df_train)
dataset_eval = image_title_dataset(df_eval)

train_dataloader = DataLoader(dataset_train, batch_size=32, shuffle=True) #Define your own dataloader
eval_dataloader = DataLoader(dataset_eval, batch_size=32, shuffle=True) #Define your own dataloader


In [None]:

# Prepare the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2) # the lr is smaller, more safe for fine tuning to new dataset

# Specify the loss function
loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()


In [None]:
# Train the model
lowest_loss = 999
num_epochs = 20
for epoch in range(num_epochs):

    model.train()
    running_loss = 0.0
    pbar_train = tqdm(train_dataloader, total=len(train_dataloader))
    for batch in pbar_train:

        optimizer.zero_grad()

        images,texts = batch

        images= images.to(device)
        texts = texts.to(device)

        # Forward pass
        output = model(images, texts)
        # Compute loss
        ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
        total_loss = (loss_img(output.logits_per_image,ground_truth) + loss_txt(output.logits_per_text,ground_truth))/2

        # Backward pass
        total_loss.backward()
        #if device == "cpu":
        optimizer.step()
        #else :
        #    convert_models_to_fp32(model)
        #    optimizer.step()
        #    model.convert_weights(model)
        running_loss += total_loss.item()
        pbar_train.set_description()


    model.eval()
    running_loss_eval = 0.0
    pbar_eval = tqdm(eval_dataloader, total=len(eval_dataloader))
    for batch in pbar_eval:
      images,texts = batch
      images = images.to(device)
      texts = texts.to(device)

      # Forward pass
      output = model(images, texts)
      # Compute loss
      ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
      total_loss = (loss_img(output.logits_per_image,ground_truth) + loss_txt(output.logits_per_text,ground_truth))/2
      running_loss_eval+=total_loss.item()


    current_loss = running_loss_eval/len(eval_dataloader)
    if current_loss < lowest_loss:
      model.save_pretrained(f'./models/{checkpoint}-ft')
      lowest_loss = current_loss

    print(f"Epoch {epoch}/{num_epochs}, Training Loss: {running_loss/len(train_dataloader):.4f} Validation Loss {current_loss:.4f}")

model.save_pretrained(f'./models/{checkpoint}-ft-last')

In [None]:
# # Train the model
# num_epochs = 10
# for epoch in range(num_epochs):
#     running_loss = 0.0
#     pbar = tqdm(train_dataloader, total=len(train_dataloader))
#     for batch in pbar:
#         model.train()
#         optimizer.zero_grad()

#         images,texts = batch 
        
#         images= images.to(device)
#         texts = texts.to(device)

#         # Forward pass
#         #logits_per_image, logits_per_text = 
#         output = model(images, texts)
#         # Compute loss
#         ground_truth = torch.arange(len(images),dtype=torch.long,device=device)
#         total_loss = (loss_img(output.logits_per_image,ground_truth) + loss_txt(output.logits_per_text,ground_truth))/2

#         # Backward pass
#         total_loss.backward()
#         #if device == "cpu":
#         optimizer.step()
#         #else : 
#         #    convert_models_to_fp32(model)
#         #    optimizer.step()
#         #    clip.model.convert_weights(model)
#         running_loss += total_loss.item()
#         pbar.set_description()

    
#     print(f"Epoch {epoch}/{num_epochs}, Loss: {running_loss/len(train_dataloader):.4f}")