In [None]:
from PIL import Image
from transformers import AutoProcessor, AutoModel
import torch
import matplotlib.pyplot as plt

# model = AutoModel.from_pretrained("google/siglip-base-patch16-224")

# processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")

# img = plt.imread("/root/multiview-robust-clip/data/objaverse/renderings/0a0c75dab0844e7fa5b299d4af858bec/004.png")[:,:, :3]

# texts = ["a photo of 2 cats", "a photo of 2 dogs"]

# inputs = processor(text=texts, images=img, padding="max_length", return_tensors="pt")

# with torch.no_grad():
#     outputs = model(**inputs)

In [None]:
import torch
import torch.nn as nn
from typing import List

from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer, AutoModel, AutoProcessor, AutoTokenizer

class VLM(nn.Module):
    def __init__(self, vlm_name: str = 'clip'):
        super().__init__()
        if vlm_name == 'clip':
            self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
        elif vlm_name == 'siglip':
            self.model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
            self.processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
            self.tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
        else:
            raise Exception("Provide a valid VLM name [clip | siglip]")
    
    @torch.no_grad()
    def embed_image(self, x):
        inputs = self.processor(images=x, return_tensors="pt")
        return self.model.get_image_features(**inputs)

    @torch.no_grad()
    def embed_text(self, x: List[str]):
        tokens = self.tokenizer(x, padding="max_length", return_tensors="pt")
        return self.model.get_text_features(**tokens)
    
    @torch.no_grad()
    def embed(self, texts: List[str], image):
        inputs = self.processor(text=texts, images=image, padding="max_length", return_tensors="pt")
        outputs = self.model(**inputs)
        return {
            "text_embed": outputs['text_embeds'],
            "image_embed": outputs['image_embeds']
        }


In [None]:
from src.models.vlm import VLM

In [None]:
clip = VLM('clip')

In [None]:
siglip = VLM('sigplip')

In [None]:
img = plt.imread("/root/multiview-robust-clip/data/objaverse/renderings/0a0c75dab0844e7fa5b299d4af858bec/004.png")[:,:, :3]

texts = ["a photo of 2 cats", "a photo of 2 dogs"]

In [None]:
img.shape

In [None]:
siglip.embed_image(img).shape

In [None]:
siglip.embed(texts, img)

In [None]:
clip.forward_text(texts).shape

In [None]:
outputs['text_embeds']

In [None]:
inputs['pixel_values'].shape

In [None]:
with torch.no_grad():
    outputs = model.get_image_features(inputs['pixel_values'])

In [None]:
outputs.shape

In [None]:
outputs.image_embeds.shape

In [None]:
inputs.keys()

In [None]:
img.shape

In [None]:
from pathlib import Path

In [None]:
Path("/root/home/") / "image.png"

In [None]:
import json
import pandas as pd
import numpy as np
np.random.seed(42)

# Load the JSON file containing object names
with open('/root/multiview-robust-clip/data/objaverse/uid_to_name.json', 'r') as file:
    uid_to_name = json.load(file)

# Get all unique IDs from the uid_to_name dictionary
all_uids = list(uid_to_name.keys())

# Shuffle the list of UIDs to ensure randomness
np.random.shuffle(all_uids)

# Calculate the number of samples for each split
num_total = len(all_uids)
num_train = int(num_total * 0.8)
num_val = int(num_total * 0.1)
num_test = num_total - num_train - num_val

# Split the UIDs into train, validation, and test sets
train_uids = all_uids[:num_train]
val_uids = all_uids[num_train:num_train + num_val]
test_uids = all_uids[num_train + num_val:]

# Create DataFrames for train, validation, and test sets
train_df = pd.DataFrame(train_uids, columns=['uid'])
val_df = pd.DataFrame(val_uids, columns=['uid'])
test_df = pd.DataFrame(test_uids, columns=['uid'])

# Save DataFrames to CSV files
train_df.to_csv('/root/multiview-robust-clip/data/objaverse/train.csv', index=False)
val_df.to_csv('/root/multiview-robust-clip/data/objaverse/val.csv', index=False)
test_df.to_csv('/root/multiview-robust-clip/data/objaverse/test.csv', index=False)

# Create a DataFrame for overfitting with only a single ID
train_overfit_df = pd.DataFrame(train_uids[:1], columns=['uid'])

# Save the DataFrame to a CSV file
train_overfit_df.to_csv('/root/multiview-robust-clip/data/objaverse/train_overfit.csv', index=False)

# Create a DataFrame for a small batch training with 8 IDs
train_batch_df = pd.DataFrame(train_uids[:8], columns=['uid'])

# Save the DataFrame to a CSV file
train_batch_df.to_csv('/root/multiview-robust-clip/data/objaverse/train_batch.csv', index=False)

print("CSV files for overfitting and small batch training have been saved.")