In [None]:
!pip install wandb onnx -Uq

In [None]:
###############################
##### inference with out of box model
###############################
import requests
from PIL import Image

from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel

# load a fine-tuned image captioning model and corresponding tokenizer and image processor
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

# let's perform inference on an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_path = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/val/01_4A799364-FADA-4722-BC9C-59D4C913B168.jpeg'

image = Image.open(image_path)
pixel_values = image_processor(image, return_tensors="pt").pixel_values

# autoregressively generate caption (uses greedy decoding by default)
generated_ids = model.generate(pixel_values)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)

In [None]:
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel

image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
    "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
)
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id


In [None]:
#################################
###Initialize Dataset 
#################################

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor, AdamW, DataCollatorWithPadding
from torch.optim import Adam
from torch.nn import functional as F
import torch
import random

# Mapping from class ID to class name
class_names = {
    0: 'DUMPSTER', 1: 'VEHICLE', 2: 'SKID_STEER', 3: 'EXCAVATOR', 4: 'VAN',
    5: 'LUMBER_BUNDLE', 6: 'CONE', 7: 'TRUCK', 8: 'GARBAGE_CONTAINER',
    9: 'LADDER', 10: 'POWER_GENERATOR', 11: 'TELESCOPIC_HANDLER',
    12: 'CONCRETE_BUCKET', 13: 'BOOMLIFT', 14: 'PLYWOOD', 15: 'TOILET_CABIN',
    16: 'FORMWORK_PROP_BUNDLE', 17: 'CONDUIT_ROLL', 18: 'FORMWORK_PANEL',
    19: 'CONCRETE_COLUMN', 20: 'PLATE_COMPACTOR', 21: 'TROWEL_POWER',
    22: 'SLAB_SLEEVES', 23: 'MINI_EXCAVATOR', 24: 'CONTAINER', 25: 'SCISSORLIFT',
    26: 'PICKUP_TRUCK', 27: 'MOBILE_CRANE', 28: 'EQUIPMENT', 29: 'TIEBACK_RIG',
    30: 'TOWER_CRANE', 31: 'CONCRETE_PUMP', 32: 'DRILLRIG', 33: 'LOADER',
    34: 'OFFICE_TRAILER', 35: 'DOZER', 36: 'BUS', 37: 'ROLLER', 38: 'CONCRETE_RIDE',
    39: 'BACKHOE_LOADER', 40: 'FORKLIFT', 41: 'GRADER', 42: 'HAND_ROLLER',
    43: 'HOIST_CABIN'
}

def to_captions(name):
    return " ".join(name.lower().split('_'))

class CustomImageTextDataset(Dataset):
    def __init__(self, image_dir, text_dir):
        self.image_dir = image_dir
        self.text_dir = text_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.jpeg')]
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_filename = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_filename)
        text_filename = os.path.splitext(image_filename)[0] + '.txt'
        text_path = os.path.join(self.text_dir, text_filename)

        image = Image.open(image_path).convert('RGB')
        
        
        # Load and parse annotations
        with open(text_path, 'r') as file:
            annotations = file.readlines()
        
        parsed_annotations = [list(map(float, line.strip().split())) for line in annotations]
        # Find the annotation with the largest area
        correct_caption = "an image of unknown"
        if len(parsed_annotations) > 0:
            parsed_annotations.sort(key=lambda x: x[3] * x[4], reverse=True)
            top_annotations = parsed_annotations[:4]
            caption_indices = [int(ann[0]) for ann in top_annotations]
            correct_captions = [class_names[idx] for idx in caption_indices]
            correct_caption = ' '.join(correct_captions)
            correct_caption = f"an image of {to_captions(correct_caption)}"

        pixel_values = image_processor(images=image, return_tensors='pt').pixel_values.squeeze()
        labels = tokenizer(correct_caption, return_tensors='pt').input_ids.squeeze()
        return {'pixel_values': pixel_values, 'labels': labels, 'image_path': image_path, 'captions': correct_caption}

image_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train'
text_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/labels/train'

train_dataset = CustomImageTextDataset(image_dir=image_dir, text_dir=text_dir)

from torch.utils.data import DataLoader


def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = [item['labels'] for item in batch]
    labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=tokenizer.pad_token_id)

    return {
        'pixel_values': pixel_values,
        'labels': labels_padded
    }

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# # Example of iterating through the dataset
for batch in train_dataloader:
    print(batch)
    break

In [None]:
import wandb
wandb.login()

In [None]:
from tqdm import tqdm

model.train()

# Training loop
batch_size=4
learning_rate = 5e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)

# Training loop
num_epochs = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


########Log to Wandb
config = dict(epochs=num_epochs,learning_rate=learning_rate)
wandb.init(project="vision-encoder-decoder-finetuning", config=config)
wandb.watch(model, optimizer, log="all", log_freq=10)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    epoch_loss = 0
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(pixel_values=pixel_values, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        wandb.log({"epoch": epoch, "loss": loss})
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch + 1} loss: {epoch_loss / len(train_dataloader)}")

# Save the model


In [None]:
model.save_pretrained('vision_encoder_decoder/model')
image_processor.save_pretrained('vision_encoder_decoder/feature_extractor')
tokenizer.save_pretrained('vision_encoder_decoder/tokenizer')

In [None]:
#################################
### Single Image
#################################
import requests
from PIL import Image 

image_path = "/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train/01_001--Optima--09-02-2018-0631.jpg"
image_path = "/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train/01_001--Optima--14-11-2017-5846.jpg"
#image_path = "/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train/01_002AFA69-A930-41C6-8982-20000E50EF97.jpeg"
#image_path = "/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train/01_003--Optima--12-08-2017.jpg"

image = Image.open(image_path).convert('RGB')
smaller_image = image.resize((int(image.width / 2), int(image.height / 2)))

# Display the smaller image
display(smaller_image)

In [None]:
########################################
###Inference with trained model 
########################################
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image



from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image

# Load the model, feature extractor, and tokenizer
model = VisionEncoderDecoderModel.from_pretrained("vision_encoder_decoder/model")
feature_extractor = ViTFeatureExtractor.from_pretrained("vision_encoder_decoder/feature_extractor")
tokenizer = AutoTokenizer.from_pretrained("vision_encoder_decoder/tokenizer")

model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

model.generation_config.decoder_start_token_id = tokenizer.cls_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
    return pixel_values

def generate_caption(pixel_values, model, tokenizer, max_length=64, num_beams=4):
    # Generate text
    output_ids = model.generate(pixel_values)
    # Decode the output ids to text
    caption = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
    return caption

pixel_values = preprocess_image(image_path)

caption = generate_caption(pixel_values, model, tokenizer)
print("Generated Caption:", caption)



In [None]:
#################################
### Calculate Accuracy, MRR and Top K 
#################################
from tqdm.notebook import tqdm
from PIL import Image
from torch.utils.data import DataLoader
import time
# Assuming CustomImageTextDataset, preprocess_image, generate_caption, model, and tokenizer are defined elsewhere

image_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/val'
text_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/labels/val'

val_dataset = CustomImageTextDataset(image_dir=image_dir, text_dir=text_dir)

# Function to calculate metrics
def calculate_metrics(val_dataset, model, tokenizer, preprocess_image, k=3, total=200):
    matched = 0
    total_samples = 0
    mrr_total = 0.0
    top_k_matched = 0
    total_inference_time = 0.0

    for i in range(total):
        original_labels = val_dataset[i]["captions"]
        image_path = val_dataset[i]['image_path']
        image = Image.open(image_path)

        pixel_values = preprocess_image(image_path)
        
        start_time = time.time()
        caption = generate_caption(pixel_values, model, tokenizer)
        inference_time = time.time() - start_time
        total_inference_time += inference_time


        predicted_labels = caption.split()
        
        # take only afte an image of 
        original_labels = original_labels.split()[3:]
        predicted_labels = predicted_labels[3:]

        # Calculate accuracy
        for clazz in predicted_labels:
            if clazz in original_labels: 
                matched += 1
                break

        # Calculate MRR
        reciprocal_rank = 0.0
        for rank, predicted_label in enumerate(predicted_labels, start=1):
            if predicted_label in original_labels:
                reciprocal_rank = 1.0 / rank
                break
        mrr_total += reciprocal_rank

        # Calculate Top-K
        if any(predicted_label in original_labels for predicted_label in predicted_labels[:k]):
            top_k_matched += 1

        total_samples += 1

        print(f"original: {original_labels} vs predicted {predicted_labels}")

    accuracy = matched / total_samples
    mrr = mrr_total / total_samples
    top_k_accuracy = top_k_matched / total_samples
    average_inference_time = total_inference_time / total_samples

    return accuracy, mrr, top_k_accuracy, average_inference_time

# Example usage
accuracy, mrr, top_k_accuracy, average_inference_time = calculate_metrics(val_dataset, model, tokenizer, preprocess_image, k=5, total=200)

print(f"Accuracy: {accuracy}")
print(f"Mean Reciprocal Rank (MRR): {mrr}")
print(f"Top-5 Accuracy: {top_k_accuracy}")
print(f"Average Inference Time: {average_inference_time} seconds")
