In [None]:
import csv
import tempfile
from functools import partial
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
import jax
import jax.numpy as jnp
from flax.training.common_utils import shard, shard_prng_key
from flax.jax_utils import replicate
import wandb
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
from dalle_mini.text import TextNormalizer

In [None]:
wandb_runs = ['rjf3rycy']
VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None
normalize_text = True

In [None]:
batch_size = 8
num_images = 128
top_k = 8
text_normalizer = TextNormalizer() if normalize_text else None
padding_item = 'NONE'
seed = random.randint(0, 2**32-1)
key = jax.random.PRNGKey(seed)
api = wandb.Api()

In [None]:
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_params = replicate(clip.params)
vqgan_params = replicate(vqgan.params)

In [None]:
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

@partial(jax.pmap, axis_name="batch")
def p_clip(inputs):
    logits = clip(**inputs).logits_per_image
    return logits

In [None]:
with open('samples.csv', newline='', encoding='utf8') as f:
    reader = csv.DictReader(f)
    samples = []
    for row in reader:
        samples.append(row)
    # make list multiple of batch_size by adding elements
    samples_to_add = [{'Caption':padding_item, 'Theme':padding_item}] * (-len(samples) % batch_size)
    samples.extend(samples_to_add)
    # reshape
    samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]

In [None]:
# TODO: iterate on runs
wandb_run = wandb_runs[0]
model_pmapped = False

In [None]:
def get_artifact_versions(run_id):
    try:
        versions = api.artifact_versions(type_name='bart_model', name=f'dalle-mini/dalle-mini/model-{run_id}', per_page=10000)
    except:
        versions = []
    return versions

In [None]:
def get_training_config(run_id):
    training_run = api.run(f'dalle-mini/dalle-mini/{run_id}')
    config = training_run.config
    return config

In [None]:
# retrieve inference run details
def get_last_inference_version(run_id):
    try:
        inference_run = api.run(f'dalle-mini/dalle-mini/inference-{run_id}')
        return inference_run.summary.get('_step', None)
    except:
        return None

In [None]:
# compile functions - needed only once per run
def pmap_model_function(model):
    
    @partial(jax.pmap, axis_name="batch")
    def _generate(tokenized_prompt, key, params):
        return model.generate(
            **tokenized_prompt,
            do_sample=True,
            num_beams=1,
            prng_key=key,
            params=params
        )
    
    return _generate

In [None]:
def log_run(run_id):
    artifact_versions = get_artifact_versions(run_id)
    last_inference_version = get_last_inference_version(run_id)
    training_config = get_training_config(run_id)
    run = None
    p_generate = None
    model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']
    for artifact in artifact_versions:
        print(f'Processing artifact: {artifact.name}')
        version = int(artifact.version[1:])
        if last_version_inference is None:
            # we should start from v0
            assert version == 0
        elif version <= last_version_inference:
            print(f'v{version} has already been logged (versions logged up to v{last_version_inference}')
        else:
            # check we are logging the correct version
            assert version == last_version_inference + 1
        
        # start/resume corresponding run
        if run is None:
            run = wandb.init(job_type='inference', config=config, id=f'inference-{wandb_run}', resume='allow')
        
        # work in temporary directory
        with tempfile.TemporaryDirectory() as tmp:
            
            # download model files
            artifact = run.use_artifact(artifact)
            for f in model_files:
                artifact.get_path(f).download(tmp)
                
            # load tokenizer and model
            tokenizer = BartTokenizer.from_pretrained(tmp)
            model = CustomFlaxBartForConditionalGeneration.from_pretrained(tmp)
            model_params = replicate(model.params)
            
            # pmap model function needs to happen only once per model config
            if p_generate is None:
                p_generate = pmap_model_function(model)
            
            for batch in tqdm(samples):
                prompts = [x['Caption'] for x in batch]
                processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts
            

            
        
        

In [None]:
log_run(wandb_run)

In [None]:
def log_runs(runs):
    for run in tqdm(runs):
        log_run(run)

In [None]:
# TODO: loop over samples
batch = samples[0]
prompts = [x['Caption'] for x in batch]
processed_prompts = [text_normalizer(x) for x in prompts] if normalize_text else prompts

In [None]:
processed_prompts

In [None]:
repeated_prompts = processed_prompts * jax.device_count()

In [None]:
tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data
tokenized_prompt = shard(tokenized_prompt)

In [None]:
tokenized_prompt['input_ids'].shape

In [None]:
images = []
for i in range(num_images // jax.device_count()):
    key, subkey = jax.random.split(key, 2)
    
    encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)
    encoded_images = encoded_images.sequences[..., 1:]
    
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))
    
    for img in decoded_images:
        images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))
    

In [None]:
len(images)

In [None]:
images[0]

In [None]:
images[1]

In [None]:
clip_inputs = processor(text=prompts, images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data

In [None]:
# each shard will have one prompt
clip_inputs['input_ids'].shape

In [None]:
# each shard needs to have the images corresponding to a specific prompt
clip_inputs['pixel_values'].shape

In [None]:
images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))
images_per_prompt_indices

In [None]:
# reorder so each shard will have correct images
clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))

In [None]:
clip_inputs = shard(clip_inputs)

In [None]:
logits = p_clip(clip_inputs)

In [None]:
logits.shape

In [None]:
logits = logits.reshape(-1, num_images)

In [None]:
logits.shape

In [None]:
logits

In [None]:
top_idx = logits.argsort()[:, -top_k:][..., ::-1]

In [None]:
len(images)

In [None]:
results = []
columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]
logits = jax.device_get(logits)

In [None]:
for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):
    if sample['Caption'] == padding_item: continue
    cur_images = [images[x] for x in images_per_prompt_indices + i]
    top_images = [wandb.Image(cur_images[x]) for x in idx]
    top_scores = [scores[x] for x in idx]
    results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)

In [None]:
table = wandb.Table(columns=columns, data=results)

In [None]:
wandb.finish()