In [1]:
import open_clip
import torch
import os
from PIL import Image

model, _, transform = open_clip.create_model_and_transforms(
  model_name="coca_ViT-L-14",
  pretrained="mscoco_finetuned_laion2B-s13B-b90k"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
model.to(device)

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


CoCa(
  (text): TextTransformer(
    (token_embedding): Embedding(49408, 768)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_drop

In [2]:

root_image_path = "yars_data/photos/"
root_path = "yars_data"
with open('out/paths.txt', 'r') as file:
    file_names = file.readlines()
    
image_paths = [os.path.join(root_image_path + file_name.strip()) for file_name in file_names[1:]]

# Display the combined paths (normally you would load images using a library like PIL or OpenCV)
image_paths[:5]


['yars_data/photos/pve7D6NUrafHW3EAORubyw.jpg',
 'yars_data/photos/Le9rMdT8YFlvqr431LctIQ.jpg',
 'yars_data/photos/9kVdBkGWcKfCFzSwUXjQyw.jpg',
 'yars_data/photos/e0dD0np3hY3F8LoUtrNoPw.jpg',
 'yars_data/photos/xiyqMEgTl4B4ux047E_zqw.jpg']

In [3]:
def gather_photos(label: str, no_captions: bool):
    """
    takes the label (category of images to fetch) and filters photos.json for images with that label
    filters for images that 
    returns: list of photo objects (photo_id, label, caption, etc.)
    """ 
    with open(os.path.join(root_path, 'photos.json'), 'r') as file:
        all_photos = []
        for line in file:
            line = line.rstrip()
            try:
                photo_record = json.loads(line)
                # filers for only ones wiht captions by default
                has_captions = (len(photo_record['caption']) > 0) != no_captions
                if photo_record['label'] == label and has_captions:
                    all_photos.append(root_image_path + photo_record["photo_id"] + ".jpg")
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e} at line: {line}")
        return all_photos
no_caption_photos = gather_photos('food', no_captions=True)
caption_photos = gather_photos('food', no_captions=False)
print(no_caption_photos[:1])

['yars_data/photos/H52Er-uBg6rNrHcReWTD2w.jpg']


In [4]:
# continue generating captions
def read_last_line(filename):
    with open(filename, 'rb') as f:
        f.seek(-2, 2)  # Jump to the second last byte.
        while f.read(1) != b'\n':  # Until EOL is found...
            f.seek(-2, 1)
        last_line = f.readline().decode()
    return last_line

last_line = read_last_line('captions.txt')

def extract_image_path(last_line):
    # Extract everything before the colon which includes the path and image_id
    image_path = last_line.split(':')[0].strip()
    return image_path

def slice_image_ids(image_ids, start_id):
    if start_id in image_ids:
        start_index = image_ids.index(start_id)
        return image_ids[start_index + 1:]
    else:
        return []  # or return image_ids to start from the beginning if ID not found


image_id = extract_image_path(last_line)
print(image_id)
no_caption_photos = slice_image_ids(no_caption_photos, image_id)[1:]
print("Filtered Image IDs:", no_caption_photos[:5])


yars_data/photos/4Zia9NkAfQNjMfcIDhwJ-g.jpg
Filtered Image IDs: []


In [5]:
captions = []
total_photos = len(no_caption_photos)

print(total_photos)

with open("captions.txt", "a") as file:
    for i, image_path in enumerate(no_caption_photos):
        im = Image.open(image_path).convert("RGB")
        im = transform(im).unsqueeze(0)
        im = im.to(device)

        with torch.no_grad(), torch.cuda.amp.autocast():
          generated = model.generate(im, generation_type='top_k')
        
        caption = open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", "")
        captions.append({"image": image_path, "caption": caption})
        file.write(f"{image_path}: {caption}\n")
        file.flush()
        
        print(f"Processed {i+1}/{total_photos}: {image_path}", end='\r', flush=True)

0


In [6]:
with open("captions.txt", "r") as file:
    for line in file:
        image_data = line.split(":")
        img_id = image_data[0].split("/")[-1]
        
# ignore for now
def label_photos(photos, captions):
    """
    takes the list of (previously uncaptioned) pictures, apply the captions to them, and create a new image set
    returns: None
    """ 
    with open(os.path.join(root_path, 'photos.json'), 'r') as file:
        
        for line in file:
            line = line.rstrip()
            try:
                photo_record = json.loads(line)
                # filers for only ones wiht captions by default
                has_captions = len(photo_record['caption']) > 0 != no_captions
                if photo_record['label'] == label and has_captions:
                    all_photos.append(root_image_path + photo_record["photo_id"] + ".jpg")
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON: {e} at line: {line}")
        return all_photos