# LoRA Comparison

In [1]:
import clip
import torch
from torch.utils.data import DataLoader
from clip_model import CaptionModel, generate
from dataset import InstagramDataset
from base_finetune import fine_tune
from transformers import GPT2Tokenizer



## Define BLEU Score (directly from pytorch implementation)

In [2]:
import torch
import collections
import math

def ngrams_iterator(token_list, ngrams):

    def _get_ngrams(n):
        return zip(*[token_list[i:] for i in range(n)])

    for x in token_list:
        yield x
    for n in range(2, ngrams + 1):
        for x in _get_ngrams(n):
            yield " ".join(x)

def _compute_ngram_counter(tokens, max_n):
    assert max_n > 0
    ngrams_counter = collections.Counter(tuple(x.split(" ")) for x in ngrams_iterator(tokens, max_n))

    return ngrams_counter


def bleu_score(candidate_corpus, references_corpus, max_n=4, weights=[0.25] * 4):


    assert max_n == len(weights), 'Length of the "weights" list has be equal to max_n'
    assert len(candidate_corpus) == len(
        references_corpus
    ), "The length of candidate and reference corpus should be the same"

    clipped_counts = torch.zeros(max_n)
    total_counts = torch.zeros(max_n)
    weights = torch.tensor(weights)

    candidate_len = 0.0
    refs_len = 0.0

    for (candidate, refs) in zip(candidate_corpus, references_corpus):
        current_candidate_len = len(candidate)
        candidate_len += current_candidate_len

        # Get the length of the reference that's closest in length to the candidate
        refs_len_list = [float(len(ref)) for ref in refs]
        refs_len += min(refs_len_list, key=lambda x: abs(current_candidate_len - x))

        reference_counters = _compute_ngram_counter(refs[0], max_n)
        for ref in refs[1:]:
            reference_counters = reference_counters | _compute_ngram_counter(ref, max_n)

        candidate_counter = _compute_ngram_counter(candidate, max_n)

        clipped_counter = candidate_counter & reference_counters

        for ngram, count in clipped_counter.items():
            clipped_counts[len(ngram) - 1] += count

        for i in range(max_n):
            # The number of N-grams in a `candidate` of T tokens is `T - (N - 1)`
            total_counts[i] += max(current_candidate_len - i, 0)

    if min(clipped_counts) == 0:
        return 0.0
    else:
        pn = clipped_counts / total_counts
        log_pn = weights * torch.log(pn)
        score = torch.exp(sum(log_pn))

        bp = math.exp(min(1 - refs_len / candidate_len, 0))

        return bp * score.item()

## Compare different models

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
clip_model, preprocess = clip.load("ViT-B/32", device="cpu", jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [5]:
train_data = InstagramDataset(clip_model, preprocess, tokenizer, device=device)
loader = DataLoader(dataset=train_data, batch_size=1)

In [6]:
ranks = ["4", "16", "64", "256"]
models = ["lora_" + r + ".pt" for r in ranks]
models

['lora_4.pt', 'lora_16.pt', 'lora_64.pt', 'lora_256.pt']

In [None]:
captioner = CaptionModel(10)
captioner.load_state_dict(torch.load('state_dicts/coco_weights.pt', map_location="cpu"))
candidate, reference = [], []
for i, (tokens, prefix, mask) in enumerate(loader):
    num_tokens = (mask.sum() - 10).long().item()
    reference.append(tokenizer.decode(tokens[:, :num_tokens].squeeze()).split())

    prefix_embed = captioner.clip_project(prefix).reshape(1, 10, -1)
    candidate.append(generate(captioner, tokenizer, embed=prefix_embed).split())
    print(f"{i+1}/{len(loader)}")
og_bleu = bleu_score(candidate, reference)

1/28360
2/28360
3/28360
4/28360
5/28360
6/28360
7/28360
8/28360
9/28360
10/28360
11/28360
12/28360
13/28360
14/28360
15/28360
16/28360
17/28360
18/28360
19/28360
20/28360
21/28360
22/28360
23/28360
24/28360
25/28360
26/28360
27/28360
28/28360
29/28360
30/28360
31/28360
32/28360
33/28360
34/28360
35/28360
36/28360
37/28360
38/28360
39/28360
40/28360
41/28360
42/28360
43/28360
44/28360
45/28360
46/28360
47/28360
48/28360
49/28360
50/28360
51/28360
52/28360
53/28360
54/28360
55/28360
56/28360
57/28360
58/28360
59/28360
60/28360
61/28360
62/28360
63/28360
64/28360
65/28360
66/28360
67/28360
68/28360
69/28360
70/28360
71/28360
72/28360
73/28360
74/28360
75/28360
76/28360
77/28360
78/28360
79/28360
80/28360
81/28360
82/28360
83/28360
84/28360
85/28360
86/28360
87/28360
88/28360
89/28360
90/28360
91/28360
92/28360
93/28360
94/28360
95/28360
96/28360
97/28360
98/28360
99/28360
100/28360
101/28360
102/28360
103/28360
104/28360
105/28360
106/28360
107/28360
108/28360
109/28360
110/28360
111/2836

In [None]:
print(og_bleu)