In [None]:
import torch
import numpy as np

from muse import PipelineMuse

We load a proof-of-concept version trained on conceptual captions.

In [None]:
pipe = PipelineMuse.from_pretrained("openMUSE/muse-laiona6-uvit-clip-220k")

The pipeline contains the following components:
- `text_encoder`
- `transformer`
- `vae`

We'll first attempt conversion of the `transformer` component, which is the image generation module.

In [None]:
model = pipe.transformer.eval()

In [None]:
import coremltools as ct
ct.__version__

scikit-learn version 1.2.2 is not supported. Minimum required version: 0.17. Maximum required version: 1.1.2. Disabling scikit-learn conversion API.
Torch version 2.0.1+cu117 has not been tested with coremltools. You may run into unexpected errors. Torch 2.0.0 is the most recent version that has been tested.


'7.0b1'

We can do 6-bit palettization with this version of `coremltools`. We'll convert without it first, and then we'll measure any differences in quality we observe.

## Inputs

### Text conditioning

In [None]:
text_input_ids = pipe.tokenizer(
    "Labrador in the style of Vermeer",
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=pipe.tokenizer.model_max_length,
).input_ids
text_input_ids.shape

torch.Size([1, 77])

Like in Stable Diffusion, note that we are _not_ using attention masks.

In [None]:
encoder_hidden_states = pipe.text_encoder(text_input_ids).last_hidden_state
encoder_hidden_states.shape

torch.Size([1, 77, 768])

In [None]:
negative_input_ids = pipe.tokenizer(
    "ugly, bad anatomy",
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=pipe.tokenizer.model_max_length,
).input_ids
negative_encoder_hidden_states = pipe.text_encoder(negative_input_ids).last_hidden_state

In [None]:
bs = 2  # cfg

Conditioning (encoder_hidden_states and negative_encoder_hidden_states). We could use `np.random.normal` but I'm not sure what the distribution is, so let's just use the previous examples.

In [None]:
sequence_length = pipe.tokenizer.model_max_length   # 77
embed_size = pipe.text_encoder.config.hidden_size   # 768

conditioning_shape = (bs, sequence_length, embed_size)

In [None]:
conditioning = np.concatenate((encoder_hidden_states.detach().numpy(), negative_encoder_hidden_states.detach().numpy()))

In [None]:
assert conditioning.shape == conditioning_shape

### Image input

Image input token ids. Each image is made of `model.config.num_vq_tokens` (256) tokens taken from a codebook of size `codebook_size` (8192).

In [None]:
input_ids_shape = (bs, model.config.num_vq_tokens)
input_ids = np.random.randint(0, model.config.codebook_size, input_ids_shape)

In [None]:
inputs = {
    "input_ids": input_ids,
    "encoder_hidden_states": conditioning,
}

### Model output (single step)

In [None]:
t_inputs = {
    "input_ids": torch.tensor(input_ids, dtype=torch.int32),
    "encoder_hidden_states": torch.tensor(conditioning),
}

In [None]:
outputs = model(**t_inputs)
outputs.shape

torch.Size([2, 256, 8192])

Outputs are: `cond_logits`, `uncond_logits`.

**TODO** We could chunk them here for convenience. We could also apply some more post-processing inside a model wrapper.

## JIT

In [None]:
jit_inputs = list(t_inputs.values())

In [None]:
jitted_model = torch.jit.trace(model, jit_inputs)
jitted_model.eval();

  height, width = int(seq_length**0.5), int(seq_length**0.5)
  height, width = int(seq_length**0.5), int(seq_length**0.5)


In [None]:
with torch.no_grad():
    output_jit = jitted_model(*jit_inputs)

In [None]:
(output_jit - outputs).abs().max()

tensor(8.3923e-05, grad_fn=<MaxBackward1>)

Close enough.

## Core ML

In [None]:
def _get_coreml_inputs(sample_inputs):
    return [
        ct.TensorType(
            name=k,
            shape=v.shape,
            dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype,
        ) for k, v in sample_inputs.items()
    ]

In [None]:
coreml_input_types = _get_coreml_inputs(t_inputs)
coreml_output_types = [ct.TensorType(name="logits")]  # Update when chunking/post-processing

In [None]:
coreml_model = ct.convert(
    jitted_model,
    convert_to = "mlprogram",
    minimum_deployment_target = ct.target.macOS13,
    inputs = coreml_input_types,
    outputs = coreml_output_types,
    compute_precision = ct.precision.FLOAT16,
)

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████████████████████████████████████████████████████████████████████████████████████████▉| 3810/3811 [00:01<00:00, 2142.19 ops/s]
Running MIL frontend_pytorch pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 16.02 passes/s]
Running MIL default pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:58<00:00,  1.09 passes/s]
Running MIL backend_mlprogram pipeline: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:00<00:00, 100.53 passes/s]


In [None]:
coreml_model.save("muse_transformer.mlpackage")

### Inference

## TODO

- Verify inference on Mac
- Convert text encoder, VAE
- Python pipeline
- Swift pipeline
- Palettization