In [15]:
from PIL import Image
import torch
from torch import nn, optim
import glob
import os
import pandas as pd
import json
import numpy as np
import clip
from torch.utils.data import Dataset, DataLoader, BatchSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import random
from matplotlib.pyplot import imshow
import torchtext
import nltk, re, string, collections
from nltk.util import ngrams
import collections
%matplotlib inline
BATCH_SIZE = 128
EPOCH = 5

In [16]:
torch.cuda.set_per_process_memory_fraction(0.5, 0)
torch.cuda.empty_cache()

# Preparing Model and Data

In [17]:
IMG_ROOT = "./input/meme-project-raw"
IMG_SEG_ROOT = "./input/meme-raw-seg"
JSON_ROOT = "./input/meme-project-clean-json"
img_paths = glob.glob(os.path.join(IMG_ROOT, "*.jpg"))
img_seg_paths = glob.glob(os.path.join(IMG_SEG_ROOT, "*.jpg"))

d = {}
d_seg = {}

for img_path in img_paths:
    name = img_path.split("/")[-1].split(".")[0]
    if not os.path.exists(os.path.join(IMG_ROOT, name + ".json")):
        continue
    with open(os.path.join(JSON_ROOT, name+".json"), "r") as f:
        if len(d) / 500 in range((len(d) // 500) + 1):
            print(f"{len(d) + 1}: {os.path.join(IMG_ROOT, name + '.jpg')}")
        captions = json.load(f)
        temp = []
        for cap in captions:
            if "http" not in (cap[0] + ' ' + cap[1]) and len(cap[0] + ' ' + cap[1]) >= 8 and len(cap[0] + ' ' + cap[1]) <= 72:
                temp.append(cap[0] + ' ' + cap[1])
        d[img_path] = temp
        
i = 0
for img_path in img_seg_paths:
    i += 1
    name = img_path.split("/")[-1].split(".")[0]
    if not os.path.exists(os.path.join(JSON_ROOT, name + ".json")):
        print(os.path.join(IMG_SEG_ROOT, name + ".json not found, skipping"))
        img_seg_paths.remove(os.path.join(IMG_SEG_ROOT, name + ".jpg"))
        continue
    with open(os.path.join(JSON_ROOT, name + ".json"), "r") as f:
        if len(d_seg) / 500 in range((len(d_seg) // 500) + 1):
            print(f"{len(d_seg) + 1}: {os.path.join(IMG_SEG_ROOT, name + '.jpg')}")
        captions = json.load(f)
        temp = []
        for cap in captions:
            if "http" not in (cap[0]+ ' '+cap[1]) and len(cap[0]+ ' '+cap[1]) >= 8 and len(cap[0]+ ' '+cap[1]) <= 72:
                temp.append(cap[0]+ ' '+cap[1])
        d_seg[img_path] = temp
    if len(d_seg.keys()) != i:
        print(f"{img_path} \t{name} failed to load json")
        
len(d), len(d_seg)

1: ./input/meme-project-raw/absent-minded-looch.jpg
501: ./input/meme-project-raw/oblivious-activist-goat.jpg
1001: ./input/meme-project-raw/typical-boy.jpg
1501: ./input/meme-project-raw/guess-who-you.jpg
2001: ./input/meme-project-raw/dr-steve-brule.jpg
2501: ./input/meme-project-raw/vip2-gayle.jpg
1: ./input/meme-raw-seg/absent-minded-looch.jpg
501: ./input/meme-raw-seg/neymarin.jpg
1001: ./input/meme-raw-seg/zayn-malik1.jpg
1501: ./input/meme-raw-seg/kd-you-the-real-mvp-f.jpg
2001: ./input/meme-raw-seg/svobodus-vulgaris.jpg
2501: ./input/meme-raw-seg/this-is-stas.jpg


(3000, 2535)

## Splitting 20% for Validation

In [18]:
train_img_paths, test_img_paths = train_test_split(img_paths, test_size=0.2, random_state=42)
train_img_seg_paths, test_img_seg_paths = train_test_split(img_seg_paths, test_size=0.2, random_state=42)

d_train = {k: d[k] for k in train_img_paths}
d_test = {k: d[k] for k in test_img_paths}

d_seg_train = {k: d_seg[k] for k in train_img_seg_paths}
d_seg_test = {k: d_seg[k] for k in test_img_seg_paths}

print(f"base split: {len(d_train)}, {len(d_test)}")
print(f"custom split: {len(d_seg_train)}, {len(d_seg_test)}")

base split: 2400, 600
custom split: 2028, 507


## Loading Pre-trained CLIP Model and Preprocessor

In [19]:
clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
model_1e, preprocess_1e = clip.load("ViT-B/32", device=device, jit=False)
model_cs, preprocess_cs = clip.load("ViT-B/32", device=device, jit=False)

image = preprocess(Image.open("./input/meme-project-raw/-okay-.jpg")).unsqueeze(0).to(device)
image_1e = preprocess_1e(Image.open("./input/meme-project-raw/-okay-.jpg")).unsqueeze(0).to(device)
image_cs = preprocess_cs(Image.open("./input/meme-project-raw/-okay-.jpg")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
print(f"{image_1e.shape}, {image_cs.shape}, {text.shape}")

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 5.81 GiB total capacity; 4.27 GiB already allocated; 21.06 MiB free; 2.90 GiB allowed; 4.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.cuda.is_available()

True

## MemeDataset

In [None]:
class MemeDataset(Dataset):
    def __init__(self, data, preprocess):
        self.preprocess = preprocess
        self.img_paths = []
        self.captions = []
        for img_path, captions in data.items():
            for cap in captions:
                self.img_paths.append(img_path)
                self.captions.append(cap)
        self.processed_cache = {}
        for img_path in data:
            self.processed_cache[img_path] = self.preprocess(Image.open(img_path))
        self.img_paths_set = list(data.keys())
        self.path2label = {path: self.img_paths_set.index(path) for path in self.img_paths_set}
        
    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = self.processed_cache[img_path]
        caption = self.captions[idx]
        label = self.path2label[img_path]
        return image, caption, label

train_dataset = MemeDataset(d_train, preprocess)
test_dataset = MemeDataset(d_test, preprocess)

train_seg_dataset = MemeDataset(d_seg_train, preprocess_cs)
test_seg_dataset = MemeDataset(d_seg_test, preprocess_cs)

train_1e_dataset = MemeDataset(d_seg_train, preprocess_1e)  # Лишнее?
test_1e_dataset = MemeDataset(d_seg_test, preprocess_1e)

print(f"{len(train_dataset)}, {len(test_dataset)}, {train_dataset[0]}")
print(f"{len(train_seg_dataset)}, {len(test_seg_dataset)}, {train_seg_dataset[0]}")
print(f"{len(train_1e_dataset)}, {len(test_1e_dataset)}, {train_1e_dataset[0]}")

238418, 60319, (tensor([[[ 0.1493, -0.4346, -1.2813,  ...,  1.9011,  1.2150,  0.6165],
         [ 0.9376,  0.3099, -0.6390,  ...,  1.4632,  0.5289, -0.1280],
         [ 1.8281,  1.1274,  0.2077,  ...,  0.5289, -0.3470, -1.1499],
         ...,
         [-1.3689, -0.8142,  0.0325,  ...,  0.3683,  1.3026,  1.8573],
         [-0.5514,  0.2223,  0.9814,  ..., -0.5514,  0.3391,  1.0252],
         [-0.0550,  0.7917,  1.7114,  ..., -1.2083, -0.2156,  0.5143]],

        [[ 0.6341,  1.6397,  2.0449,  ..., -1.2568, -0.8666,  0.0038],
         [-0.5815,  0.6041,  1.6997,  ..., -0.9867,  0.0789,  1.2945],
         [-1.2268, -0.7616,  0.4991,  ..., -0.0262,  1.3845,  1.9848],
         ...,
         [ 2.0599,  1.8348,  0.7992,  ...,  0.2439, -0.8967, -1.2568],
         [ 1.6697,  0.7842, -0.5965,  ...,  1.5946,  0.3040, -0.7016],
         [ 0.8593, -0.3864, -1.1818,  ...,  2.0299,  1.3995,  0.2139]],

        [[-1.2669, -1.1105, -1.3380,  ..., -1.1247, -1.3522, -1.2527],
         [-1.3522, -1.1532, -

In [None]:
i = 0
for k,v in train_seg_dataset.path2label.items():
    i+=1
    print(f"{i}) {k}, {v}")
    if i == 10:
        break
        
i = 0
for k,v in train_1e_dataset.path2label.items():
    i+=1
    print(f"{i}) {k}, {v}")
    if i == 10:
        break

1) ./input/meme-raw-seg/super-ssau.jpg, 0
2) ./input/meme-raw-seg/female-internet-troll.jpg, 1
3) ./input/meme-raw-seg/y-u-no-bren-iwsnt.jpg, 2
4) ./input/meme-raw-seg/how-tough-are-you.jpg, 3
5) ./input/meme-raw-seg/dinosaur-director.jpg, 4
6) ./input/meme-raw-seg/kanyetothe.jpg, 5
7) ./input/meme-raw-seg/arnaldo-tirone.jpg, 6
8) ./input/meme-raw-seg/creepy-ash.jpg, 7
9) ./input/meme-raw-seg/reborn-logic.jpg, 8
10) ./input/meme-raw-seg/hoodie-faggot.jpg, 9
1) ./input/meme-raw-seg/super-ssau.jpg, 0
2) ./input/meme-raw-seg/female-internet-troll.jpg, 1
3) ./input/meme-raw-seg/y-u-no-bren-iwsnt.jpg, 2
4) ./input/meme-raw-seg/how-tough-are-you.jpg, 3
5) ./input/meme-raw-seg/dinosaur-director.jpg, 4
6) ./input/meme-raw-seg/kanyetothe.jpg, 5
7) ./input/meme-raw-seg/arnaldo-tirone.jpg, 6
8) ./input/meme-raw-seg/creepy-ash.jpg, 7
9) ./input/meme-raw-seg/reborn-logic.jpg, 8
10) ./input/meme-raw-seg/hoodie-faggot.jpg, 9


## BalancedBatchSampler (ensures no same class per batch)

In [21]:
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size
    
train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)

test_labels = torch.tensor([item[2] for item in test_dataset])
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)

train_seg_labels = torch.tensor([item[2] for item in train_seg_dataset])
train_seg_sampler = BalancedBatchSampler(train_seg_labels, BATCH_SIZE, 1)
train_seg_dataloader = DataLoader(train_seg_dataset, batch_sampler=train_seg_sampler)

test_seg_labels = torch.tensor([item[2] for item in test_seg_dataset])
test_seg_sampler = BalancedBatchSampler(test_seg_labels, BATCH_SIZE, 1)
test_seg_dataloader = DataLoader(test_seg_dataset, batch_sampler=test_seg_sampler)

train_1e_labels = torch.tensor([item[2] for item in train_1e_dataset])
train_1e_sampler = BalancedBatchSampler(train_1e_labels, BATCH_SIZE, 1)
train_1e_dataloader = DataLoader(train_1e_dataset, batch_sampler=train_1e_sampler)

test_1e_labels = torch.tensor([item[2] for item in test_1e_dataset])
test_1e_sampler = BalancedBatchSampler(test_1e_labels, BATCH_SIZE, 1)
test_1e_dataloader = DataLoader(test_1e_dataset, batch_sampler=test_1e_sampler)

In [22]:
for i, item in enumerate(train_seg_sampler):
    labels = []
    for idx in item:
        label = train_seg_dataset[idx][2]
        labels.append(label)
    break
    
for i, item in enumerate(train_1e_sampler):
    labels1e = []
    for idx in item:
        label = train_1e_dataset[idx][2]
        labels1e.append(label)
    break
    
print(f"{len(labels)}, {len(set(labels))}")
print(f"{len(labels1e)}, {len(set(labels1e))}")

128, 128
128, 128


In [23]:
print("seg:")
for batch in train_seg_dataloader:
    imgs, txts, labels = batch
    print(imgs.shape)
    print(len(txts))
    print(labels)
    print(labels.shape)
    print(torch.unique(labels).shape)
    break
    
print("\n1e:")
for batch in train_1e_dataloader:
    imgs, txts, labels = batch
    print(imgs.shape)
    print(len(txts))
    print(labels)
    print(labels.shape)
    print(torch.unique(labels).shape)
    break

seg:
torch.Size([128, 3, 224, 224])
128
tensor([ 309, 1052,  944,  577, 1449, 1614, 1491, 1279, 1049,  688, 1267,  855,
         415,  686, 1686, 1479, 1219,  971, 1566,  397, 1330, 1047, 1395,  586,
         177,  843,  930,  274,  326,  511, 1385,  488, 1380,  687, 1378, 1190,
        1936,  654,  950, 1748,  368, 2011,  232,  347,  278, 1407, 1698, 1468,
        1241, 2015, 1349, 1188,  866,  306,  926,  446,  775, 2023, 1605,  510,
         531, 1680, 1299,  393, 1909, 1773, 1597,   92, 1160,  621,   21,   69,
         160, 1558,  130, 1193, 1620, 1564, 1362, 1311, 1735, 1266, 1418, 1278,
        1465,  118, 1280, 1903, 1017, 1743,  254, 1083,  199, 1302,  951,  429,
         186, 1990, 1414, 1634,  502, 1112,  602,  471,  333,  994,  757,  784,
        1926, 1596,  884,  509,  165,  715, 1599,  494,  829,  230, 1980,  476,
        1423, 1677, 2004,  872, 1216, 1158, 1469, 1272])
torch.Size([128])
torch.Size([128])

1e:
torch.Size([128, 3, 224, 224])
128
tensor([1203, 1333,  483,  

# Training

In [24]:
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

if device == "cpu":
    model.float()
    model_1e.float()
    model_cs.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

optimizer_1e = optim.Adam(model_1e.parameters(), lr=1e-5)
scheduler_1e = optim.lr_scheduler.CosineAnnealingLR(optimizer_1e, len(train_1e_dataloader)*1)

optimizer_cs = optim.Adam(model_cs.parameters(), lr=1e-5)
scheduler_cs = optim.lr_scheduler.CosineAnnealingLR(optimizer_cs, len(train_seg_dataloader)*EPOCH)

In [25]:
# initTraining(model_cs, "cs", 4) => обучение модели model_cs, обученные файлы будут сохраняться с добавлением аппендикса "cs",
# обучение будет длиться 4 эпохи.
def initTraining(model, train_dataloader, test_dataloader, optimizer, scheduler, appendix="base", EPOCH=EPOCH):
    best_te_loss = 1e5
    best_ep = -1
    for epoch in range(EPOCH): # Здесь нужно обучать cs модель, посмотреть циклы ниже
        print(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
        step = 0
        tr_loss = 0
        model.train()
        pbar = tqdm(train_dataloader, leave=False)
        for batch in pbar:
            step += 1
            optimizer.zero_grad()

            images, texts, _ = batch
            images = images.to(device)
            texts = clip.tokenize(texts).to(device)
    #       print(images.shape, texts.shape)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            total_loss.backward()
            tr_loss += total_loss.item()
            if device == "cpu":
                optimizer.step()
                scheduler.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                scheduler.step()
                clip.model.convert_weights(model)
            pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
        tr_loss /= step

        step = 0
        te_loss = 0
        with torch.no_grad():
            model.eval()
            test_pbar = tqdm(test_dataloader, leave=False)
            for batch in test_pbar:
                step += 1
                images, texts, _ = batch
                images = images.to(device)
                texts = clip.tokenize(texts).to(device)
                logits_per_image, logits_per_text = model(images, texts)
                ground_truth = torch.arange(BATCH_SIZE).to(device)

                total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
                te_loss += total_loss.item()
                test_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
            te_loss /= step

        if te_loss < best_te_loss:
            best_te_loss = te_loss
            best_ep = epoch
            torch.save(model.state_dict(), f"best_{appendix}_model.pt")
        print(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {te_loss}")
    torch.save(model.state_dict(), f"last_{appendix}_model.pt")

In [None]:
initTraining(model_cs, train_seg_dataloader, test_seg_dataloader, optimizer_cs, scheduler_cs, "cs", 5)

In [26]:
initTraining(model_1e, train_1e_dataloader, test_1e_dataloader, optimizer_1e, scheduler_1e, "1e", 1)

running epoch 0, best test loss 100000.0 after epoch -1


                                        

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 5.81 GiB total capacity; 4.17 GiB already allocated; 21.06 MiB free; 2.90 GiB allowed; 4.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# Evaluating Precision on Validation Set

In [22]:
model.load_state_dict(torch.load("./input/tuned-clips/best_model.pt"))
model_cs.load_state_dict(torch.load("./input/tuned-clips/best_cs_model.pt"))
model_1e.load_state_dict(torch.load("./input/tuned-clips/best_1e_model.pt"))
NUM_NEG = 127
NUM_TEST = 1000

In [23]:
def evaluatePrecision(model, preprocess, d_test, appendix="base"):
    n_correct = 0
    for i in tqdm(range(NUM_TEST)):
        empty = True

        while empty:
            img_path = random.choice(list(d_test.keys()))
            image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
            name = img_path.split('/')[-1].split('.')[0]
            caps = d_test[img_path]
            if len(caps) > 0:
                pos_txt = random.choice(caps)
                empty = False

        neg_i = 0
        neg_txts = []
        while neg_i < NUM_NEG:
            img_path = random.choice(list(d_test.keys()))
            neg_name = img_path.split('/')[-1].split('.')[0]
            if neg_name == name:
                continue
            caps = d_test[img_path]
            if len(caps) == 0:
                continue
            neg_txt = random.choice(caps)
            if neg_txt in neg_txts:
                continue
            neg_txts.append(neg_txt)
            neg_i += 1

        text = clip.tokenize([pos_txt]+neg_txts).to(device)

        with torch.no_grad():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            logits_per_image, logits_per_text = model(image, text)
            probs = logits_per_image.softmax(dim=-1).cpu().numpy()

        if np.argmax(probs) == 0:
            n_correct +=1
    print(f"Test precision on model_{appendix} {n_correct / NUM_TEST}")

In [28]:
evaluatePrecision(model, preprocess, d_test, appendix="base")
evaluatePrecision(model_1e, preprocess_1e, d_seg_test, appendix="1e")
evaluatePrecision(model_cs, preprocess_cs, d_seg_test, appendix="cs")

100%|██████████| 1000/1000 [02:57<00:00,  5.62it/s]


Test precision on model_base 0.546


100%|██████████| 1000/1000 [02:59<00:00,  5.56it/s]


Test precision on model_1e 0.547


100%|██████████| 1000/1000 [03:00<00:00,  5.55it/s]

Test precision on model_cs 0.559





# Evaluating BLEU and Word Diversity using Naive Sampling

## Sampling Captions for Validation Images According to CLIP Text-Image Proximity

In [29]:
def sample1Caption(img_path, corpus, preprocess, model, num_cand):
    image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    i = 0
    txts = []
    while i < num_cand:
        txt = random.choice(corpus)
        if txt in txts:
            continue
        if len(txt.split())<5 or len(txt)>72:
            continue
        txts.append(txt)
        i += 1

    text = clip.tokenize(txts).to(device)

    with torch.no_grad():
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    return txts[np.argmax(probs)]

In [36]:
corpus = []
corpus_cs = []
corpus_1e = []
for txtlist in d_train.values():
    corpus += txtlist
for txtlist in d_seg_train.values():
    corpus_cs += txtlist
    corpus_1e += txtlist
print(len(corpus), corpus[0])
print(len(corpus_cs), corpus_cs[0])
print(len(corpus_1e), corpus_1e[0])

238418 I'M A TEACHER BUT I HAVE NO PUPILS
195683 WHAT DA FUCK IS THIS SHIT? 
195683 WHAT DA FUCK IS THIS SHIT? 


In [46]:
print("sampling for base model")
captions_base = {}
for img_path in tqdm(d_test.keys()):
    caption = sample1Caption(img_path, corpus, preprocess, model, 1000)
    captions_base[img_path] = caption
    
print("sampling for cs model")
captions_cs = {}
for img_path in tqdm(d_seg_test.keys()):
    caption = sample1Caption(img_path, corpus_cs, preprocess_cs, model_cs, 1000)
    captions_cs[img_path] = caption
    
print("sampling for 1e model")
captions_1e = {}
for img_path in tqdm(d_seg_test.keys()):
    caption = sample1Caption(img_path, corpus_1e, preprocess_1e, model_1e, 1000)
    captions_1e[img_path] = caption

sampling for base model


100%|██████████| 600/600 [07:14<00:00,  1.38it/s]


sampling for cs model


100%|██████████| 507/507 [06:13<00:00,  1.36it/s]


sampling for 1e model


100%|██████████| 507/507 [06:09<00:00,  1.37it/s]


## BLEU Score

In [47]:
def calculateBleuScore(captionst, d_t, appendix="base"):
    broken = 0
    for get_bleu in range(1,4):
        bleu_x_lst = []
        bleu_y_lst = []
        for p, caps in d_t.items():
            if not caps:
                continue
            if p not in captionst.keys():
                broken += 1
                continue
            bleu_x_lst.append(captionst[p].split())
            splittedcaps = [x.split() for x in caps]
            bleu_y_lst.append(splittedcaps)
        BLEU = torchtext.data.metrics.bleu_score(bleu_x_lst, bleu_y_lst, max_n=get_bleu, weights=[1/get_bleu]*get_bleu)
        print(f"{get_bleu}-gram BLEU score ({appendix}): {BLEU}, broken files: {broken}")

In [48]:
print(f"{len(captions_base)}, {len(captions_cs)}, {len(captions_1e)}")
print(f"{len(d_test.keys())}, {len(d_seg_test.keys())}")

600, 507, 507
600, 507


In [49]:
calculateBleuScore(captions_base, d_test, "base")
calculateBleuScore(captions_cs, d_seg_test, "cs")
calculateBleuScore(captions_1e, d_seg_test, "1e")

1-gram BLEU score (base): 0.4022921621799469, broken files: 0
2-gram BLEU score (base): 0.17008958756923676, broken files: 0
3-gram BLEU score (base): 0.09917695075273514, broken files: 0
1-gram BLEU score (cs): 0.3872121572494507, broken files: 0
2-gram BLEU score (cs): 0.16243566572666168, broken files: 0
3-gram BLEU score (cs): 0.08899804949760437, broken files: 0
1-gram BLEU score (1e): 0.38601332902908325, broken files: 0
2-gram BLEU score (1e): 0.15987643599510193, broken files: 0
3-gram BLEU score (1e): 0.08371224254369736, broken files: 0
