In [None]:
import os
import copy
from dataclasses import dataclass, field
import json
import logging
import pathlib
from typing import Dict, Optional, Sequence, Union
import numpy as np


import torch

import transformers
import tokenizers

from llava.constants import (
    IGNORE_INDEX,
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
from torch.utils.data import Dataset
from llava.train.llava_trainer import LLaVATrainer

from llava import conversation as conversation_lib
from llava.model import *
from llava.mm_utils import tokenizer_image_token

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    snakemake.input[0], torch_dtype=torch.bfloat16
)
model

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    snakemake.input[0],
    padding_side="right",
    use_fast=False,
)

In [None]:
assert tokenizer.unk_token is None

tokenizer.add_special_tokens({"unk_token": "<unk>"})
model.config.pad_token_id = tokenizer.unk_token_id
model.config.vocab_size += 1

orig_embed = model.model.embed_tokens
orig_embed.weight.shape

In [None]:
model.model.embed_tokens = torch.nn.Embedding(
    orig_embed.weight.shape[0] + 1,
    orig_embed.weight.shape[1],
    padding_idx=tokenizer.unk_token_id,
    dtype=model.model.embed_tokens.weight.dtype,
    device=model.model.embed_tokens.weight.device,
)
model.model.embed_tokens.weight.data[:-1] = orig_embed.weight.data
torch.nn.init.zeros_(model.model.embed_tokens.weight.data[-1:, :])

In [None]:
model

In [None]:
model.config._name_or_path = snakemake.output[0]
model.config._name_or_path

In [None]:
model.config


In [None]:
# Also adapt the lm_head
orig_lm_head = model.lm_head

model.lm_head = torch.nn.Linear(
    model.lm_head.in_features,
    model.config.vocab_size,
    bias=False,
    device=model.lm_head.weight.device,
    dtype=model.lm_head.weight.dtype
)
model.lm_head.weight.data[:-1] = orig_lm_head.weight.data
torch.nn.init.zeros_(model.lm_head.weight.data[-1:, :])

In [None]:
# save model

model.save_pretrained(snakemake.output[0], state_dict=None, safe_serialization=True)
tokenizer.save_pretrained(snakemake.output[0])