In [13]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
from safetensors.torch import save_file
import os
from pytorch_lightning.callbacks import ModelCheckpoint

In [14]:
# Dataset class to handle CSV and image data
class CustomDataset(Dataset):
    def __init__(self, dataframe, processor, img_dir="AGIQA-3K"):
        self.data = dataframe
        self.img_dir = img_dir
        self.processor = processor    
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = f"{self.img_dir}/{row['name']}"
        text = row['prompt']
        for column in ["adj1", "adj2", "style"]:
            if row[column] and not pd.isna(row[column]) and len(row[column].strip()) > 0:
                text += f", {row[column]}"
        label = (row[['mos_quality', 'mos_align']] / 5).tolist()

        # print(text, type(text))
        # print(label, type(label))
        # print(image_path, type(image_path))
        with Image.open(image_path) as image:
            image = image.convert("RGB")
            inputs = self.processor(
                text=[text], 
                images=image, 
                return_tensors="pt", 
                padding="max_length", 
                max_length=self.processor.tokenizer.model_max_length, 
                truncation=True
            )
        inputs['label'] = torch.tensor(label, dtype=torch.float)
        # print(inputs)
         
        return inputs

# LightningModule class for the training loop
class CLIPRegressionModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-4):
        super().__init__()
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.clip_model.eval()  # Freeze CLIP model weights
        for param in self.clip_model.parameters():
            param.requires_grad = False
        
        self.text_projection = nn.Linear(self.clip_model.config.projection_dim, 512)
        self.image_projection = nn.Linear(self.clip_model.config.projection_dim, 512)
        self.norm_text_layer = nn.LayerNorm(512)
        self.norm_image_layer = nn.LayerNorm(512)
        self.final_layer = nn.Linear(512, 2)
        self.criterion = nn.L1Loss()
        self.learning_rate = learning_rate
    
    def forward(self, text_inputs, image_inputs):
        with torch.no_grad():
            text_features = self.clip_model.get_text_features(**text_inputs)
            image_features = self.clip_model.get_image_features(**image_inputs)
        
        text_proj = self.text_projection(text_features)
        image_proj = self.image_projection(image_features)
        dot_product = self.norm_text_layer(text_proj) * self.norm_image_layer(image_proj)
        output = F.sigmoid(self.final_layer(dot_product))
        
        return output
    
    def training_step(self, batch, batch_idx):
        text_inputs = {key: val.squeeze(1) for key, val in batch.items() if 'input_ids' in key}
        image_inputs = {key: val.squeeze(1) for key, val in batch.items() if 'pixel_values' in key}
        labels = batch['label']
        
        outputs = self(text_inputs, image_inputs)
        loss = self.criterion(outputs, labels)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        text_inputs = {key: val.squeeze(1) for key, val in batch.items() if 'input_ids' in key}
        image_inputs = {key: val.squeeze(1) for key, val in batch.items() if 'pixel_values' in key}
        labels = batch['label']
        
        outputs = self(text_inputs, image_inputs)
        loss = self.criterion(outputs, labels)
        
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
# Main training script
def train_model(csv_file, img_dir, batch_size=128, max_epochs=30):
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    dataframe = pd.read_csv(csv_file)    
    train_df, test_df = train_test_split(dataframe, test_size=0.1, random_state=42)
    
    train_dataset = CustomDataset(train_df, processor, img_dir=img_dir)
    test_dataset = CustomDataset(test_df, processor, img_dir=img_dir)
    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4)
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        mode='min',
        save_top_k=3,
        filename='{epoch:02d}-{val_loss:.3f}',
        verbose=True
    )

    model = CLIPRegressionModel()
    trainer = pl.Trainer(max_epochs=max_epochs, devices=[0], accelerator="gpu", callbacks=[checkpoint_callback])
    trainer.fit(model, train_dataloader, test_dataloader)
    return checkpoint_callback

In [19]:
## Run train
csv_file = "AGIQA-3k-Database/data.csv"
img_dir = "AGIQA-3K"
checkpoint_callback = train_model(csv_file, img_dir)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name             | Type      | Params
-----------------------------------------------
0 | clip_model       | CLIPModel | 427 M 
1 | text_projection  | Linear    | 393 K 
2 | image_projection | Linear    | 393 K 
3 | norm_text_layer  | LayerNorm | 1.0 K 
4 | norm_image_layer | LayerNorm | 1.0 K 
5 | final_layer      | Linear    | 1.0 K 
6 | criterion        | L1Loss    | 0     
-----------------------------------------------
790 K     Trainable params
427 M     Non-trainable params
428 M     Total params
1,713.628 Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (21) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 0, global step 21: 'val_loss' reached 0.10066 (best 0.10066), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=00-val_loss=0.101.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 1, global step 42: 'val_loss' reached 0.08894 (best 0.08894), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=01-val_loss=0.089.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 63: 'val_loss' reached 0.08378 (best 0.08378), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=02-val_loss=0.084.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 84: 'val_loss' reached 0.08245 (best 0.08245), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=03-val_loss=0.082.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 105: 'val_loss' reached 0.08157 (best 0.08157), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=04-val_loss=0.082.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 126: 'val_loss' reached 0.08075 (best 0.08075), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=05-val_loss=0.081.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 147: 'val_loss' reached 0.07913 (best 0.07913), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=06-val_loss=0.079.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 168: 'val_loss' reached 0.07881 (best 0.07881), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=07-val_loss=0.079.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 189: 'val_loss' reached 0.07722 (best 0.07722), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=08-val_loss=0.077.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 210: 'val_loss' reached 0.07645 (best 0.07645), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=09-val_loss=0.076.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 231: 'val_loss' reached 0.07671 (best 0.07645), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=10-val_loss=0.077.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 252: 'val_loss' reached 0.07637 (best 0.07637), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=11-val_loss=0.076.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 273: 'val_loss' reached 0.07528 (best 0.07528), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=12-val_loss=0.075.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 294: 'val_loss' reached 0.07595 (best 0.07528), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=13-val_loss=0.076.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 315: 'val_loss' reached 0.07477 (best 0.07477), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=14-val_loss=0.075.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 336: 'val_loss' reached 0.07513 (best 0.07477), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=15-val_loss=0.075.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 357: 'val_loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 17, global step 378: 'val_loss' reached 0.07436 (best 0.07436), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=17-val_loss=0.074.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 18, global step 399: 'val_loss' reached 0.07474 (best 0.07436), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=18-val_loss=0.075.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 19, global step 420: 'val_loss' reached 0.07399 (best 0.07399), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=19-val_loss=0.074.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 441: 'val_loss' reached 0.07396 (best 0.07396), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=20-val_loss=0.074.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 462: 'val_loss' reached 0.07382 (best 0.07382), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=21-val_loss=0.074.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 483: 'val_loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 504: 'val_loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 525: 'val_loss' reached 0.07356 (best 0.07356), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=24-val_loss=0.074.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 546: 'val_loss' reached 0.07318 (best 0.07318), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=25-val_loss=0.073.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 567: 'val_loss' reached 0.07303 (best 0.07303), saving model to '/home/23_Train_Aesthetics_Model/lightning_logs/version_4/checkpoints/epoch=26-val_loss=0.073.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 588: 'val_loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 609: 'val_loss' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 29, global step 630: 'val_loss' was not in top 3
`Trainer.fit` stopped: `max_epochs=30` reached.


In [17]:
checkpoint_callback.best_model_path

'/home/23_Train_Aesthetics_Model/lightning_logs/version_3/checkpoints/epoch=29-val_loss=0.071.ckpt'

In [18]:
# Find the best checkpoint and save it as safetensors
checkpoint = torch.load(checkpoint_callback.best_model_path, map_location=torch.device('cpu'))
model_state_dict = checkpoint['state_dict']
save_file(model_state_dict, "agiqa_3k_clip_vitl14.safetensors")