# imports

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoProcessor, set_seed, AutoModelWithLMHead
from transformers import T5Tokenizer, T5ForConditionalGeneration
import pandas as pd
from PIL import Image
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import requests
from torchvision import transforms
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset, random_split

# Loading Encoder and Decoder and device (cpu or gpu)

In [3]:
image_model_name = 'microsoft/Florence-2-base-ft'
image_encoder = AutoModelForCausalLM.from_pretrained(image_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(image_model_name, trust_remote_code=True)

In [4]:
text_model_name = "sberbank-ai/rugpt3small_based_on_gpt2"
text_decoder = AutoModelForCausalLM.from_pretrained(text_model_name)
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)

text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
text_decoder.resize_token_embeddings(len(text_tokenizer))

Embedding(50258, 768)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Functions

In [6]:
class ImageCaptionDataset(Dataset):
    def __init__(self, dataframe, caption_column, image_column, image_dir, transform=None):
        self.dataframe = dataframe[dataframe[caption_column].notna()]
        self.image_dir = image_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        image_name = self.dataframe.iloc[idx][image_column]
        caption = self.dataframe.iloc[idx][caption_column]
        
        image_path = os.path.join(self.image_dir, image_name)
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        caption_ids = text_tokenizer.encode(
            caption, 
            return_tensors='pt', 
            max_length=1024, 
            truncation=True, 
            padding='max_length'
        )
        
        return image, caption_ids.squeeze(0)  

In [7]:
def train_model(
    image_encoder,
    text_decoder,
    projection_layer,
    optimizer,
    train_dataloader,
    val_dataloader,
    num_epochs,
    device,
    save_path="model.pth",
    train_encoder = False
):  
    train_losses = []
    val_losses = []

    for param in image_encoder.parameters():
        param.requires_grad = train_encoder

    for epoch in range(num_epochs):
        text_decoder.train()
        projection_layer.train()
        
        total_train_loss = 0
        
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} [Training]", leave=False)
        for images, caption_ids in progress_bar:
            images = images.to(device)
            caption_ids = caption_ids.to(device)
            
            with torch.no_grad():
                image_features = image_encoder._encode_image(images)

            image_features = projection_layer(image_features.flatten(1))
            
            batch_size = image_features.size(0)
            seq_length = caption_ids.size(1)
            image_features = image_features.unsqueeze(1).repeat(1, seq_length, 1)
            
            outputs = text_decoder(inputs_embeds=image_features, labels=caption_ids)
            loss = outputs.loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            progress_bar.set_postfix({"Training Loss": loss.item()})
        
        avg_train_loss = total_train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        
        text_decoder.eval()
        projection_layer.eval()
        
        total_val_loss = 0
        
        with torch.no_grad():
            for images, caption_ids in val_dataloader:
                images = images.to(device)
                caption_ids = caption_ids.to(device)
                
                image_features = image_encoder._encode_image(images)
                image_features = projection_layer(image_features.flatten(1))
                
                batch_size = image_features.size(0)
                seq_length = caption_ids.size(1)
                image_features = image_features.unsqueeze(1).repeat(1, seq_length, 1)

                outputs = text_decoder(inputs_embeds=image_features, labels=caption_ids)
                loss = outputs.loss
                
                total_val_loss += loss.item()
        
        avg_val_loss = total_val_loss / len(val_dataloader)
        val_losses.append(avg_val_loss)
        if avg_val_loss > val_losses[-1]:
            torch.save({
            'text_decoder_state_dict': text_decoder.state_dict(),
            'projection_layer_state_dict': projection_layer.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, save_path)     
            print(f"Val Loss imporoved. Model is saved at {save_path}")
            
        print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    return train_losses, val_losses

# Training

In [8]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


dataframe = pd.read_csv("/kaggle/input/semart/semart_topic_annotated_train_ru.csv") 
image_dir = "/kaggle/input/semart/SemArt/SemArt/Images" 

In [9]:
caption_column = 'ru_content' 
image_column = 'annotations/img'

In [10]:
dataset = ImageCaptionDataset(dataframe, caption_column, image_column, image_dir, transform=transform)
train_size = int(0.8 * len(dataset))  
val_size = len(dataset) - train_size  
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 2
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [11]:
dummy_image = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    dummy_features = image_encoder._encode_image(dummy_image)
input_dim = dummy_features.flatten(1).shape[-1]  
hidden_size = text_decoder.config.hidden_size 
projection_layer = torch.nn.Linear(input_dim, hidden_size).to(device)

In [12]:
optimizer = torch.optim.AdamW(list(text_decoder.parameters()) + list(projection_layer.parameters()), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss(ignore_index=text_tokenizer.pad_token_id)

In [13]:
num_epochs = 5

image_encoder.to(device)
text_decoder.to(device)
projection_layer.to(device)

train_losses, val_losses = [], []

In [14]:
train_losses, val_losses = train_model(
    image_encoder = image_encoder,
    text_decoder = text_decoder,
    projection_layer = projection_layer,
    optimizer = optimizer,
    train_dataloader = train_dataloader,
    val_dataloader = val_dataloader,
    num_epochs = num_epochs,
    device = device,
    save_path="florence_rugpt_model.pth", 
    train_encoder = False)

                                                                                             

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(range(1, num_epochs + 1), train_losses, label="Training Loss", marker="o")
plt.plot(range(1, num_epochs + 1), val_losses, label="Validation Loss", marker="o")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)
plt.show()

# Inference

In [None]:
image_model_name = 'microsoft/Florence-2-base-ft'
image_encoder = AutoModelForCausalLM.from_pretrained(image_model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(image_model_name, trust_remote_code=True)

In [None]:
text_model_name = "sberbank-ai/rugpt3small_based_on_gpt2"
text_tokenizer = AutoTokenizer.from_pretrained(text_model_name)

text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

In [None]:
def encode_image(image_path):
    image = Image.open(image_path).convert("RGB")

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    image_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        image_features = image_encoder._encode_image(image_tensor)

    return image_features


def generate_caption(image_features, max_length=1024):
    hidden_size = text_decoder.config.hidden_size
    if image_features.shape[-1] != hidden_size:
        projection_layer = torch.nn.Linear(image_features.shape[-1], hidden_size)
        image_features = projection_layer(image_features)

    with torch.no_grad():
        output_ids = text_decoder.generate(
            inputs_embeds=image_features,  
            max_new_tokens=max_length,
            num_beams=5,
            early_stopping=True,
            no_repeat_ngram_size=2,
            top_k=50,
            top_p=0.95,
            do_sample=True,
            temperature=0.7,
            pad_token_id=text_tokenizer.eos_token_id,
            bos_token_id=text_tokenizer.bos_token_id,
            eos_token_id=text_tokenizer.eos_token_id,
        )

    caption = text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
hidden_size = 768 

dummy_image = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    dummy_features = image_encoder._encode_image(dummy_image)
input_dim = dummy_features.flatten(1).shape[-1]  

projection_layer = torch.nn.Linear(input_dim, hidden_size)

In [None]:
model_path = "/kaggle/working/florence_rugpt_model.pth"
checkpoint = torch.load(model_path, map_location=device)

In [None]:
text_decoder = AutoModelForCausalLM.from_pretrained(text_model_name)
text_decoder.load_state_dict(checkpoint['text_decoder_state_dict'])
projection_layer.load_state_dict(checkpoint['projection_layer_state_dict'])

In [None]:
image_encoder.to(device)
text_decoder.to(device)
projection_layer.to(device)

In [None]:
img_path = '/kaggle/input/semart/SemArt/SemArt/Images/42791-1sacris.jpg'

In [None]:
image_features = encode_image(img_path)
caption = generate_caption(image_features)
print("Описание изображения:", caption)