In [1]:
!git clone https://github.com/deepseek-ai/Janus.git

Cloning into 'Janus'...
remote: Enumerating objects: 118, done.[K
remote: Counting objects: 100% (72/72), done.[K
remote: Compressing objects: 100% (36/36), done.[K
remote: Total 118 (delta 50), reused 36 (delta 36), pack-reused 46 (from 1)[K
Receiving objects: 100% (118/118), 7.18 MiB | 21.01 MiB/s, done.
Resolving deltas: 100% (56/56), done.


In [2]:
%cd Janus
!pip install -e .
!pip install flash-attn

/content/Janus
Obtaining file:///content/Janus
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting attrdict (from janus==1.0.0)
  Downloading attrdict-2.0.1-py2.py3-none-any.whl.metadata (6.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.1->janus==1.0.0)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.1->janus==1.0.0)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.0.1->janus==1.0.0)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.0.1

In [3]:
!pip install datasets

import os
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from datasets import Dataset

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 



In [11]:
#Notes:

#Janus does not support Flash Attention 2.0 yet
#Had a bug with dtype
#bug with preparing imputs

class JanusFormTrainer:
  def __init__(self, model_path="deepseek-ai/Janus-1.3B", form_image_path="/content/Janus/images/demoForm1.png"):

    form_image_path = os.path.join(os.getcwd(), form_image_path)

    #For image path validation
    if not os.path.exists(form_image_path):
        raise ValueError(f"Image path does not exist: {form_image_path}")
    print(f"Found image at: {form_image_path}")

    if torch.backends.mps.is_available():
      self.device = torch.device("mps")
    else:
      self.device = torch.device("cpu")
    if torch.cuda.is_available():
      self.device = torch.device("cuda")

    self.config = {
        "model_path": model_path,
        "form_image_path": form_image_path,
        "output_dir": "./janus_finetuned",
        "batch_size": 4,
        "num_epochs": 3,
        "learning_rate": 5e-5,
        "logging_dir": "./logs"
    }

    # Model & processor setup
    self.vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
    self.tokenizer = self.vl_chat_processor.tokenizer

    self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
    self.model = self.model.to(torch.bfloat16).to(self.device).eval()

    # Did this to laod the model with correct dtype
    self.model: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
      model_path,
      trust_remote_code=True,
      torch_dtype=torch.float16,
      #attn_implementation="flash_attention_2"
    )
    self.model = self.model.to(self.device)

  def prepare_training_data(self):
    train_data = [
         {
            "conversation": [
                {
                    "role": "user",
                    "content": "Hey Quill, I am a first time user, can you help me fill out a form?",
                },
                {
                    "role": "assistant",
                    "content": "Of course I can! Please upload the form image."
                }
            ],
             "image_path": self.config["form_image_path"],
        },
        {
            "conversation": [
                {
                    "role": "user",
                    "content": "Hey Quill, can you identify what blank fields I need to fill out on this form?",
                },
                {
                    "role": "assistant",
                    "content": "Yes, I can see several blank fields that need to be completed. The main sections are:\n\n1. Property Location Information:\n- County\n- Township\n- PIN(s)\n- Address of Property\n- Square Feet of Building\n\n2. Property Status Fields:\n- Recent Purchase checkbox\n- Demolition and related fields (Demo Permit, Paid Receipt, Demo Affidavit)\n- Excess Land and Non-Buildable Area information\n\n3. Timing Information:\n- New Building Completion Date\n- Occupancy Permit status\n- Vacancy dates for 2024 and 2025\n\nWould you like me to explain any specific section in detail?"
                }
            ],
             "image_path": self.config["form_image_path"],
        },
        {
            "conversation": [
                {
                    "role": "user",
                    "content": "What are all the yes/no fields I need to fill out?"
                },
                {
                    "role": "assistant",
                    "content": "I can see several yes/no fields that need to be completed:\n\n1. Property Status Questions:\n- Demolition?\n- Demo Permit?\n- Paid Receipt?\n- Demo Affidavit?\n- Excess Land?\n\n2. Property Characteristics:\n- Have Occupancy Permit?\n- Is property being offered for sale? (YES/NO with price field if YES)\n- Is building in court or under citation for building code violations? (YES/NO)\n\nEach of these fields requires either a YES or NO response, with some requiring additional information if answered YES."
                }
            ],
            "image_path": self.config["form_image_path"],
        },
        {
            "conversation": [
                {
                    "role": "user",
                    "content": "Which fields are related to dates or timing?"
                },
                {
                    "role": "assistant",
                    "content": "The form includes several date-related fields:\n\n1. Building Timeline:\n- New Building Completed When? (requires specific completion date)\n\n2. Vacancy Periods:\n- Vacancy - 2024 (specify period)\n- Vacancy - 2025 (specify period)\n\n3. Implied Date References:\n- Recent Purchase? (indicates recent timing)\n- Any additions or improvements to the property within the last three years\n- The fair market value estimation specifically references January 1, 2025\n\nThese fields need appropriate date formats or time period specifications."
                }
            ],
            "image_path": self.config["form_image_path"],
        }
    ]
    return train_data

  def process_conversations(self, train_data):
    processed_data = []
    for data in train_data:
        # Combine all messages into a single conversation
        processed_data.append({
            "conversation": {  # Single conversation object
                "messages": data["conversation"]  # List of messages
            },
            "image_path": data["image_path"]
        })

    print(f"Processed {len(processed_data)} conversations")
    print("First conversation structure:", processed_data[0]["conversation"])
    return processed_data

  def load_pil_images(self, processed_data):
    try:

      print("Loading images...")
      pil_images = [Image.open(item["image_path"]).convert('RGB') for item in processed_data]
      print(f"Loaded {len(pil_images)} images")
      print("Images loaded!")

      print("Processing conversations...")
     # Extract conversations in the format the processor expects
      conversations = [item [["conversation"]["messages"]] for item in processed_data]
      print("Conversation format being sent to processor:", conversations[0])


      with torch.autocast(device_type=self.device.type):
          inputs = self.vl_chat_processor(
              conversations=conversations,
              images=pil_images,
              force_batchify=True,
              return_tensors="pt"
          )
      print("Imputs prepared")
      # Convert to dataset format
      dataset_dict = {
          "input_ids": inputs["input_ids"],
          "attention_mask": inputs["attention_mask"],
      }
      print("Dataset format")
      if "labels" in inputs:
          dataset_dict["labels"] = inputs["labels"]

      # Create dataset
      return Dataset.from_dict(dataset_dict)
      print("Dataset created")

    except Exception as e:
      print(f"Error in prepare_model_inputs: {e}")
      raise


  def train(self):
    try:
      print("Preparing training data...")
      train_data = self.prepare_training_data()
      processed_data = self.process_conversations(train_data)

      print("Preparing model inputs...")
      model_inputs = self.load_pil_images(processed_data)

      print("Settign up training args...")
      training_args = TrainingArguments(
        output_dir=self.config["output_dir"],
        per_device_train_batch_size=self.config["batch_size"],
        num_train_epochs=self.config["num_epochs"],
        learning_rate=self.config["learning_rate"],
        logging_dir=self.config["logging_dir"],
        logging_steps=50,
        save_strategy="epoch",
        load_best_model_at_end=True,
        fp16=True
      )

      print("Setting up trainer...")
      trainer = Trainer(
        model=self.model,
        args=training_args,
        train_dataset=model_inputs
      )

      print("Starting training...")
      trainer.train()
      print("Training complete!")

    except Exception as e:
      print(f"An error occurred during training: {e}")
      raise

  conversations = [item [["conversation"]["messages"]] for item in processed_data]


In [12]:
if __name__ == "__main__":
    trainer = JanusFormTrainer(
        model_path="deepseek-ai/Janus-1.3B",
        form_image_path="/content/Janus/images/demoForm1.png"
    )
    trainer.train()

Found image at: /content/Janus/images/demoForm1.png


Some kwargs in processor config are unused and will not have any effect: ignore_id, num_image_tokens, mask_prompt, add_special_token, image_tag, sft_format. 


Add image tag = <image_placeholder> to the tokenizer
Preparing training data...
Processed 4 conversations
First conversation structure: {'messages': [{'role': 'user', 'content': 'Hey Quill, I am a first time user, can you help me fill out a form?'}, {'role': 'assistant', 'content': 'Of course I can! Please upload the form image.'}]}
Preparing model inputs...
Loading images...
Loaded 4 images
Images loaded!
Processing conversations...
Error in prepare_model_inputs: list indices must be integers or slices, not str
An error occurred during training: list indices must be integers or slices, not str


TypeError: list indices must be integers or slices, not str