In [2]:
!git clone https://github.com/ypeleg/llama.git



In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
import random

import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms


import torch.nn as nn
import torchvision.models as models
import torch.optim as optim

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"  
torch.cuda.device_count()

3

In [4]:
os.makedirs("flickr8", exist_ok=True)
os.chdir("flickr8")

In [5]:
# !wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
# !wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
# !unzip -qq Flickr8k_Dataset.zip
# !unzip -qq Flickr8k_text.zip
# !rm Flickr8k_Dataset.zip Flickr8k_text.zip

In [6]:
!ls -lh

total 1.1G
-rw-r--r-- 1 qblocks qblocks 2.8M Oct 14  2013 CrowdFlowerAnnotations.txt
-rw-r--r-- 1 qblocks qblocks 339K Oct 14  2013 ExpertAnnotations.txt
drwxr-xr-x 1 qblocks qblocks 392K Oct  3  2012 Flicker8k_Dataset
-rw-r--r-- 1 qblocks qblocks 3.1M Feb 16  2012 Flickr8k.lemma.token.txt
-rw-r--r-- 1 qblocks qblocks 3.3M Oct 14  2013 Flickr8k.token.txt
-rw-r--r-- 1 qblocks qblocks 1.1G Dec  6  2021 Flickr8k_Dataset.zip
-rw-r--r-- 1 qblocks qblocks 2.3M Dec  6  2021 Flickr8k_text.zip
-rw-r--r-- 1 qblocks qblocks  26K Oct 10  2013 Flickr_8k.devImages.txt
-rw-r--r-- 1 qblocks qblocks  26K Oct 10  2013 Flickr_8k.testImages.txt
-rw-r--r-- 1 qblocks qblocks 152K Oct 10  2013 Flickr_8k.trainImages.txt
drwxrwxr-x 1 qblocks qblocks  124 Jul 15 14:49 __MACOSX
drwxr-xr-x 1 qblocks qblocks   90 Jul 17 02:29 checkpoint
drwxr-xr-x 1 qblocks qblocks  542 Jul 17 18:58 logs
-rw-r--r-- 1 qblocks qblocks 1.8K Oct 14  2013 readme.txt
drwxr-xr-x 1 qblocks qblocks    0 Jul 17 18:35 results


In [7]:
!cat readme.txt

If you use this corpus / data:

Please cite: M. Hodosh, P. Young and J. Hockenmaier (2013) "Framing Image Description as a Ranking Task: Data, Models and Evaluation Metrics", Journal of Artifical Intellegence Research, Volume 47, pages 853-899
http://www.jair.org/papers/paper3994.html


Captions, Dataset Splits, and Human Annotations :


Flickr8k.token.txt - the raw captions of the Flickr8k Dataset . The first column is the ID of the caption which is "image address # caption number"

Flickr8k.lemma.txt - the lemmatized version of the above captions 

Flickr_8k.trainImages.txt - The training images used in our experiments
Flickr_8k.devImages.txt - The development/validation images used in our experiments
Flickr_8k.testImages.txt - The test images used in our experiments


ExpertAnnotations.txt is the expert judgments.  The first two columns are the image and caption IDs.  Caption IDs are <image file name>#<0-4>.  The next three columns are the expert judgments for that image-caption pai

In [8]:
df = pd.read_csv("Flickr8k.token.txt", sep="\t", header=None)
df.columns = ["image", "caption"]
df['num'] = df.apply(lambda row: row['image'].split("#")[-1], axis=1)
df['image'] = df['image'].apply(lambda x: x.split("#")[0])
df

Unnamed: 0,image,caption,num
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...,0
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,1
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,2
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...,3
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...,4
...,...,...,...
40455,997722733_0cb5439472.jpg,A man in a pink shirt climbs a rock face,0
40456,997722733_0cb5439472.jpg,A man is rock climbing high in the air .,1
40457,997722733_0cb5439472.jpg,A person in a red shirt climbing up a rock fac...,2
40458,997722733_0cb5439472.jpg,A rock climber in a red shirt .,3


In [9]:
train_images, validation_images, test_images = [], [], []

with open("Flickr_8k.trainImages.txt", "r") as f:
    for line in f:
        train_images.append(line.strip())

with open("Flickr_8k.devImages.txt", "r") as f:
    for line in f:
        validation_images.append(line.strip())

with open("Flickr_8k.testImages.txt", "r") as f:
    for line in f:
        test_images.append(line.strip())

len(train_images), len(validation_images), len(test_images)

(6000, 1000, 1000)

In [10]:
train_images = train_images[:100]
validation_images = validation_images[:20]
test_images = test_images[:20]

In [11]:
img2idx = {img: idx + 1 for idx, img in enumerate(train_images + validation_images + test_images)}
img2idx["<PAD>"] = 0

In [12]:
!ls Flicker8k_Dataset | head

1000268201_693b08cb0e.jpg
1001773457_577c3a7d70.jpg
1002674143_1b742ab4b8.jpg
1003163366_44323f5815.jpg
1007129816_e794419615.jpg
1007320043_627395c3d8.jpg
1009434119_febe49276a.jpg
1012212859_01547e3f17.jpg
1015118661_980735411b.jpg
1015584366_dfcec3c85a.jpg
ls: write error: Broken pipe


In [13]:
def show_image(image, title=None):
    image[0] = image[0] * 0.229
    image[1] = image[1] * 0.224
    image[2] = image[2] * 0.225
    image[0] += 0.485
    image[1] += 0.456
    image[2] += 0.406

    image = image.numpy().transpose((1, 2, 0))


    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)

In [14]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])

# **Encoder**

In [15]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False) # upto you if you want to make this True or False

        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        
    def forward(self, images):
        features = self.resnet(images)
        features = features.view(features.size(0), -1)
        return features

In [16]:
device = torch.device("cuda")
encoder = EncoderCNN().to(device)



# **Decoder**

In [17]:
from datasets import load_dataset

import torch
import time
import evaluate
import pandas as pd
import numpy as np

2023-07-17 19:00:22.082715: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-17 19:00:23.605133: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/qblocks/.local/lib/python3.8/site-packages/cv2/../../lib64:
2023-07-17 19:00:23.605287: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /home/qblocks/.local/lib/python3.8/site-packages/cv2/../../lib64

In [18]:
%cd ../llama/
import llama
%cd ../flickr8/

/home/qblocks/nifty/flicker/llama
/home/qblocks/nifty/flicker/flickr8


In [19]:
MODEL = 'decapoda-research/llama-7b-hf'

tokenizer = llama.LLaMATokenizer.from_pretrained(MODEL)
tokenizer.pad_token_id = 0

original_model = llama.LLaMAForCausalLM.from_pretrained(MODEL)
original_model.to(device)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

LLaMAForCausalLM(
  (model): LLaMAModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LLaMADecoderLayer(
        (self_attn): LLaMAAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (mlp): LLaMAMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (lm_h

In [19]:
original_model

LLaMAForCausalLM(
  (model): LLaMAModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LLaMADecoderLayer(
        (self_attn): LLaMAAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (mlp): LLaMAMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
     

In [19]:
from tqdm import tqdm


class Flickr8kDataset(data.Dataset):
    def __init__(self, image_dir, image_names, transform=None):
        self.image_dir = image_dir
        self.image_ids = image_names
        self.transform = transform

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_file = f'{self.image_dir}/{image_id}'
        image = Image.open(image_file).convert('RGB')
        image = image.resize((512, 512))

        if self.transform:
            image = self.transform(image)

        return image, image_id

    def __len__(self):
        return len(self.image_ids)

    
dataset = Flickr8kDataset("Flicker8k_Dataset", list(img2idx.keys())[:-1], transform=transform)
dataloader = data.DataLoader(dataset, batch_size=128, shuffle=False)

encoder.eval()
img_emb = [torch.zeros((1, 2048))]

with torch.no_grad():
    for images, name in tqdm(dataloader):
        images = images.to(device)

        out = encoder(images)
        img_emb.append(out.cpu())

img_emb = torch.cat(img_emb)


# change Llama Embedding
class CustomEmbedding(nn.Module):
    def __init__(self):
        super(CustomEmbedding, self).__init__()
        self.llama_emb = nn.Embedding(32000, 4096)
        self.llama_emb.weight.data.copy_(original_model.model.embed_tokens.weight.data)

        self.img_emb = nn.Embedding.from_pretrained(img_emb, freeze=True)
        self.fc = nn.Linear(2048, 4096) # resnet.fc.in_features = 2048
        
    def forward(self, x):
        """
        x - [batch, seq_len]
        """
        llama = self.llama_emb(x[:, 1:])
        img = self.fc(self.img_emb(x[:, :1]))
        
        return torch.cat([img, llama], dim=1)
    
original_model.model.embed_tokens = CustomEmbedding().to(device)

100%|██████████| 63/63 [02:02<00:00,  1.94s/it]


In [20]:
original_model

LLaMAForCausalLM(
  (model): LLaMAModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=31999)
    (layers): ModuleList(
      (0-31): 32 x LLaMADecoderLayer(
        (self_attn): LLaMAAttention(
          (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (mlp): LLaMAMLP(
          (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)
          (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
     

In [None]:
image, caption = dataset[0]
show_image(image, caption)

In [None]:
image, caption = dataset[0]
encoder(image[None, :].to(device)).shape

# Prepare Dataset

In [1]:
import torch
import pandas as pd
from torch.utils.data import Dataset, random_split


class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.labels = []
        self.input_ids = []
        self.attn_masks = []
        
        prompt = "Write a caption for the image. \nCaption: \n"
        
        for image_name, caption in data.values:
            encodings_dict = tokenizer(prompt + caption, 
                                       truncation=True,
                                       max_length=max_length, 
                                       padding="max_length")
            
            input_ids = torch.tensor(encodings_dict['input_ids'])
            input_ids[0] = img2idx[image_name]
            
            self.input_ids.append(input_ids)
            self.attn_masks.append(torch.tensor(encodings_dict['attention_mask']))
            
    def __len__(self): 
        return len(self.input_ids)

    def __getitem__(self, idx): 
        return self.input_ids[idx], self.attn_masks[idx]

In [None]:
# max length
max_length = max([len(tokenizer.encode(caption)) for image, caption, _ in df.values])

# train and test dataset
image_data = lambda images: pd.DataFrame(images, columns=["image"]).merge(df.drop("num", axis=1), how="inner")

train_data = TextDataset(image_data(train_images), 
                         tokenizer, 
                         max_length=max_length)


test_data = TextDataset(image_data(test_images), 
                        tokenizer, 
                        max_length=max_length)

# Train

In [None]:
training_args = TrainingArguments(
                                  save_steps = 5000,
                                  warmup_steps = 10,
                                  logging_steps = 100,
                                  weight_decay = 0.05,
                                  num_train_epochs = 1,
                                  logging_dir = './logs',
                                  output_dir = './results',
                                  per_device_eval_batch_size = 1,
                                  per_device_train_batch_size = 1)

Trainer(model=original_model,
        args=training_args,
        eval_dataset=val_dataset,
        train_dataset=train_dataset,
        data_collator=lambda data: {'input_ids': torch.stack([f[0] for f in data]), 'attention_mask': torch.stack([f[1] for f in data]), 'labels': torch.stack([f[0] for f in data])}).train()