In [1]:
import torch
from transformers import PreTrainedModel
from loader.model_loader import load_vision_model, load_llm
from vision.projector import load_vision_projector
from vision.feature_select import feature_select
from vision.learned_encoding import load_learned_encoding
from image_handling.padding import resize_with_padding, load_images
from image_handling.slice import split_image
from transformers import BitsAndBytesConfig
import math
import requests
from PIL import Image
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"
quantization_config = BitsAndBytesConfig(load_in_8bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
vision_model , image_processor = load_vision_model("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", device = device )
llm, tokenizer = load_llm("llama3/8B-instruct", device = device, quantization_config = quantization_config)
vision_projector = load_vision_projector()
llm_dim = llm.config.hidden_size
vision_dim = vision_model.config.hidden_size
learned_positional = load_learned_encoding(vision_dim, llm_dim, "linear")

def get_positional_encoding(max_seq_len, embedding_dim):
    position_encoding = torch.zeros(max_seq_len, embedding_dim)
    position = torch.arange(0, max_seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
    position_encoding[:, 0::2] = torch.sin(position * div_term)
    position_encoding[:, 1::2] = torch.cos(position * div_term)
    return position_encoding

def prepare( inputs, images):
    if(images is None):
        return 0
    images = load_images(images)

def encode_images_positional_encoding( images, padding = True, learned_encoding = True):
    #make sure all images are already preprocessed 
    MAX_LEN = 8

    image_tensors = image_processor.preprocess(images, return_tensors='pt')['pixel_values'].to(device)
    #for the case where there are less than 8 images, add empty tensors
    if(padding):
        for i in range(8-len(images)):
            image_tensors = torch.cat((image_tensors, torch.zeros_like(image_tensors[0]).unsqueeze(0)), dim=0)
        
    with torch.no_grad(): 
        batch_features = vison_model(image_tensors, output_hidden_states=True)
        image_features = batch_features.hidden_states[-1]
        image_features = feature_select(image_features, "patch")
        # Positional Encoding
        max_seq_len = image_features.shape[1]
        pos_encoding = get_positional_encoding(max_seq_len, image_features.shape[-1]).to(device)
        image_features += pos_encoding

    # Learned Encoding
    if learned_encoding:
        image_features += learned_encoding_layer(image_features)

        return projector(image_features)

def images_uhd_positional_encoding( image):
    #lower the image with padding to 
    resized_image = resize_with_padding(image, 336)
    splits , h , w = split_image(image)
    encode_images_positional_encoding(splits)

def imaged_uhd_arranged( image):
    resized_image = resize_with_padding(image, 336)
    splits , h , w = split_image(image)
    #get the embedding of the tokens "," and "\n" from the llm tokenizer
    tokens = tokenizer.tokenize("\n")
    #get the embedding
    token_embeddings = llm.get_input_embeddings()
    #get the embedding of the tokens
    token_embeddings = token_embeddings(torch.tensor(tokens).to(device))

    encode_images_no_positional_encoding(splits ,padding = False)
    for i in range(h):
        for j in range(w):
            print(f"Image {i*w+j} at position {i},{j}")

def encode_images_no_positional_encoding( image):
    return 0

  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
  return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)


TypeError: load_llm() got an unexpected keyword argument 'quantization_config'