<a href="https://colab.research.google.com/github/mobarakol/tutorial_captioning/blob/main/Explaining_Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

src: https://shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_captioning/Image%20Captioning%20using%20Open%20Source.html

In [2]:
!git clone https://github.com/ruotianluo/ImageCaptioning.pytorch.git

Cloning into 'ImageCaptioning.pytorch'...
remote: Enumerating objects: 2268, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 2268 (delta 0), reused 4 (delta 0), pack-reused 2264[K
Receiving objects: 100% (2268/2268), 1.37 MiB | 14.93 MiB/s, done.
Resolving deltas: 100% (1601/1601), done.


In [1]:
%cd /content/ImageCaptioning.pytorch

/content/ImageCaptioning.pytorch


In [5]:
!curl -L -s -o model-best.pth 'https://drive.google.com/drive/folders/1OsB_jLDorJnzKz6xsOfk1n493P3hwOP0'
!curl -L -s -o infos_fc_nsc-best.pkl 'https://drive.google.com/drive/folders/1OsB_jLDorJnzKz6xsOfk1n493P3hwOP0'
!curl -L -s -o resnet101 'https://drive.google.com/drive/folders/0B7fNdx_jAqhtbVYzOURMdDNHSGM'
!mv resnet101 data/imagenet_weights 

In [6]:
!pip -q install shap yacs lmdbdict transformers
!pip -q install git+https://github.com/ruotianluo/meshed-memory-transformer.git

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/ruotianluo/meshed-memory-transformer.git
  Cloning https://github.com/ruotianluo/meshed-memory-transformer.git to /tmp/pip-req-build-hzwguqik
  Running command git clone -q https://github.com/ruotianluo/meshed-memory-transformer.git /tmp/pip-req-build-hzwguqik
Building wheels for collected packages: meshed-memory-transformer
  Building wheel for meshed-memory-transformer (setup.py) ... [?25l[?25hdone
  Created wheel for meshed-memory-transformer: filename=meshed_memory_transformer-0.0.1-py3-none-any.whl size=39551 sha256=ee5bc636e364409df9e29bed6dc2df233992386bf52dc14131a0ff30a865b635
  Stored in directory: /tmp/pip-ephem-wheel-cache-7bwcj143/wheels/89/75/b0/1210778401d564ce5daceae0bed8ee6089fe7b65c7224d7e78
Successfully built meshed-memory-transformer
Installing collected packages: meshed-memory-transformer
Successfully installed meshed-memory-transfor

In [7]:
import os
import shap
from shap.utils.image import *

In [8]:
#change PREFIX to have absolute path of cloned directory of ImageCaptioning.pytorch
PREFIX = r"/content/ImageCaptioning.pytorch"
os.chdir(PREFIX)

# directory of images to be explained
DIR = './test_images/'
# creates or empties directory if it already exists
make_dir(DIR)
add_sample_images(DIR)

# directory for saving masked images
DIR_MASKED = './masked_images/'

In [9]:
import captioning
import captioning.models as models
import captioning.utils.eval_utils as eval_utils
import captioning.utils.misc as utils
import captioning.modules.losses as losses
from captioning.data.dataloader import *
from captioning.data.dataloaderraw import *
import gc
import sys
import torch
from transformers import AutoTokenizer,AutoModelForSeq2SeqLM

# to suppress verbose output from open source model
from contextlib import contextmanager
@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        old_stderr = sys.stderr
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

In [10]:
class ImageCaptioningPyTorchModel:
    """
    Wrapper class to get image captions using Resnet model from setup above.
    Note: This class is being used instead of tools/eval.py to get predictions (captions).
    To get more context for this class, please refer to tools/eval.py file.
    """

    def __init__(self, model_path, infos_path, cnn_model = "resnet101", device = "cuda"):
        """
        Initializing the class by loading torch model and vocabulary at path given and using Resnet weights stored in data/imagenet_weights.
        This is done to speeden the process of getting image captions and avoid loading the model every time captions are needed.
        Parameters
        ----------
        model_path  : pre-trained model path
        infos_path  : pre-trained infos (vocab) path
        cnn_model   : resnet model weights to use; options: "resnet101" (default), "resnet152"
        device      : "cpu" or "cuda" (default)
        """

        # load infos
        with open(infos_path, 'rb') as f:
            print(f, infos_path)
            infos = utils.pickle_load(f)
        opt = infos['opt']

        # setup the model
        opt.model = model_path
        opt.cnn_model = cnn_model
        opt.device = device
        opt.vocab = infos['vocab'] # ix -> word mapping
        model = models.setup(opt)
        del infos
        del opt.vocab
        model.load_state_dict(torch.load(opt.model, map_location='cpu'))
        model.to(opt.device)
        model.eval()
        crit = losses.LanguageModelCriterion()

        # setup class variables for call function
        self.opt = opt
        self.model = model
        self.crit = crit
        self.infos_path = infos_path

        # free memory
        torch.cuda.empty_cache()
        gc.collect()


    def __call__(self, image_folder, batch_size):
        """
        Function to get captions for images placed in image_folder.
        Parameters
        ----------
        image_folder: folder of images for which captions are needed
        batch_size  : number of images to be evaluated at once
        Output
        -------
        captions    : list of captions for images in image_folder (will return a string if there is only one image in folder)
        """

        # setting eval options
        opt = self.opt
        opt.batch_size = batch_size
        opt.image_folder = image_folder
        opt.coco_json = ""
        opt.dataset = opt.input_json
        opt.verbose_loss = 0
        opt.verbose = False
        opt.dump_path = 0
        opt.dump_images = 0
        opt.num_images = -1
        opt.language_eval = 0

        # loading vocab
        with open(self.infos_path, 'rb') as f:
            infos = utils.pickle_load(f)
        opt.vocab = infos['vocab']

        # creating Data Loader instance to load images
        if len(opt.image_folder) == 0:
            loader = DataLoader(opt)
        else:
            loader = DataLoaderRaw({'folder_path': opt.image_folder,
                                    'coco_json': opt.coco_json,
                                    'batch_size': opt.batch_size,
                                    'cnn_model': opt.cnn_model})

        # when evaluating using provided pretrained model, vocab may be different from what is in cocotalk.json.
        # hence, setting vocab from infos file.
        loader.dataset.ix_to_word = opt.vocab
        del infos
        del opt.vocab

        # getting caption predictions
        _, split_predictions, _ = eval_utils.eval_split(self.model, self.crit, loader, vars(opt))
        captions = []
        for line in split_predictions:
            captions.append(line['caption'])

        # free memory
        del loader
        torch.cuda.empty_cache()
        gc.collect()

        return captions if len(captions) > 1 else captions[0]


# create instance of ImageCaptioningPyTorchModel
osmodel = ImageCaptioningPyTorchModel(model_path = "model-best.pth",
                        infos_path = "infos_fc_nsc-best.pkl",
                        cnn_model = "resnet101",
                        device = "cpu")

# create function to get caption using model created above
def get_caption(model, image_folder, batch_size):
    return model(image_folder, batch_size)

<_io.BufferedReader name='infos_fc_nsc-best.pkl'> infos_fc_nsc-best.pkl


UnpicklingError: ignored

In [26]:
torch.__version__

'1.12.1+cu113'

In [11]:
import pickle
data = pickle.load("infos_fc_nsc-best.pkl")

TypeError: ignored

In [17]:
with open("infos_fc_nsc-best.pkl", 'rb') as pickle_file:
    pickle_file.encoding = 'latin-1'
    content = pickle.load(pickle_file)



UnpicklingError: ignored

In [None]:
with open('mnist.pkl', 'rb') as f: 
    #file = pickle._Unpickler(f)
    f.encoding = 'latin1' 
    file = pickle.load(f) 
    #train_set, valid_set, test_set = pickle.load(file), 
    file.close()