## 1. Chargement des données d'entraînement

Les images sont issues de COCO, et ont été filtrées pour (tenter de) ne retenir que des images représentant des paysages.

In [None]:
# !wget https://minio.lab.sspcloud.fr/cthiounn2/archive_val.zip
# !unzip -o ../image_data/archive_val.zip -d ../image_data/

In [None]:
verbose = 0
text_token_length = 255


## 1. Entraînement

L'entraînement consiste en:

1. Encoder le texte en tokens-text
2. Encoder les images en tokens-image
3. Modéliser l'ensemble de façon auto-régressive

### 1.1. Encodage du texte

Nous utilison `BartTokenizer` pour l'encodage du texte comme mini Dall-E. Le Dall-E original utilise selon l'article du "_BPE-encoding_" (byte-pair encoding, c'est à dire strictement parlant des paires de caractères), ce qui peut s'interpréter comme l'utilisation du modèle GPT-3, qui repose lui-aussi sur un encodage proche d'un _BPE encoding_. Malheureusement, GPT-3 n'est pas disponible au grand public.

In [None]:
# TO DO:
# - put all the parameters at the begning of the 
# - save the weights locally to avoid download each time
# - see if possible to use GPT-3 with Open-AI free account
# - enable TPUs usage offered by Google

from transformers import BartTokenizer, BartForConditionalGeneration
import torch

# https://huggingface.co/transformers/v2.11.0/model_doc/bart.html

caption = "A Emperor penguin standing on the ice"

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') # essayer une taille plus faible

text_tokens = tokenizer(caption, max_length=text_token_length, padding='max_length')['input_ids']
text_tokens = torch.as_tensor(text_tokens) # BPE-encoding
if verbose >= 2:
    print(text_tokens)


Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

### 1.2 Encodage des images

Nous utilisons alternativement 

In [None]:
!pip install dall_e

You should consider upgrading via the '/root/venv/bin/python -m pip install --upgrade pip' command.[0m


In [None]:
## TRAINING 
import io
import os, sys
import requests
import PIL

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from dall_e          import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

dev = torch.device('cpu')


# get token ids for images (= encode)

encoder = load_model("https://cdn.openai.com/dall-e/encoder.pkl", dev)


target_image_size = 256

# scale images down to 256x256 (cropping the uneven dimension)
# we might get problems with some images from the COCO datasets
# ignore these images as a first approximation
# or reduce the image resolution ?

def preprocess(img):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

# replace by direct reading from disk
# persist images to disk in the first place
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))

z_logits = encoder(x)
z = torch.argmax(z_logits, axis=1)
#z = F.one_hot(z, num_classes=encoder.vocab_size).permute(0, 3, 1, 2).float()

# pad text to fixed length with an additional id and bind text and image tokens together

# model the sequence with a transformer model



  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:

image_tokens = z.flatten()

# def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
#     """
#     Shift input ids one token to the right.
#     """
#     shifted_input_ids = np.zeros(input_ids.shape)
#     shifted_input_ids[:, 1:] = input_ids[:, :-1]
#     shifted_input_ids[:, 0] = decoder_start_token_id
#     return shifted_input_ids


    # dataset.preprocess(
    #     tokenizer=tokenizer,
    #     decoder_start_token_id=model.config.decoder_start_token_id,
    #     normalize_text=model.config.normalize_text,
    #     max_length=model.config.max_text_length,
    # )


# all_tokens =  torch.cat( (text_tokens,image_tokens) )

# if verbose > 2:
#     print(all_tokens.shape)

In [None]:
!cd model
!git clone https://huggingface.co/facebook/bart-large-cnn

/bin/bash: line 0: cd: model: No such file or directory
fatal: destination path 'bart-large-cnn' already exists and is not an empty directory.


In [None]:
# QUESTIONS:
# - Is it a problem that there is potentiel overlap between the indices of image and text-tokens ?

from transformers import BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')
# bart-large limité à des séquences de 1024 tokens, vérifier pour bart-large-cnn
# tester un modèle plus petit que bart-large-cnn

image_tokens[1:] = image_tokens[:-1]
image_tokens[0]  = model.config.decoder_start_token_id

predict = model(
    input_ids = text_tokens,
    decoder_input_ids = image_tokens,
    # attention_mask = torch.ones_like(all_tokens)
) # retourne un objet de même taille que decoder_input_ids

# loss = cross-entropy
# torch.nn.crossEntropy()
# predict vs. image_tokens

# alternativement on peut changer la dernière couche de Bart
# nn.Linear(size_embedding, num_vocab_img)

# import torch
# import torch.nn as nn
# class RNN(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(RNN, self).__init__()
#         self.hidden_size = hidden_size
#         self.i2h = nn.Linear(n_categories + input_size + hidden_size, hidden_size)
#         self.i2o = nn.Linear(n_categories + input_size + hidden_size, output_size)
#         self.o2o = nn.Linear(hidden_size + output_size, output_size)
#         self.dropout = nn.Dropout(0.1)
#         self.softmax = nn.LogSoftmax(dim=1)
#     def forward(self, category, input, hidden):
#         input_combined = torch.cat((category, input, hidden), 1)
#         hidden = self.i2h(input_combined)
#         output = self.i2o(input_combined)
#         output_combined = torch.cat((hidden, output), 1)
#         output = self.o2o(output_combined)
#         output = self.dropout(output)
#         output = self.softmax(output)
#         return output, hidden

KernelInterrupted: Execution interrupted by the Jupyter kernel.

In [None]:
image_tokens[:-1]

tensor([7522,  741, 5973,  ..., 6231, 5016, 1144])

In [None]:
## INFERENCE

# get token ids for texts (= encode)

# generate the next terms in the sequence with a random seed

# get image from token ids (= decode)

In [None]:
# https://colab.research.google.com/drive/14oChMr8KZVS7DzcbsuJix0JQKUTGO64j

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=4f3692ed-5f27-49a4-899a-82a03e72232c' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>