In [None]:
!python -m pip install --upgrade pip -q
!pip install matplotlib -q -U

In [None]:
!pip install q- datasets
!pip install transformers -q -U
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/perft.git
!pip install hf_transfer -q -U
!pip install pickleshare -q

In [None]:
%env HF_HUB_ENABLE_HF_TRANSFER=1

In [None]:
import os

if not os.path.isdir("LLaVA"):
    !git clone https://github.com/haotian-liu/LLaVA.git
else:
    print("LLaVA directory already exists. Skipping clone.")

In [None]:
import re

#Define the path to the builder.py file
file_path = 'LLaVA/llava/model/builder.py'

#Read the content of the file
with open(file_path, 'r') as file:
    content = file.read()
    
#Regular expression to find the block between 'vision_tower = model.get....
pattern_block = (
    r'(vision_tower = model.get_vision_tower\(\)\n)'
    r'.*?' #non-greedy match for any characters
    r'(image_processor = vision_tower.image_processor)'
)

replacement_block = (
    r'\1' # keep starting line unchaged
    '     if not vision_tower.is_loader:\n'
    '       print(\'vision_tower is not loaded so loading it now\')\n'
    '       vision_tower.load_model(device_map=device_map)\n'
    '       vision_tower.to(deice=device, dtype=torch.bfloat16)\n'
    '     else:\n'
    '       pint(\'vision_tower is loaded\')\n'
    r'    \2' #keep the ending line unchanged
)

#replace the specific block
content = re.sub(pattern_block, replacement_block, content, flegs=re.DOTALL)

#Write the modified content back to the file
with open(file_path, 'w') as file:
    file.write(content)
print('The script has been updated successfully')

In [None]:
import re

file_path = 'LLaVA/llava/model/builder.py'

#read the content of the file
with open(file_path, 'r') as file:
    concept = file.read()
    
#regular expression to find 'float16' not preceded by 'b'
pattern = r'(?<!b)float16'

#check if there are any matches
if re.search(pattern, content):
    #Replace 'float16' with 'bfloat16'
    modified_content = re.sub(pattern, 'bfloat16', content)
    
    #Write the modified contnet back to the file
    with open(file_path, 'w') as file:
        file.write(modified_content)
    
    print("All necessary instances of floats have been replaced with..")
else:
    print('No replacement needed. All instances of float16 already have.. ')
    

In [None]:
%cd LLaVA

In [None]:
#can take up to 5 mins
!pip install -e . -q

In [None]:
# !git pull
# !pip install -e . -q

!pip install protobuf -q -U
!pip install --upgrade Pillow -q
!pip install -e ".[train]" -q
!pip install flash-attn --no-build-isolation -q 


In [1]:
#load model
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from PIL import Image
import transformers
from transformers import AutoProcessor, Trainer, TrainingArgument, BitsA
import torchvision.transforms as transforms 

ModuleNotFoundError: No module named 'torch'

In [None]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

device = "cuda" if torch.cuda.is_available() else "cpu"

model_path = 'liuhaotian/llava-v1.6-mistral-7b'
#model_path = "Trelis/llava-v1.6-mistral-7b-PATCHED"

model_name=get_model_name_from_path(model_path)

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=model_name,
    cache_dir='',
    use_flash_attn=True,  
)
 

In [None]:
#print(model)

In [None]:
#print(processor)

In [None]:
#print(model.config)
#print(tokenizer.pad_token_id)
#print(tokenizer)

In [None]:
import torch

#check dtype of all modules, focusing on those not torch.bfloat16
print("Modules not torch.bfloat16:")
for name, module in model.named_modules():
    if hasattr(module, 'parameters') and list(module.parameters()):
        #check if any parameter of the module is not bfloat16
        if any(param.dtype != torch.bfloat16 for para in module.parameters()):
            print(f"{name}: {next(module.parameters()).dtype}")
    else:
        #Optionally, acknowledge module without parameters if needed
        pass