In [1]:
import os

In [2]:
import json

In [3]:
import torch

In [4]:
import logging

In [5]:
from typing import List, Dict, Optional

In [6]:
from PIL import Image

In [7]:
from tqdm import tqdm

In [8]:
from torch.cuda.amp import autocast,GradScaler

In [9]:
from datasets import load_dataset

In [10]:
from text_image_token_processor_1 import PaliGemmaProcessor

In [18]:
from torch.utils.data import Dataset, DataLoader

In [11]:
from decoder_1 import PaliGemmaForConditionalGeneration

In [12]:
from utils import load_hf_model

In [13]:
logging.basicConfig(level=logging.INFO)

In [14]:
logger = logging.getLogger(__name__)

In [16]:
logger.info('dddd')

INFO:__main__:dddd


In [17]:
!ls

[1m[36m__pycache__[m[m                     multi_modality.ipynb
decoder_1.py                    simple_model_state_dict.pth
final_data.json                 text_image_token_processor_1.py
fintune.ipynb                   tokens.json
gpt_from_scratch.ipynb          translate_from_scratch.ipynb
hg_learn.ipynb                  utils.py
inference.ipynb                 vision_model.ipynb
[1m[36mmarian-finetuned-kde4-en-to-fr[m[m  vision_transformer_1.py
mml_decoder.ipynb


In [20]:
class MultiModalDataset(Dataset):
    
    def __init__(self,
                processor: PaliGemmaProcessor,
                 split: str = 'train',
                 max_length: int = 512,
                 max_samples: int = None
                ):
        self.dataset = load_dataset('food101',split=split)
        if max_samples:
            self.dataset = self.dataset.select(range(max_samples))
        self.processor = processor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self,idx):
        item = self.dataset[idx]
        image = item['image']
        text = f'this is a picture of {item["label"]}'
        
        inputs = self.processor(
            text = [text],
            images=[image],
            padding="max_length",
            truncation=True
        )
        
        return {
            'input_ids':inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'pixel_values':inputs['pixel_values'].squeeze(0)
        }

In [23]:
class Trainer:
    
    def __init__(self,
                model: PaliGemmaForConditionalGeneration,
                 train_dataloader: DataLoader,
                 val_dataloader: Optional[DataLoader],
                 optimizer: torch.optim.Optimizer,
                 device: str,
                 gradient_accumulation_steps: int = 1,
                 max_grad_norm: float = 1.0,
                 use_amp: bool = True,
                ):
        self.model = model
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.optimizer = optimizer
        self.device = device
        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.max_grad_norm = max_grad_norm
        self.use_amp = use_amp
        
        self.scaler = GradScaler() if use_amp else None
    
    def train_epoch(self,epoch: int):
        self.model.train()
        total_loss = 0
        
        pbar = tqdm(self.train_dataloader,desc=f"training epoch {epoch}")
        
        for step, batch in enumerate(pbar):
            batch = {k:v.to(self.device) for k,v in batch.items()}
            
            with autocast(enabled=self.use_amp):
                outputs = self.model(
                    input_ids=batch['input_ids'],
                    pixel_values=batch['pixel_values'],
                    attention_mask=batch['attention_mask']
                )
                
                shift_logits = outputs['logits'][:,:-1,:].contiguous()
                shift_labels = batch["input_ids"][...,1:].contiguous()
                
                loss = torch.nn.functional.cross_entropy(
                    shift_logits.view(-1,shift_logits.size(-1)),
                    shift_labels.view(-1),
                    ignore_index=-100
                )
                
                loss = loss / self.gradient_accumulation_steps
            
            if self.use_amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()
            
            if (step + 1) % self.gradient_accumulation_steps == 0:
                if self.use_amp:
                    self.scaler.unscale_(self.optimizer)
                
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.max_grad_norm
                )
                
                if self.use_amp:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                
                else:
                    
                    self.optimizer.step()
                
                self.optimizer.zero_grad()
                
            total_loss += loss.item()
            avg_loss = total_loss / (step + 1)
            pbar.set_postfix({'loss':avg_loss})
        
        return total_loss / len(self.train_dataloader)
    
    
    @torch.no_grad()
    def evaluate(self):
        if not self.val_dataloader:
            return None
        
        self.model.eval()
        total_loss = 0
        
        for batch in tqdm(self.val_dataloader,desc='Evaluating'):
            batch = {k:v.to(self.device) for k,v in batch.items()}
            
            outputs = self.model(
                input_ids=batch['input_ids'],
                pixel_values=batch['pixel_values'],
                attention_mask=batch['attention_mask']
            )
            
            shift_logits = outputs['logits'][:,:-1,:].contiguous()
            shift_labels = outputs['input_ids'][...,1:].continguous()
            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1,shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100
            )
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(self.val_dataloader)
        return avg_loss

In [26]:
def main():
    
    # Training configuration
    config = {
        'model_path': "/Users/liuchu/vision-launguage-model-from-scratch/paligemma-3b-pt-224",
        'train_json': "data/train.json",
        'val_json': "data/val.json",
        'image_dir': "data/images/",
        'output_dir': "trained_model/",
        'batch_size': 4,
        'gradient_accumulation_steps': 4,
        'learning_rate': 5e-5,
        'weight_decay': 0.01,
        'num_epochs': 1,
        'use_amp': True,
        'max_grad_norm': 1.0,
    }
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, tokenizer = load_hf_model(config['model_path'],device)
    processor = PaliGemmaProcessor(
        tokenizer,
        num_image_tokens=model.config.vision_config.num_image_tokens,
        image_size=model.config.vision_config.image_size
    )
    
    train_dataset = MultiModalDataset(
        processor=processor,
        split='train',
        max_samples=10
    )
    
    val_dataset = MultiModalDataset(
        processor=processor,
        split='validation',
        max_samples=1
    )
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True
    ) if val_dataset else None
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    
    trainer = Trainer(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        optimizer=optimizer,
        device=device,
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        use_amp=config['use_amp'],
        max_grad_norm=config['max_grad_norm']
    )
    
    best_val_loss = float('inf')
    for epoch in range(config['num_epochs']):
        train_loss = trainer.train_epoch(epoch)
        logger.info(f'Epoch {epoch} - Train loss: {train_loss:.4f}')
        
        if val_dataloader:
            val_loss = trainer.evaluate()
            logger.info(f'Epoch {epoch} - Val loss: {val_loss:.4f}')
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                model.save_pretrained(os.path.join(config['output_dir'], 'best_model'))
            checkpoint_dir = os.path.join(config['output_dir'],
                                                  f'checkpoint-{epoch}')
            model.save_pretrained(checkpoint_dir)
            

In [27]:
main()

ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/datasets/food101/revision/e06acf2a88084f04bce4d4a525165d68e0a36c38 (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x37f69c2b0>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: d9778e83-3942-48d6-b645-c6d053c4872c)')