# Image Captioning with BLIP

## 1. Introduction
In this notebook I will load a snapshot of the DynamoDB table where I have stored all of Reddit posts. For each post containing an image, I will use BLIP: *Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation* to generate a caption, so the fine-tuned LLM can interpret the image.

In [1]:
!pip install transformers

In [3]:
import pandas as pd
import boto3
import io
import gzip

## 2. Load and Prep Data
Read in the DynamoDB export from S3. This is a `json.gz` file, so I will need to decompress it first.

In [7]:
# Initialize S3 client
s3_client = boto3.client('s3')
bucket_name = 'sagemaker-us-east-1-513033806411'
object_key = 'reddit/funny/AWSDynamoDB/01707665158792-b0452d79/data/lsr2eo7idm6prgqdcvn72k7joa.json.gz'

# Get the object from S3
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
content = response['Body'].read()

# Decompress and read into a pandas DataFrame
with gzip.GzipFile(fileobj=io.BytesIO(content)) as gzipfile:
    content = gzipfile.read()

df = pd.read_json(io.BytesIO(content), lines=True)  # Assuming the JSON is line-delimited

display(df.head())

Unnamed: 0,Item
0,"{'submissionId': {'S': '136cmyb'}, 'topComment..."
1,"{'submissionId': {'S': '16ok566'}, 'topComment..."
2,"{'submissionId': {'S': 'poe4fo'}, 'topComment'..."
3,"{'submissionId': {'S': '13xn4b9'}, 'topComment..."
4,"{'submissionId': {'S': 'mprm3j'}, 'topComment'..."


Since this is coming from DynamoDB, which doesn't enforce a strict schema, all of the items are stored as JSON objects. I will extract the values to create new columns in the dataframe with the function below.

In [9]:
# Function to extract values from the DynamoDB format
def extract_values(row):
    return {k: list(v.values())[0] for k, v in row.items()}

# Apply the transformation to each row in the DataFrame
df_transformed = df['Item'].apply(lambda row: extract_values(row))

# Convert the series of dictionaries into a DataFrame
df_flat = pd.json_normalize(df_transformed)

display(df_flat.head())

Unnamed: 0,submissionId,topComment,numComments,topCommentScore,createdUtc,score,url,title,body
0,136cmyb,"No seriously though, what the hell is happenin...",608,2989,1683095752,41173,https://v.redd.it/u050avdqrlxa1,"""So what are your intentions with my daughter?""",
1,16ok566,Oh Deere!\n\nThis is one of the few times I ca...,645,3557,1695313196,37380,https://i.redd.it/hsp7ro7jwmpb1.jpg,I wonder whose decision it was,
2,poe4fo,She did this dress so she could send her body ...,1530,5810,1631661684,73128,https://i.redd.it/xtbn4vzdyjn71.jpg,"The is me, circa 2014, on the left. Kim Kardas...",
3,13xn4b9,"Beth, you are a *horse* surgeon.",1697,8146,1685639071,50443,https://v.redd.it/c1g35aiuce3b1,It's never a veterinarian that they are lookin...,
4,mprm3j,This wouldn’t happen to be in Christiansburg V...,973,3074,1618275332,90033,https://i.redd.it/icbg029w9us61.jpg,A local music store in my town has had this si...,


## 3. Generate Image Captions
### a) Load the `Salesforce/blip-image-captioning-large` model from huggingface

In [6]:
from transformers import pipeline
import requests
from PIL import Image

# Initialize the pipeline
image_to_text = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")

config.json:   0%|          | 0.00/4.60k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/527 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/445 [00:00<?, ?B/s]

In [10]:
image_count = len(df_flat[df_flat['imageS3Url'] != "N/A"])
print(f"Generating captions for {image_count} images")

Generating captions for 1328 images


### b) Iterate through the images in S3

In [15]:
def generate_caption_from_s3(bucket_name, object_key):
    """Generate a caption for the image using BLIP"""
    
    # Get the object from S3
    try:
        response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
        image_content = response['Body'].read()

        # Load the image
        image = Image.open(io.BytesIO(image_content))

        # Generate caption from BLIP
        result = image_to_text(image, max_new_tokens=50)
        return result[0]['generated_text'] if result else "Caption not generated"
    
    except Exception as e:
        print(str(e))
        
# Initialize new description column for BLIP response
df['blip_image_description'] = ''

# Extract the object key from the S3 URL
df_flat['object_key'] = df_flat['imageS3Url'].apply(lambda x: x.replace(f's3://{bucket_name}/', ''))

# Generate descriptions for each image by applying the function to the 'object_key' column
for i in range(df_flat.shape[0]):
    object_key = df_flat.loc[i, 'object_key']
    if object_key != 'N/A':
        blip_description = generate_caption_from_s3(bucket_name, object_key)
        print(f"{i}: {blip_description}")
    else:
        blip_description = 'N/A'
    # assign value
    df_flat.loc[i, 'blip_image_description'] = blip_description


1: there are two people standing in a room with a tractor
2: a close up of a person in a black dress and a black man in a black suit
4: arafed image of a store front with a sign for a super shoes store
6: there is a pile of wood sitting in front of a house
7: someone is holding a banana in their hand on a tile floor
8: cartoon of a man with a computer on his head and a computer on his head
9: there is a picture of a toy on a scooter and a picture of a man on a scooter
12: there is a cat that is laying inside of a refrigerator
13: a close up of a baby sitting in a wooden box
15: someone is holding a credit card and a game controller
18: a cartoon of a dog looking out a window at a neighbor
...
2662: a cartoon of a comic strip with a man and a woman talking
2663: there is a car that has a sticker on the window
2665: there are many men posing for a picture together with one man pointing at the camera
2666: cartoon of a group of blue alien with a message saying,'this is our oldest liquid '

You might notice that some of these captions have typos like `araffe` and `arafed`. These typically occur where you expect the word to be `A`. For example: 
- ***araffeed*** *cup of ice cream with a cow on it*
- ***araffe*** *milk container with a message written on it*
- ***arafed*** *man standing on a muddy bank next to a car*

This is a known bug for this model and is believed to be due to huggingface dataset it was trained on. If this problem was a significant issue, I could add logic to remove these typos in the training set, as well as in production where BLIP is interpreting images for the system. However, LLM's handle typos quite well and this shouldn't affect the ability for the fine-tuned model to understand what is being depicted in the image.

### c) Use SageMaker Batch Transform (optional)
Iterating through ~1,300 images for this model works here using a `ml.p3.2xlarge` instance type. However, at greater scale, a better option would be using SageMaker Batch Transform to distribute the workload.

In [None]:
import sagemaker
import boto3
from sagemaker.huggingface import HuggingFaceModel

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

# Hub Model configuration. https://huggingface.co/models
hub = {
    'HF_MODEL_ID':'Salesforce/blip-image-captioning-large',
    'HF_TASK':'image-to-text'
}

# create Hugging Face Model Class
huggingface_model = HuggingFaceModel(
    transformers_version='4.26.0',
    pytorch_version='1.13.1',
    py_version='py39',
    env=hub,
    role=role, 
)

model_name = huggingface_model.create_model()


In [None]:
sagemaker = boto3.client('sagemaker')

# Start a transform job
response = sagemaker.create_transform_job(
    TransformJobName='blip-image-captioning-transform-job',
    ModelName=model_name,  # Use the model name from the model creation step
    MaxConcurrentTransforms=4,
    MaxPayloadInMB=6,
    BatchStrategy='MultiRecord',
    TransformInput={
        'DataSource': {
            'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-513033806411/reddit/funny/posts/'}
        },
        'ContentType': 'application/x-image',  # Ensure this matches the model's expected content type
    },
    TransformOutput={
        'S3OutputPath': 's3://sagemaker-us-east-1-513033806411/reddit/funny/captions/',
    },
    TransformResources={
        'InstanceType': 'ml.p3.2xlarge',
        'InstanceCount': 1
    }
)

## 4. Save Generated Captions
Now that the image captions have been generated, I'm going to update the DynamoDB table with the captions as well as save the dataframe to S3 for later use.
### a) Update DynamoDB items

In [11]:
# Initialize the DynamoDB client
dynamodb = boto3.client('dynamodb', region_name='us-east-1') 

def update_dynamodb_table_with_blip(df):

    for index, row in df.iterrows():
        dynamodb.update_item(
            TableName='funny-reddit-posts',
            Key={
                'submissionId': {'S': str(row['submissionId'])},
            },
            UpdateExpression='SET blipCaption = :val',
            ExpressionAttributeValues={
                ':val': {'S': str(row['blip_image_description'])},
            }
        )

# Update items with BLIP caption
blip_df = df[~df['blip_image_description'].isna()].reset_index(drop=True)
update_dynamodb_table_with_blip(blip_df)

In [12]:
# Convert DataFrame to CSV string
csv_buffer = io.StringIO()
df_flat.to_csv(csv_buffer)

# Initialize S3 client
s3_client = boto3.client('s3')

# Specify your bucket name and the desired key (path + filename in the bucket)
bucket_name = 'sagemaker-us-east-1-513033806411'
object_key = 'reddit/funny/data/blip_descriptions.csv'

# Upload the CSV string to S3
s3_client.put_object(Bucket=bucket_name, Body=csv_buffer.getvalue(), Key=object_key)

## 5. Conclusion
After this processing, I have image captions for about half of the Reddit posts in my dataset. In the next notebook I will be imputing captions for the posts with missing images (due to gif/video format).