In [1]:
import torch
from models import Showo, MAGVITv2
from training.prompting_utils import UniversalPrompting, create_attention_mask_for_mmu, create_attention_mask_for_mmu_vit
from training.utils import get_config, flatten_omega_conf, image_transform
from transformers import AutoTokenizer
from models.clip_encoder import CLIPVisionTower
from transformers import CLIPImageProcessor
import llava.llava.conversation as conversation_lib

conversation_lib.default_conversation = conversation_lib.conv_templates["phi1.5"]


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
torch.cuda.empty_cache()
#del model

In [3]:
from omegaconf import DictConfig, ListConfig, OmegaConf
import os

config = OmegaConf.load('configs/fashionm3_training.yaml')

# Fix paths for Windows
base_path = os.getcwd()  # Current directory
config.model.vq_model.pretrained_model_path = os.path.join(base_path, "models",
"magvitv2", "pytorch_model.safetensors")
config.model.showo.pretrained_model_path = os.path.join(base_path, "models",
"show-o-512x512-wo-llava-tuning")

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory /
1024**3:.1f}GB")

Using device: cuda:0
GPU: NVIDIA GeForce RTX 3070 Ti Laptop GPU
GPU Memory: 8.0GB


In [5]:

# show o tokenizer setup and adding special tokens to universal prompting
# llm model : 'microsoft/phi-1_5'
tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side ="left")
uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length,
                                       special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
                                       ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob)


In [6]:
import os
import torch

magvit_path = config.model.vq_model.pretrained_model_path
print(f"File path: {magvit_path}")
print(f"File exists: {os.path.exists(magvit_path)}")
if os.path.exists(magvit_path):
    file_size = os.path.getsize(magvit_path) / (1024**3)  # GB
    print(f"File size: {file_size:.2f} GB")

    # Try to peek at the file content
    with open(magvit_path, 'rb') as f:
        first_bytes = f.read(20)
        print(f"First 20 bytes: {first_bytes}")

File path: c:\Users\jonat\Desktop\StyleAI\backend\fashionm3\Show-o\models\magvitv2\pytorch_model.safetensors
File exists: True
File size: 0.36 GB
First 20 bytes: b'x\x9a\x00\x00\x00\x00\x00\x00{"__metadata'


In [7]:
print("Loading MAGVIT-v2 from local path...")

try:
    # Try safetensors loading first
    from safetensors.torch import load_file

    vq_model = MAGVITv2().to(device)
    print(f"Loading from: {config.model.vq_model.pretrained_model_path}")

    # Load with safetensors
    state_dict = load_file(config.model.vq_model.pretrained_model_path)

    # The state_dict might need to be wrapped in a 'model' key or used directly
    if 'model' in state_dict:
        vq_model.load_state_dict(state_dict['model'])
    else:
        vq_model.load_state_dict(state_dict)

    print("MAGVIT-v2 loaded with safetensors")

except Exception as e:
    print(f"Safetensors loading failed: {e}")
    print("Trying HuggingFace fallback...")

    # Fallback to HuggingFace
    vq_model = MAGVITv2.from_pretrained(config.model.vq_model.vq_model_name).to(device)
    print("MAGVIT-v2 loaded from HuggingFace")

vq_model.requires_grad_(False)
vq_model.eval()
print(f"MAGVIT-v2 parameters: {sum(p.numel() for p in
vq_model.parameters()):,}")

Loading MAGVIT-v2 from local path...
Working with z of shape (1, 13, 16, 16) = 3328 dimensions.
Look-up free quantizer with codebook size: 8192
Loading from: c:\Users\jonat\Desktop\StyleAI\backend\fashionm3\Show-o\models\magvitv2\pytorch_model.safetensors
MAGVIT-v2 loaded with safetensors
MAGVIT-v2 parameters: 95,387,004


In [8]:
# setting up vision tower: clip-vit
vision_tower_name ="openai/clip-vit-large-patch14-336"
vision_tower = CLIPVisionTower(vision_tower_name).to(device)
clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower_name)




In [9]:
#model = Showo.from_pretrained("showlab/show-o")
#model = model.to(device)
#model.eval()

In [10]:
# setting up the showo model 
model = Showo.from_pretrained(config.model.showo.pretrained_model_path, w_clip_vit=True,  low_cpu_mem_usage=True, device_map = "auto", torch_dtype=torch.float32).to(device)
model.eval()

The config attributes {'mask_token_id': 58497} were passed to Showo, but are not expected and will be ignored. Please verify your config.json configuration file.


attention implementation:  sdpa


  if self.w_clip_vit:


ValueError: Showo does not support `device_map='auto'`. To implement support, the model class needs to implement the `_no_split_modules` attribute.

In [None]:
# setting up the parameters
temperature = 0.8  # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 1  # retain only the top_k most likely tokens, clamp others to have 0 probability
SYSTEM_PROMPT = "A chat between a curious user and an artificial intelligence assistant. " \
                "The assistant gives helpful, detailed, and polite answers to the user's questions."
SYSTEM_PROMPT_LEN = 28


## Inference 

In [None]:
import os
from IPython.display import Image

# Available fashion images (extracted from our training data)
fashion_images = {
    "black_trousers": "./0000000.jpg",
    "pink_jacket": "./0000001.jpg",
    "fashion_item_3": "./0000002.jpg"
}

# Choose which image to test
selected_image = "pink_jacket"  # or "black_trousers"
image_path = fashion_images[selected_image]

print(f"Testing with: {selected_image}")
print(f"Path: {image_path}")

# Display the image
if os.path.exists(image_path):
    display(Image(filename=image_path))
    print(f"Fashion image loaded: {selected_image}")
else:
    print(f"Image not found: {image_path}")

In [None]:
# inference
from PIL import Image
## arguments
input_image_path = fashion_images[selected_image]
questions ='Please describe this image in detail. *** What style elements do you notice? *** What fashion advice would you give?'

## processing
questions = questions.split('***')
image_ori = Image.open(input_image_path).convert("RGB")
# tranforming the image to the required resolution:256x256
image = image_transform(image_ori, resolution = config.dataset.preprocessing.resolution).to(device)
image = image.unsqueeze(0)
print(f"image shape: {image.shape}") # torch.Size([1, 3, 256, 256])
pixel_values = clip_image_processor.preprocess(image_ori,return_tensors="pt")['pixel_values'][0]
print(f"pixel values shape: {pixel_values.shape}")
image_tokens = vq_model.get_code(image) + len(uni_prompting.text_tokenizer)
print(f"image tokens shape: {image_tokens.shape}") # torch.Size([1, 256])
batch_size = 1

## inference
for question in questions: 
  conv = conversation_lib.default_conversation.copy()
  print(f"conversation: {conv}")
  conv.append_message(conv.roles[0], question)
  conv.append_message(conv.roles[1], None)
  prompt_question = conv.get_prompt()
  # print(prompt_question)
  question_input = []
  question_input.append(prompt_question.strip())
  print(f"system prompt: {SYSTEM_PROMPT}")
  input_ids_system = [uni_prompting.text_tokenizer(SYSTEM_PROMPT, return_tensors="pt", padding="longest").input_ids for _ in range(batch_size)]
  print(f"system prompt input ids: {input_ids_system}")
  input_ids_system = torch.stack(input_ids_system, dim=0)
  assert input_ids_system.shape[-1] == 28
  print(f"after torch stacking: {input_ids_system}")
  input_ids_system = input_ids_system.clone().detach().to(device)
  # inputs_ids_system = input_ids_system.to(device)
#   inputs_ids_system = torch.tensor(input_ids_system).to(device).squeeze(0)
  
  print(f"after moving to device: {input_ids_system}")
  input_ids_system = input_ids_system[0]
  print(f"after indexing 0: {input_ids_system}")
  
  
  print(f"question input: {question_input}")
  input_ids = [uni_prompting.text_tokenizer(prompt, return_tensors="pt", padding="longest").input_ids for prompt in question_input]
  print(f"after tokenizing the question: {input_ids}")
  input_ids = torch.stack(input_ids)
  print(f"after torch stacking: {input_ids}")
  input_ids = torch.nn.utils.rnn.pad_sequence(
                        input_ids, batch_first=True, padding_value=uni_prompting.text_tokenizer.pad_token_id
                )
  print(f"after padding: {input_ids}")
  # input_ids = torch.tensor(input_ids).to(device).squeeze(0)
  input_ids = input_ids.clone().detach().to(device).squeeze(0)
  print(f"after moving to device: {input_ids}")
  input_ids_llava = torch.cat([
                          (torch.ones(input_ids.shape[0], 1) *uni_prompting.sptids_dict['<|mmu|>']).to(device),
                          input_ids_system,
                          (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|soi|>']).to(device),
                          # place your img embedding here
                          (torch.ones(input_ids.shape[0], 1) * uni_prompting.sptids_dict['<|eoi|>']).to(device),
                          input_ids,
                  ], dim=1).long()
  print(input_ids_llava)
  
  images_embeddings = vision_tower(pixel_values[None])
  print(f"images embeddings shape: {images_embeddings.shape}")# torch.Size([1, 576, 1024])
  images_embeddings = model.mm_projector(images_embeddings)
  print(f"images embeddings shape after projection: {images_embeddings.shape}") 

  text_embeddings = model.showo.model.embed_tokens(input_ids_llava)

  #full input seq
  part1 = text_embeddings[:, :2+SYSTEM_PROMPT_LEN,:]
  part2 = text_embeddings[:, 2+SYSTEM_PROMPT_LEN:,:]
  input_embeddings = torch.cat((part1,images_embeddings,part2),dim=1)

  attention_mask_llava = create_attention_mask_for_mmu_vit(input_embeddings,system_prompt_len=SYSTEM_PROMPT_LEN)

  cont_toks_list = model.mmu_generate(
    input_embeddings = input_embeddings,
    attention_mask = attention_mask_llava[0].unsqueeze(0),
    max_new_tokens = 100,
    top_k = top_k,
    eot_token = uni_prompting.sptids_dict['<|eov|>']
  )
  
  cont_toks_list = torch.stack(cont_toks_list).squeeze()[None]
  text = uni_prompting.text_tokenizer.batch_decode(cont_toks_list,skip_special_tokens=True)
  print(f"User: {question}, \nAnswer: {text[0]}")


