# MAGMA Inference Demo

Copyright (c) 2023 Graphcore Ltd.

This notebook provides a basic interactive MAGMA inference application, based on the freely available checkpoint published by Aleph Alpha. The MAGMA model autoregressively generates text from arbitrary combinations of visual and textual input.
Note that the provided checkpoint is only a demo meant to help users understand how the model works.

MAGMA (Multimodal Augmentation of Generative Models through Adapter-based Finetuning) is a multimodal Vision-Language model developed by Aleph Alpha and researchers from [Heidelberg University](https://www.cl.uni-heidelberg.de "Computational Linguistics at Heidelberg University"). For all the details, please refer to the [research paper](https://arxiv.org/abs/2112.05253) and the [official GitHub repository](https://github.com/Aleph-Alpha/magma). 

MAGMA aims at obtaining a joint image-text representation, which can then be used for a variety of tasks such as image captioning, visual question answering and visual entailment. You can feed the model with a textual prompt and an image and use the textual prompt to make queries about the image.

The main components of MAGMA are:
- A pretrained autoregressive **language model**. This component is **frozen**, it is not trained. The chosen language model is [GPT-J 6B](https://github.com/kingoflolz/mesh-transformer-jax).
- A **visual encoder** (namely, variants of ViT or ResNet) that takes an image input and produces image features. This is a **trainable** component. In the paper, several variants are discussed and compared.
- An **image prefix** that projects image features
into a sequence of embedding vectors, suitable to be inputs of the language model transformer. This is a **trainable** component.
- **Adapter** layers, which are **trainable** layers added in different layouts to the language model decoder block. In the paper, several configurations are discussed and compared.

The image below, taken from the Aleph Alpha [paper](https://arxiv.org/abs/2112.05253), summarises the model structure.

![MAGMA model architecture](images/MagmaStructure.png)

- The inputs of the model are an image and an accompanying text.
- Text is encoded by the language model embedding **E**, obtaining text embeddings.
- In the ImagePrefix:
    - Image features are produced by the visual encoder **V<sup>e</sup>**
    - Image features are projected to the embedding dimension by a linear layer **V<sup>p</sup>**, obtaining image embeddings.
- Image embeddings and text embeddings are concatenated. Concatenation order is `(image, text)`, with the image tokens always at the beginning of the sequence. In this way, they are always attended by the text tokens (not affected by causal mask).
- The resulting sequence is fed to the language model transformer endowed with adapters.

This application implements only the model choices of the [publicly available checkpoint](https://bit.ly/aleph_alpha_magma_download). As they state in the model repo, this checkpoint is just meant to be a demo and it's not the one they use in actual applications:
- The visual encoder is a CLIP modified ResNet50-x16 ([implementation](https://github.com/openai/CLIP/blob/main/clip/model.py)), without the final attention pool layer.
- Adapters are only added to feed-forward blocks. In the MAGMA paper, it says that adding adapters also to the attention layers gives better results.

## Inference

**Note:** Magma uses GPT-J, but in the code you can find references to GPTNeo, and this can be confusing. The model is GPT-J 6B but not taken from the official Hugging Face transformer library. Instead, they use [finetuneanon/transformer](https://github.com/finetuneanon/transformers). In this library, [GPTNeo](https://github.com/finetuneanon/transformers/tree/gpt-neo-localattention3-rp-b/src/transformers/models/gpt_neo) with the options `jax=True`, `rotary=True` and `global` attention corresponds to GPT-J.

### Input preprocessing
#### Textual input
The input text is tokenized using GPT2 tokenizer.
It is then padded up to `sequence_len - 144`, where `sequence_len` is specified in the configs and 144 is the number of image tokens generated by the image encoder.
#### Image input
The input image is resized, cropped and normalised using Magma [clip_preprocess](https://github.com/Aleph-Alpha/magma/blob/master/magma/transforms.py) function.


### Next token generation
Given an image and a text prompt, the model outputs logits for the sequence.
Logits corresponding to the last token in the sequence are used to predict the next token.
Several heuristics are supported:
- argmax: simply pick the token with the highest probability. This is a deterministic sampling method.
- top-p: probability is redistributed among the first x tokens such that the cumulative probability is greater than a threshold p. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- top-k: probability is redistributed among the K most-likely tokens. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- temperature: logits are scaled by a factor 1/T (T between 0 and 1 ) before applying the softmax. This makes the distribution more peaked for low temperature, and broader for high temperatures. A zero temperature corresponds to a deterministic choice (argmax), while sampling output becomes more random as we increase the temperature.

If you are not familiar with these concepts, this [HuggingFace article](https://huggingface.co/blog/how-to-generate) can help you visualise them.

The new token is added to the input sequence, and the process goes on until `max_out_tokens` are generated or the model outputs the `EOS` token.

### Execution scheme
The model is run using phased execution.
This means that the model is partitioned into a set of smaller graphs that are executed in series on the IPU, using remote memory to store variables and input/output tensors between calls (activations).
We recommend going through the tutorial [Phased Execution in MNIST example](https://github.com/graphcore/examples/tree/master/tutorials/tutorials/popxl/6_phased_execution) to better understand this execution scheme.
- The first phase corresponds to the image prefix. The output of this phase produces image embeddings that are concatenated to the (tokenized) textual input
- The following phase is the GPT-J embedding phase, which produces text embeddings.
- Image embeddings and text embeddings are concatenated and fed to GPT-J blocks. Each block constitutes a different phase.
- The final phase is the LM head, producing logits to be used for the next token prediction.

The GPT-J model makes use of tensor parallelism, spanning across 4 IPUs. More details about the implementation are available in the [GPT-J Notebook]().


## Environment setup

The best way to run this demo is on Paperspace Gradient’s cloud IPUs because everything is already set up for you. To improve your experience, we preload datasets and pre-install packages. This can take a few minutes. If you experience errors immediately after starting a session, please try restarting the kernel before contacting support. If a problem persists or you want to give us feedback on the content of this notebook, please reach out to through our community of developers using our [Slack channel](https://www.graphcore.ai/join-community). 

[![Run on Gradient](../../gradient-badge.svg)](https://console.paperspace.com/github/<runtime-repo>?machine=Free-IPU-POD4&container=<dockerhub-image>&file=<path-to-file-in-repo>)  

To run the demo using other IPU hardware, you need to have the Poplar SDK enabled. Refer to the [Getting Started guide](https://docs.graphcore.ai/en/latest/getting-started.html#getting-started) for your system for details on how to do this. Also refer to the [Jupyter Quick Start guide](https://docs.graphcore.ai/projects/jupyter-notebook-quick-start/en/latest/index.html) for how to set up Jupyter to be able to run this notebook on a remote IPU machine.


In order to improve usability and support for future users, Graphcore would like to collect information about the applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:

User progression through the notebook
Notebook details: number of cells, code being run and the output of the cells
Environment details
You can disable logging at any time by running %unload_ext graphcore_cloud_tools.notebook_logging.gc_logger from any cell.

Uncomment the cell below to install libopenmpi-dev if not available in your docker

In [None]:
#!export DEBIAN_FRONTEND=noninteractive
#!apt-get update
#!apt install -y libopenmpi-dev

In [None]:
%pip install -r requirements.txt --ignore-requires-python
%load_ext graphcore_cloud_tools.notebook_logging.gc_logger

In [None]:
from run_inference import run_inference, init_inference_session
import sys, os, os.path

executable_cache_dir = os.path.join(
    os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", "./exe_cache"), "magma"
)
os.environ["POPXL_CACHE_DIR"] = executable_cache_dir

## Compile and load the model
Compilation takes around 3 minutes.

In [None]:
from magma.magma import Magma
from configs import CONFIG_DIR, MagmaConfig

The upstream `Magma.from_checkpoint` which is called in `init_inference_session()`` uses `gdown`, which appears to be broken at the moment. Either download the checkpoint manually, or allow `wget` to download it.

In [None]:
checkpoint_path = os.path.join(
    os.getenv("PUBLIC_DATASETS_DIR", "."), "magma", "mp_rank_00_model_states.pt"
)

checkpoint_path
#if os.path.exists(checkpoint_path) == False:
#    !wget 'https://bit.ly/aleph-alpha-magma-download' -O {checkpoint_path}

Note that sequence length 1024 checkpoints are also available: magma_v1_1024

In [None]:
session, config, tokenizer = init_inference_session(
    "magma_v1_500", checkpoint_path=checkpoint_path
)

## Run demo

In [None]:
from PIL import Image
import requests
from io import BytesIO
import ipywidgets as ipw


def answer_int(image_url, text, seed, top_p, top_k, temperature, max_out_tokens):

    if image_url.startswith("http"):
        response = requests.get(image_url)
        image = BytesIO(response.content)
    else:
        image = open(image_url, "rb")
    img = ipw.Image(value=image.read(), width=384, height=384)
    prompt = ipw.Label(value=f"Prompt: {text}", style={"font_size": "16px"})
    answer = ipw.Label(
        value=f"Answer: `{run_inference(session, config, tokenizer, image_url, text, seed, top_p, top_k, temperature, max_out_tokens)}`",
        style={"font_size": "16px"},
    )

    return ipw.VBox(
        [img, prompt, answer], layout=ipw.Layout(display="flex", align_items="center")
    )

Since the next token selection is done on the CPU, you can change the parameters used for generation (`top_p`, `top_k`, `max_out_tokens`, `temperature`, explained below) without triggering recompilation.
Specifying a `seed` allows you to get reproducible deterministic results.

Given an image and a text prompt, the model outputs logits for the sequence.
Logits corresponding to the last token in the sequence are used to predict the next token.
Several heuristics are supported:
- top-p: probability is redistributed among the first x tokens such that the cumulative probability is greater than a threshold p. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- top-k: probability is redistributed among the K most-likely tokens. Then, next token is sampled from such distribution (categorical sampling, non deterministic).
- temperature: logits are scaled by a factor 1/T (T between 0 and 1 ) before applying the softmax. This makes the distribution more peaked for low temperature, and broader for high temperatures. A zero temperature corresponds to a deterministic choice (argmax), while sampling output becomes more random as we increase the temperature.
- max_out_tokens: maximum number of tokens in the answer

This [Hugging Face blog post](https://huggingface.co/blog/how-to-generate) provides more information.


Manually set model heuristics and the input image:

In [None]:
# image_url = "demo_example_images/cantaloupe_popsicle.jpg"
# image_url="demo_example_images/circles.jpg"
# image_url="demo_example_images/circles_square.jpg"
# image_url="demo_example_images/korea.jpg"
# image_url="demo_example_images/matterhorn.jpg"
# image_url="demo_example_images/mushroom.jpg"
# image_url="demo_example_images/people.jpg"
# image_url="demo_example_images/playarea.jpg"
image_url = "demo_example_images/popsicle.png"
# image_url="demo_example_images/rainbow_popsicle.jpeg"
# image_url="demo_example_images/table_tennis.jpg"

text = "A picture of "
seed = 0  # 0 to 300
top_p = 0.9  # 0 to 1.0
top_k = 0  # 0 to 10
temperature = 0.7  # 0 to 1.0
max_out_tokens = 6  # 1 to 356

answer_int(image_url, text, seed, top_p, top_k, temperature, max_out_tokens)

## Next steps

Areas for further exploration:
* Experiment with the sequence length 1024 checkpoint
* Investigate the [GPT-J notebook]() to better understand tensor parallelism on the IPU