In [None]:
!pip install datasets accelerate matplotlib -U
!pip install torch torchvision pillow


In [None]:
!pip install transformers

In [None]:
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor, AutoTokenizer, AutoImageProcessor
import torch
import torch.nn.functional as F

# Load the pre-trained models

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
processor = VisionTextDualEncoderProcessor(image_processor, tokenizer)

In [None]:
###############################
##### Contrastive Dataset 
###############################
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from torch.optim import Adam
from torch.nn import functional as F
import torch
import random

# Mapping from class ID to class name
class_names = {
    0: 'DUMPSTER', 1: 'VEHICLE', 2: 'SKID_STEER', 3: 'EXCAVATOR', 4: 'VAN',
    5: 'LUMBER_BUNDLE', 6: 'CONE', 7: 'TRUCK', 8: 'GARBAGE_CONTAINER',
    9: 'LADDER', 10: 'POWER_GENERATOR', 11: 'TELESCOPIC_HANDLER',
    12: 'CONCRETE_BUCKET', 13: 'BOOMLIFT', 14: 'PLYWOOD', 15: 'TOILET_CABIN',
    16: 'FORMWORK_PROP_BUNDLE', 17: 'CONDUIT_ROLL', 18: 'FORMWORK_PANEL',
    19: 'CONCRETE_COLUMN', 20: 'PLATE_COMPACTOR', 21: 'TROWEL_POWER',
    22: 'SLAB_SLEEVES', 23: 'MINI_EXCAVATOR', 24: 'CONTAINER', 25: 'SCISSORLIFT',
    26: 'PICKUP_TRUCK', 27: 'MOBILE_CRANE', 28: 'EQUIPMENT', 29: 'TIEBACK_RIG',
    30: 'TOWER_CRANE', 31: 'CONCRETE_PUMP', 32: 'DRILLRIG', 33: 'LOADER',
    34: 'OFFICE_TRAILER', 35: 'DOZER', 36: 'BUS', 37: 'ROLLER', 38: 'CONCRETE_RIDE',
    39: 'BACKHOE_LOADER', 40: 'FORKLIFT', 41: 'GRADER', 42: 'HAND_ROLLER',
    43: 'HOIST_CABIN'
}

class CustomDataset(Dataset):
    def __init__(self, image_dir, text_dir):
        self.image_dir = image_dir
        self.text_dir = text_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg') or f.endswith('.jpeg')]
        
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_filename = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_filename)
        text_filename = os.path.splitext(image_filename)[0] + '.txt'
        text_path = os.path.join(self.text_dir, text_filename)


        # Load and parse annotations
        with open(text_path, 'r') as file:
            annotations = file.readlines()
        
        parsed_annotations = [list(map(float, line.strip().split())) for line in annotations]

        image = Image.open(image_path)
        
        # Find the annotation with the largest area
        correct_caption = "unknown"
        if len(parsed_annotations) > 0:
            parsed_annotations.sort(key=lambda x: x[3] * x[4], reverse=True)
            top_annotations = parsed_annotations[:2]
            caption_indices = [int(ann[0]) for ann in top_annotations]
            correct_captions = [class_names[idx] for idx in caption_indices]
            correct_caption = ' '.join(correct_captions)
        
        inputs = processor(text=[correct_caption], images=[image], return_tensors="pt", padding="max_length")
        
        return inputs

image_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train'
text_dir = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/labels/train'

dataset = CustomDataset(image_dir=image_dir, text_dir=text_dir)

# To use with a DataLoader
from torch.utils.data import DataLoader


dataloader = DataLoader(dataset, batch_size=50)

# # Example of iterating through the dataset
for batch in dataloader:
    print(batch)
    break


In [None]:
dataset[343]["pixel_values"].shape

In [None]:

# Define a simple training loop
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor, TrainingArguments, Trainer
import torch
import torch.nn.functional as F

# Load the pre-trained models
vision_model_name = "openai/clip-vit-base-patch32"
text_model_name = "bert-/base-uncased"

model = model = VisionTextDualEncoderModel.from_vision_text_pretrained(
    "google/vit-base-patch16-224", "google-bert/bert-base-uncased"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:
!export CUDA_LAUNCH_BLOCKING=1

In [None]:
###############################
##### Contrastive Training  
###############################
import torch
from tqdm import tqdm

# Define a simple training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-7,betas=(0.9,0.98),eps=1e-6,weight_decay=0.001)

num_epochs = 3

model.eval()

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    total_loss = 0
    
    tqdm_object = tqdm(dataloader, total=len(dataloader),position=0, leave=True)
    nb_batches = len(dataloader)
    for i, batch in enumerate(tqdm_object):
        # Move inputs to the GPU if available
        correct_inputs = {
            "input_ids": batch["input_ids"].to(device),
            "attention_mask": batch["attention_mask"].to(device),
            "pixel_values": batch["pixel_values"].to(device),
            "return_loss":True
        }
        # Forward pass for correct pairs
        correct_outputs = model(**correct_inputs)
        logits_per_image_correct = correct_outputs.logits_per_image
        logits_per_text_correct = correct_outputs.logits_per_text

        # Contrastive loss
        labels = torch.arange(logits_per_image_correct.size(0)).to(device)
        loss_image_correct = F.cross_entropy(logits_per_image_correct, labels)
        loss_text_correct = F.cross_entropy(logits_per_text_correct, labels)

        loss = (loss_image_correct + loss_text_correct) / 2 
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
        tqdm_object.set_postfix(
            batch="{}/{}".format(i+1,nb_batches),
            train_loss=loss.item(),
            lr=5e-5
        )
        

    average_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}")

print("Training complete.")
model.save_pretrained("./vit-bert")
processor.save_pretrained("./vit-bert")

In [None]:
model.save_pretrained("./vit-bert")
processor.save_pretrained("./vit-bert")

In [None]:
model

In [None]:
image_path = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/val/2018-11-22 11.02.10.jpg'
#image_path = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/val/01_4A799364-FADA-4722-BC9C-59D4C913B168.jpeg'
#image_path = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/train/01_002AFA69-A930-41C6-8982-20000E50EF97.jpeg'
#image_path = '/mnt/data/ypatel/ObjectDetection/Dataset/Dataset_D8/images/val/20210325175112_production_2746262081.jpeg'

image = Image.open(image_path).convert("RGB")
image

In [None]:
###############################
#### Inference with Contrastive 
###############################
from transformers import VisionTextDualEncoderModel, BertTokenizer
from PIL import Image
import torch 
from torchvision import transforms

# Load the saved model and processor
model = VisionTextDualEncoderModel.from_pretrained("./vit-bert")
processor = VisionTextDualEncoderProcessor.from_pretrained("./vit-bert")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
model.to(device)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define a function for inference
def predict(image_path, captions):
    # Preprocess the image
    image = Image.open(image_path).convert("RGB")
    inputs = processor(text=captions, images=[image], return_tensors="pt", padding=True)

    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    pixes_values = inputs.pixel_values.to(device)

    with torch.no_grad():
        outputs = model(pixel_values=pixes_values, input_ids=input_ids, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

    return logits_per_image, logits_per_text

captions = list(class_names.values())

logits_per_image, logits_per_text = predict(image_path, captions)
correct_probs = F.softmax(logits_per_image, dim=-1)
print(logits_per_image)
#correct_label = torch.argmax(correct_probs, dim=-1).item()
correct_label_indices =torch.topk(correct_probs, 5).indices.squeeze().tolist()
correct_classes = [class_names[i] for i in correct_label_indices]

print(correct_classes)


In [None]:
torch.topk(correct_probs, 5).indices.squeeze().tolist()

In [None]:
#################################
### Calculate Accuracy
#################################
from tqdm.notebook import tqdm

def extract_top(scores, n):
    sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    top_n = sorted_scores[:n]
    captions = [cap[0] for cap in top_n]
    return captions

matched = 0
total = 200
for i in range(total):
    original_captions = dataset[i][0]
    image_path = dataset[i][3]
    image = Image.open(image_path)
    
    captions = list(class_names.values())

    logits_per_image, logits_per_text = predict(image_path, captions)
    correct_probs = F.softmax(logits_per_image, dim=-1)
    #correct_label = torch.argmax(correct_probs, dim=-1).item()
    correct_label_indices =torch.topk(correct_probs, 5).indices.squeeze().tolist()
    correct_classes = [class_names[i] for i in correct_label_indices]
    
    
    for clazz in correct_classes:
        if clazz in original_captions: 
            matched += 1
    
    print(f"original: {original_captions} vs predicted {correct_classes}")
    
accuracy = matched/total
    

In [None]:
accuracy