# Questions to answer

    Maybe attention is the problem!!

In [1]:
from transformers import AutoModelWithLMHead, AutoTokenizer,VisionEncoderDecoderModel, ViTFeatureExtractor,ViTImageProcessor
import torch
from PIL import Image
import os
import json
import matplotlib.pyplot as plt
import random

In [2]:
from torch import nn

In [3]:
import cv2
import numpy as np
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader

## 0. Functions

In [4]:
max_length = 100
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}

def predict_step(image_paths):
    images = []
    for image_path in image_paths:

        try:
            i_image = Image.open(image_path)
        except:
            return None
            
        if i_image.mode != "RGB":
            i_image = i_image.convert(mode="RGB")
    
        images.append(i_image)
    
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    print(pixel_values.shape)
    pixel_values = pixel_values.to(device)
    
    output_ids = model.generate(pixel_values, **gen_kwargs)

    print(output_ids)
    
    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    
    return preds

In [78]:
def predict_step2(image_paths, encoder, decoder):
    
    preds_all = []
    prev_pix = None
    for idx, image_path in enumerate(image_paths):

        try:
            i_image = Image.open(image_path)
        except:
            return None
            
        if i_image.mode != "RGB":
            i_image = i_image.convert(mode="RGB")

        org_pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values.to(device)

        curr_pix = encoder(org_pixel_values).last_hidden_state
        
        if prev_pix != None:
            pixel_values = torch.concat((prev_pix, curr_pix), 1)
        else:
            pixel_values = torch.concat((curr_pix, curr_pix), 1)
        prev_pix = curr_pix
            
        pixel_values = pixel_values.to(device)
        print(pixel_values.shape)
        
        output_ids = decoder.generate(pixel_values = pixel_values, **gen_kwargs)
        print(output_ids)
    
        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]

        print(preds)
        preds_all.append(preds)
            
    return preds_all

## 1. Load baseline model + basic performance check

In [5]:
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

In [6]:
for name, child in model.named_children():
        for x, y in child.named_children():
            print(name,x)

encoder embeddings
encoder encoder
encoder layernorm
encoder pooler
decoder transformer
decoder lm_head


In [7]:
model.encoder

ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation(

In [8]:
model.decoder

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (crossattention): GPT2Attention(
          (c_attn): Conv1D()
          (q_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_cross_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation(

In [9]:
device

device(type='cuda')

In [10]:
tokenizer

GPT2TokenizerFast(name_or_path='nlpconnect/vit-gpt2-image-captioning', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

## 2. Load VIST images

    Loading is slightly different from the previous notebook, as the entire set of images for one story together

## Train

In [11]:
ann = os.listdir("/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/annotations")
labels = open("/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/annotations/StoriesFin.json")
labels = json.load(labels)
labels["16"]

{'sent_ids': ['80', '81', '82', '83', '84'],
 'img_ids': ['181647714', '181626113', '181645575', '181635518', '181640606'],
 'album_id': '72157594187037689',
 'text': ['we took a nice hike into the forest today .',
  'we were lucky enough to see some wildlife , like this deer .',
  'this guy was friendly . he must hit up all the hikers for food .',
  "i 'm glad we spotted this snake before we got too close !",
  'the end of our hike rewarded us with an amazing view of the falls !']}

In [12]:
len(labels.keys())

5149

In [13]:
labels['16']

{'sent_ids': ['80', '81', '82', '83', '84'],
 'img_ids': ['181647714', '181626113', '181645575', '181635518', '181640606'],
 'album_id': '72157594187037689',
 'text': ['we took a nice hike into the forest today .',
  'we were lucky enough to see some wildlife , like this deer .',
  'this guy was friendly . he must hit up all the hikers for food .',
  "i 'm glad we spotted this snake before we got too close !",
  'the end of our hike rewarded us with an amazing view of the falls !']}

In [14]:
# stories = random.sample(list(labels.keys()),100) # choose 100 random stories

In [15]:
labels['16']

{'sent_ids': ['80', '81', '82', '83', '84'],
 'img_ids': ['181647714', '181626113', '181645575', '181635518', '181640606'],
 'album_id': '72157594187037689',
 'text': ['we took a nice hike into the forest today .',
  'we were lucky enough to see some wildlife , like this deer .',
  'this guy was friendly . he must hit up all the hikers for food .',
  "i 'm glad we spotted this snake before we got too close !",
  'the end of our hike rewarded us with an amazing view of the falls !']}

In [16]:
train = {}
for s in labels:
    for idx, im in enumerate(labels[s]['img_ids']):
        image_path = f"/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/{im}.jpeg"
        key = f"story_{s}"
        if key not in train:
            train[key] = {}
        train[key][idx] = {}
        train[key][idx]['image_path'] = image_path
        train[key][idx]['text'] = labels[s]['text'][idx]

    # train[s] = {}
    # train[s]["image_paths"] = []
    # train[s]["img_ids"] = labels[s]['img_ids']
    # train[s]["text"] = labels[s]['text']    
    # for im in train[s]["img_ids"]:
    #     train[s]["image_paths"].append(f"/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/{im}.jpeg")

In [17]:
train['story_16']

{0: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181647714.jpeg',
  'text': 'we took a nice hike into the forest today .'},
 1: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181626113.jpeg',
  'text': 'we were lucky enough to see some wildlife , like this deer .'},
 2: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181645575.jpeg',
  'text': 'this guy was friendly . he must hit up all the hikers for food .'},
 3: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181635518.jpeg',
  'text': "i 'm glad we spotted this snake before we got too close !"},
 4: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181640606.jpeg',
  'text': 'the end of our hike rewarded us with an amazing view of the falls !'}}

In [18]:
len(train.keys())

5149

In [19]:
print(list(train.keys())[2])
train[list(train.keys())[2]]

story_19


{0: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181647714.jpeg',
  'text': 'giant sequoia tree and red woods in the forest .'},
 1: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181626113.jpeg',
  'text': 'a young deer scampers about in the woods .'},
 2: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181645575.jpeg',
  'text': 'grey squirrel holding some food with his paws .'},
 3: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181635518.jpeg',
  'text': 'the snake slithers quietly through the underbrush .'},
 4: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181640606.jpeg',
  'text': 'beautiful picture taken of a river running through the valley .'}}

In [20]:
os.path.exists('/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/181640606.jpeg')

True

## Test

In [21]:
ann_test = os.listdir("/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/annotations")
labels_test = open("/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/annotations/TestStoriesFin.json")
labels_test = json.load(labels_test)

In [22]:
labels_test['45531']

{'sent_ids': ['227655', '227656', '227657', '227658', '227659'],
 'img_ids': ['1741625', '1741640', '1741639', '1741633', '1741630'],
 'album_id': '44277',
 'text': ['i was so excited to be heading to the crafts fair .',
  'when i arrived i saw a great booth with a variety of great crafts .',
  "i stopped at chatted at my friend [female] 's booth for a bit .",
  'there were even booths set up for all of the kids .',
  "i found some awesome crafts at the fair , i 'm really happy that i went ."]}

In [23]:
test = {}
for s in labels_test:
    for idx, im in enumerate(labels_test[s]['img_ids']):
        image_path = f"/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/{im}.jpeg"
        key = f"story_{s}"
        if key not in test:
            test[key] = {}
        test[key][idx] = {}
        test[key][idx]['image_path'] = image_path
        test[key][idx]['text'] = labels_test[s]['text'][idx]


In [24]:
len(test)

2269

In [25]:
print(list(test.keys())[0])
test[list(test.keys())[0]]

story_45531


{0: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741625.jpeg',
  'text': 'i was so excited to be heading to the crafts fair .'},
 1: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741640.jpeg',
  'text': 'when i arrived i saw a great booth with a variety of great crafts .'},
 2: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741639.jpeg',
  'text': "i stopped at chatted at my friend [female] 's booth for a bit ."},
 3: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741633.jpeg',
  'text': 'there were even booths set up for all of the kids .'},
 4: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741630.jpeg',
  'text': "i found some awesome crafts at the fair , i 'm really happy that i went ."}}

## Val

In [26]:
ann_val = os.listdir("/home/jay.je/IMspiredStoryTelling/datasets/VIST/val/annotations")
labels_val = open("/home/jay.je/IMspiredStoryTelling/datasets/VIST/val/annotations/ValStoriesFin.json")
labels_val = json.load(labels_val)

In [27]:
labels_val[list(labels_val.keys())[0]]

{'sent_ids': ['202350', '202351', '202352', '202353', '202354'],
 'img_ids': ['693397887', '695160730', '694227508', '693397865', '694227468'],
 'album_id': '72157600601428727',
 'text': ['my sister arrived early to help me with the family bar bq .',
  'every one else arrived soon after .',
  'dad manned the grill .',
  'there was so much food and it was all delicious .',
  'we ended the day shooting off some fireworks .']}

In [28]:
val = {}
for s in labels_val:
    for idx, im in enumerate(labels_val[s]['img_ids']):
        image_path = f"/home/jay.je/IMspiredStoryTelling/datasets/VIST/val/images/{im}.jpeg"
        key = f"story_{s}"
        if key not in val:
            val[key] = {}
        val[key][idx] = {}
        val[key][idx]['image_path'] = image_path
        val[key][idx]['text'] = labels_val[s]['text'][idx]


In [29]:
len(val)

2223

# 2. Fine-tune on the VIST dataset

    Actually, I think we do not need to fine-tune the feature extractor.
    Just fine-tune the decoder based on the features produced by the extractor <<


    SOMEHOW have the dataloader load 5 images (so it's 5 x 3 x H x W)

    

In [46]:
train['story_26929']

{0: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/258058601.jpeg',
  'text': 'the party has many pictures .'},
 1: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/258058806.jpeg',
  'text': 'the guys meet [male] .'},
 2: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/258059074.jpeg',
  'text': 'they have subway .'},
 3: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/258059285.jpeg',
  'text': 'they go to the museum .'},
 4: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/258059887.jpeg',
  'text': 'they are flying back home .'}}

In [47]:
class CustomImageDataset(Dataset):
    
    def __init__(self, stories_annotations, img_dir=None, transform=None, target_transform=None):
        # self.img_labels = annotations_file
        # self.img_dir = img_dir
        self.stories = stories_annotations
        self.stories_keys = list(stories_annotations.keys())
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        """Returns the number of samples in our dataset
        """
        return len(list(self.stories.keys()))

    def __getitem__(self, idx):

        if idx >= self.__len__():
            return
        story_num = self.stories_keys[idx]
        
        # get label
        # label = self.stories[story_num]['text']
        # label = tokenizer(label, padding="max_length").input_ids # omitted max_target_lengths
        # label = torch.tensor(label).squeeze()
        inputs = {}
        imgs = []
        labs = []
        atts = []

        try:
            prev_img = None
            
            for seq in self.stories[story_num]:
                img_source = self.stories[story_num][seq]['image_path']
                img = read_image(img_source)
                img = feature_extractor(images=img, return_tensors="pt").pixel_values.squeeze()
                imgs.append(img.clone().detach())
                
                lab_source = self.stories[story_num][seq]['text']
                tok_out = tokenizer(lab_source, padding="max_length")
                lab = tok_out.input_ids
                lab = torch.tensor(lab).squeeze()
                labs.append(lab.clone().detach())

                # att = tok_out.attention_mask
                # att = torch.tensor(att).squeeze()
                # atts.append(att.clone().detach())

                
            inputs['pixel_values'] = torch.stack(imgs)
            inputs['labels'] = torch.stack(labs)
            # inputs['attention_mask'] = torch.stack(atts)
            
        except:
            return self.__getitem__(idx+1)
        
        return inputs

In [48]:
print(len(train))
print(len(val))

5149
2223


In [49]:
train_idx = random.sample(list(train.keys()), 1000)
val_idx = random.sample(list(val.keys()), 200)

In [50]:
train_fin = {}
for k in train_idx:
    train_fin[k] = train[k]

val_fin = {}
for k in val_idx:
    val_fin[k] = val[k]

In [52]:
custom_train_data = CustomImageDataset(train_fin)
train_dataloader = DataLoader(custom_train_data, batch_size=1, shuffle=True) # , collate_fn=my_collate

In [53]:
custom_val_data = CustomImageDataset(val_fin)
val_dataloader = DataLoader(custom_val_data, batch_size=1, shuffle=True) # , collate_fn=my_collate

In [54]:
output = next(iter(train_dataloader))
print(output['pixel_values'].shape)
print(output['labels'].shape)

torch.Size([1, 5, 3, 224, 224])
torch.Size([1, 5, 1024])


In [56]:
output = next(iter(val_dataloader))
print(output['pixel_values'].shape)
print(output['labels'].shape)

torch.Size([1, 5, 3, 224, 224])
torch.Size([1, 5, 1024])


## Now try to train the model

In [57]:
from transformers import Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments
import evaluate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [58]:
metric = evaluate.load("rouge")

In [59]:
ignore_pad_token_for_loss = True

In [60]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    if ignore_pad_token_for_loss:
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds,
                                                     decoded_labels)

    result = metric.compute(predictions=decoded_preds,
                            references=decoded_labels,
                            use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    return result

In [61]:
def train_one_epoch(epoch_index, tb_writer = None):
    """
    Idea 1: The next sample gets the previously generated model output in the model decoder part only (previous sentence)
    Idea 2: The next sample gets the previously generated feature extractor (previous image)
    """

    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        inputs = data

        # for each image, can iterate!
        # eg shape: (1, 5, 3, H, W), (1, 5, 2056)
        
        # Zero your gradients for every batch!
        last_pix = None
        for k in range(5):
            
            optimizer.zero_grad()
            
            pix = inputs['pixel_values'][0][k].unsqueeze(0).to(device)
            lab = inputs['labels'][0][k].unsqueeze(0).to(device)
            
            # Make predictions for this batch
            curr_pix = encoder(pix, return_dict=True).last_hidden_state
            if last_pix != None:
                pix = torch.concat((last_pix, curr_pix), 1)
            else:
                pix = torch.concat((curr_pix, curr_pix), 1)
            last_pix = curr_pix.detach()

            outputs = decoder(input_ids = lab, labels = lab, encoder_hidden_states = pix, output_hidden_states = True
                             )
            # Compute the loss and its gradients
            loss = outputs.loss
            # loss = loss_fn(outputs, inputs['labels'])
            loss.backward()
    
            # Adjust learning weights
            optimizer.step()

            # Gather data and report
            running_loss += loss.item()
        if i % 100 == 0:
            last_loss = running_loss / (100) # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_dataloader) + i + 1
            # tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

In [62]:
print(torch.__version__)
print(torch.version.cuda)

2.0.1+cu118
11.8


In [63]:
encoder = model.base_model.encoder
decoder = model.base_model.decoder
encoder.to(device)
decoder.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (crossattention): GPT2Attention(
          (c_attn): Conv1D()
          (q_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_cross_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation(

In [64]:
EPOCHS = 1
epoch_number = 0
best_vloss = 1_000_000

optimizer = torch.optim.SGD(decoder.parameters(), lr=0.001, momentum=0.9)

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    decoder.train()
    avg_loss = train_one_epoch(epoch_number) #writer)

    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    # print("Got to eval")
    decoder.eval()
    vrunning_vloss = 0
    with torch.no_grad():
        for i, vdata in enumerate(val_dataloader):
  
            vinputs = vdata
            last_pix = None
            
            for k in range(5):
                pix = vinputs['pixel_values'][0][k].unsqueeze(0).to(device)
                lab = vinputs['labels'][0][k].unsqueeze(0).to(device)
                
                # Make predictions for this batch
                curr_pix = encoder(pix, return_dict=True).last_hidden_state
                if last_pix != None:
                    pix = torch.concat((last_pix, curr_pix), 1)
                else:
                    pix = torch.concat((curr_pix, curr_pix), 1)
                last_pix = curr_pix.detach()
                
                voutputs = decoder(input_ids = lab, labels = lab, encoder_hidden_states = pix, output_hidden_states = True
                                  # attention_mask = att
                                 )

                vloss = voutputs.loss
                vrunning_vloss += vloss.item()
            
            if i % 100 == 0:
                last_loss = vrunning_vloss / (100) # loss per batch
                print('  batch {} loss: {}'.format(i + 1, last_loss))
                vrunning_vloss = 0.
    
            avg_vloss = running_vloss / (i + 1)
            print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    print(f"At epoch {epoch} because avg loss was {avg_vloss} and best loss was {best_vloss}")
    if avg_vloss < best_vloss:
        # print(f"Saving...")
        # encoder.save_pretrained(f'./model_epoch{epoch}_encoder_VIST_5000/')
        # decoder.save_pretrained(f'./model_epoch{epoch}_decoder_2imgs_VIST_5000/')
        # best_vloss = avg_vloss
        

EPOCH 1:


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


  batch 1 loss: 0.038397381529211995
  batch 101 loss: 0.4802251105569303


Corrupt JPEG data: 6299 extraneous bytes before marker 0xd9
Corrupt JPEG data: 5957 extraneous bytes before marker 0xd9
Corrupt JPEG data: 7571 extraneous bytes before marker 0xd9


  batch 201 loss: 0.4178759583551437
  batch 301 loss: 0.39035042848438023
  batch 401 loss: 0.3832549136132002
  batch 501 loss: 0.3683503209985793
  batch 601 loss: 0.3890676255710423




  batch 701 loss: 0.35347841528244317
  batch 801 loss: 0.34859455096535386
  batch 901 loss: 0.3511508968565613




  batch 1 loss: 0.0023934620060026644
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LOSS train 0.3511508968565613 valid 0.0
LO

KeyboardInterrupt: 

In [92]:
sample = test['story_45531']

In [93]:
sample

{0: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741625.jpeg',
  'text': 'i was so excited to be heading to the crafts fair .'},
 1: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741640.jpeg',
  'text': 'when i arrived i saw a great booth with a variety of great crafts .'},
 2: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741639.jpeg',
  'text': "i stopped at chatted at my friend [female] 's booth for a bit ."},
 3: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741633.jpeg',
  'text': 'there were even booths set up for all of the kids .'},
 4: {'image_path': '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741630.jpeg',
  'text': "i found some awesome crafts at the fair , i 'm really happy that i went ."}}

In [94]:
imgs = []
for k in sample:
    imgs.append(sample[k]['image_path'])

In [95]:
imgs

['/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741625.jpeg',
 '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741640.jpeg',
 '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741639.jpeg',
 '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741633.jpeg',
 '/home/jay.je/IMspiredStoryTelling/datasets/VIST/test/images/1741630.jpeg']

In [88]:
model_test = VisionEncoderDecoderModel(encoder = encoder, decoder = decoder)

In [99]:
def get_model_generation(model, custom_loader):

    ls_out = []

    output = next(iter(custom_loader))

    prev_enc = None
    for i in range(5):
        pix = output['pixel_values'][0][i].unsqueeze(0).to(device)
        # att = output['attention_mask'][0][i].unsqueeze(0).to(device)

        print(f"{i}th image")
        print(pix.shape)
        # print(att.shape)

        enc_out = model.encoder(pix)
        
        if prev_enc != None:
            enc_out.last_hidden_state = torch.concat((prev_enc, enc_out.last_hidden_state),1)
        
        else:
            enc_out.last_hidden_state = torch.concat((enc_out.last_hidden_state, enc_out.last_hidden_state),1)

        print("at prev enc stage")
        prev_enc = enc_out.last_hidden_state

        print("outout state")
        outputids = model.generate(encoder_outputs=enc_out, **gen_kwargs)# attention_mask = att,

        print(tokenizer.batch_decode(outputids))
        ls_out.append(outputids) 
    
    return ls_out

In [100]:
get_model_generation(model_test, val_dataloader)

0th image
torch.Size([1, 3, 224, 224])
at prev enc stage
outout state
['<|endoftext|>.<|endoftext|>']
1th image
torch.Size([1, 3, 224, 224])
at prev enc stage
outout state
['<|endoftext|>.<|endoftext|>']
2th image
torch.Size([1, 3, 224, 224])
at prev enc stage
outout state
['<|endoftext|>.<|endoftext|>']
3th image
torch.Size([1, 3, 224, 224])
at prev enc stage
outout state
['<|endoftext|>.<|endoftext|>']
4th image
torch.Size([1, 3, 224, 224])
at prev enc stage
outout state
['<|endoftext|>.<|endoftext|>']


[tensor([[50256,   764, 50256]], device='cuda:0'),
 tensor([[50256,   764, 50256]], device='cuda:0'),
 tensor([[50256,   764, 50256]], device='cuda:0'),
 tensor([[50256,   764, 50256]], device='cuda:0'),
 tensor([[50256,   764, 50256]], device='cuda:0')]

## See model output

In [105]:
decoder.generate # decoder has generate method

<bound method GenerationMixin.generate of GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (crossattention): GPT2Attention(
          (c_attn): Conv1D()
          (q_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_cross_attn): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Co

In [107]:
sample = labels['20262']
sample['image_paths'] = []
for im in sample['img_ids']:
    sample['image_paths'].append(f'/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/{im}.jpeg')

In [118]:
def predict_step2(image_paths):
    
    preds_all = []
    prev_pix = None
    for idx, image_path in enumerate(image_paths):

        try:
            i_image = Image.open(image_path)
        except:
            return None
            
        if i_image.mode != "RGB":
            i_image = i_image.convert(mode="RGB")

        org_pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values.to(device)

        curr_pix = encoder(org_pixel_values).last_hidden_state
        
        if prev_pix != None:
            pixel_values = torch.concat((prev_pix, curr_pix), 1)
        else:
            pixel_values = torch.concat((curr_pix, curr_pix), 1)
        prev_pix = curr_pix
            
        pixel_values = pixel_values.to(device)
        print(pixel_values.shape)
        
        output_ids = decoder.generate(pixel_values = pixel_values, **gen_kwargs)
        print(output_ids)
    
        preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        preds = [pred.strip() for pred in preds]

        print(preds)
        preds_all.append(preds)
            
    return preds_all

In [119]:
predict_step2(sample['image_paths']) # train images

torch.Size([1, 394, 768])


ValueError: The following `model_kwargs` are not used by the model: ['pixel_values'] (note: typos in the generate arguments will also show up in this list)

In [85]:
labels_test['45531']

{'sent_ids': ['227655', '227656', '227657', '227658', '227659'],
 'img_ids': ['1741625', '1741640', '1741639', '1741633', '1741630'],
 'album_id': '44277',
 'text': ['i was so excited to be heading to the crafts fair .',
  'when i arrived i saw a great booth with a variety of great crafts .',
  "i stopped at chatted at my friend [female] 's booth for a bit .",
  'there were even booths set up for all of the kids .',
  "i found some awesome crafts at the fair , i 'm really happy that i went ."]}

In [87]:
sample

{'sent_ids': ['227655', '227656', '227657', '227658', '227659'],
 'img_ids': ['1741625', '1741640', '1741639', '1741633', '1741630'],
 'album_id': '44277',
 'text': ['i was so excited to be heading to the crafts fair .',
  'when i arrived i saw a great booth with a variety of great crafts .',
  "i stopped at chatted at my friend [female] 's booth for a bit .",
  'there were even booths set up for all of the kids .',
  "i found some awesome crafts at the fair , i 'm really happy that i went ."]}

In [81]:
sample['image_paths'] = []
for im in sample['img_ids']:
    sample['image_paths'].append(f'/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/{im}.jpeg')

In [82]:
sample

{'sent_ids': ['101310', '101311', '101312', '101313', '101314'],
 'img_ids': ['6631123887',
  '6631122551',
  '6631123221',
  '6643755711',
  '6631124129'],
 'album_id': '72157628706301801',
 'text': ['my boyfriend is a great guy . he decided to take me on a tour of our city to see the sights . this is a picture of him at the start .',
  'our first stop was to this old antique store . out front was this little toy . my favorite animal a tiger !',
  'walking further on , we came across some amazing street music . a dance or two later and we were still having out with them .',
  'leaving the music behind we headed south . the most beautiful church loomed in the distance . the day was coming to an end and started to become chilly .',
  'our solution , was coffee !'],
 'image_paths': ['/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/6631123887.jpeg',
  '/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/6631122551.jpeg',
  '/home/jay.je/IMspiredStoryTelling/datasets/

In [230]:
os.path.exists('/home/jay.je/IMspiredStoryTelling/datasets/VIST/train/images/6631123887.jpeg')

True

In [242]:
predict_step2(sample['image_paths']) # train images

tensor([[50256,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764,
           764,   764,   764,   764,   764,   764,   764,   764,   764,   764]],
       device='cuda:0')
['...................................................................................................']
tensor([[50256,   764,   764,   764,   764,   764,   764,   764,   764,   764,
 

[['...................................................................................................'],
 ['...................................................................................................'],
 ['.......................................................................................'],
 ['......................................................................................'],
 ['.......................................................................................']]

Previous output 

    ['a man with a beard is looking at the camera',
     'a toy bear sitting on top of a box',
     'a crowd of people standing in front of a building',
     'a large building with a clock on the front of it',
     'a man holding a cup of coffee in front of his face']

In [79]:
predict_step(c_test['story_13']['image_paths'])

tensor([[50256,   732,   389,  2045,   379,   281,  1468,  4590,   286,   257,
          1448,   286,  4695,   287,   257,  2214,   286,  8701,   290,  7150,
           764, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256],
        [50256,   732,   460,   766,   326,   612,   373,   257,  1256,   286,
          3404,   287,   262,  4286,   764, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 

['we are looking at an old photo of a group of animals in a field of grass and trees.',
 'we can see that there was a lot of stuff in the picture.',
 'we are looking at a woman who is dressed in a costume....',
 'we see a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a bikini and a painting of a woman in a',
 'we are looking at a painting of a group of animals.']

Original input

    ['a painting of a castle with a bunch of animals on top of it',
     'a close up image of a close up image of a bird',
     'a painting of a woman holding a pink umbrella',
     'a painting of a woman in a bathing suit',
     'a painting of a cartoon character on a wall']