In [2]:
import json
import string 
import nltk 
from nltk.tokenize import RegexpTokenizer 
from nltk.corpus import stopwords 
import re
from random import shuffle
from fastText import load_model, train_supervised
from tqdm import tqdm

In [2]:
def preprocess(sentence): 
    sentence = sentence.lower() 
    tokenizer = RegexpTokenizer(r'\w+') 
    tokens = tokenizer.tokenize(sentence) 
    filtered_words = [w for w in tokens if not w in stopwords.words('english')] 
    return " ".join(filtered_words)

## Loading annotations

In [4]:
data = json.load(open('pruebas/coco_noun_alpha_0.0.tags'))

In [5]:
val = data["val2014"]
train = data["train2014"]
print("Train samples: {}".format(len(train)))
print("Val samples: {}".format(len(val)))

Train samples: 82783
Val samples: 40504


## Making fasttext train file 

In [5]:
with open("pruebas/train.txt", "w") as ofile:
    for set_ in ["train2014", "val2014"]:
        for id_ in data[set_].keys():
            captions = data[set_][id_]["captions"]
            captions = [preprocess(cap) for cap in captions]
            captions = " ".join(captions)
            categories = data[set_][id_]["tags"]
            categories = ["_".join(cat.split()) for cat in categories]
            categories = [ "__label__" + cat for cat in categories]
            categories = " ".join(categories)
            ofile.write(categories + " " + captions + "\n")

In [6]:
with open("pruebas/train.txt", "r") as ifile, open("pruebas/shuffle_train.txt", "w") as ofile:
    lines = [ line for line in ifile]
    shuffle(lines)
    for line in lines:
        ofile.write(line)

## Training fasttext

##### fasttext supervised -input ./shuffle_train.txt -epoch 20 -lr 0.2 -loss hs -wordNgrams 1 -verbose 2 -dim 100 -minCount 1 -output out/noun_model_100_1

## Make json with tags from fasttext

In [6]:
data_fast = data.copy()

In [11]:
model = load_model("pruebas/out/noun_model_100_1.bin")

In [34]:
k = 10
th = 0.0
prefix = "__label__"
for set_ in ["train2014", "val2014"]:
    for id_ in tqdm(data[set_].keys()):
        captions = data_fast[set_][id_]["captions"]
        captions = [preprocess(cap) for cap in captions]
        captions = " ".join(captions)

        lab, pro = model.predict(captions, k, th)
        lab = [l.replace(prefix, "") for l in lab]
        pro = list(pro)

        data_fast[set_][id_]["tags"] = lab
        data_fast[set_][id_]["scores"] = pro

100%|██████████| 82783/82783 [06:08<00:00, 224.80it/s]
100%|██████████| 40504/40504 [02:53<00:00, 233.61it/s]


In [35]:
with open("pruebas/noun_fasttext_th_0.0.tags", "w") as ofile:
    json.dump(data_fast, ofile)

In [None]:
80 --> coco
10977 --> noun

In [15]:
captions = data_fast["train2014"]["384029"]["captions"]
captions = [preprocess(cap) for cap in captions]
captions = " ".join(captions)
captions

'man preparing desserts kitchen covered frosting chef preparing decorating many small pastries baker prepares various types baked goods close person grabbing pastry container close hand touching various pastries'

In [22]:
lab, pro = model.predict(captions, 10, 0.01)

In [28]:
[l.replace("__label__", "") for l in lab]

['pastry',
 'close',
 'dessert',
 'type',
 'frosting',
 'person',
 'donut',
 'container',
 'man',
 'good']

In [29]:
list(pro)

[0.11727694422006607,
 0.10713212192058563,
 0.07604345679283142,
 0.056507110595703125,
 0.05331281200051308,
 0.046535227447748184,
 0.033007536083459854,
 0.02365998737514019,
 0.021940473467111588,
 0.01893511787056923]

In [12]:
data_fast["train2014"]["384029"]

{'captions': ['A man preparing desserts in a kitchen covered in frosting.',
  'A chef is preparing and decorating many small pastries.',
  'A baker prepares various types of baked goods.',
  'a close up of a person grabbing a pastry in a container',
  'Close up of a hand touching various pastries.'],
 'category_ids': [1, 60, 61],
 'category_names': ['cake', 'donut', 'person'],
 'file_name': 'COCO_train2014_000000384029.jpg',
 'scores': [0.9166666666666666,
  0.9090909090909091,
  0.9,
  0.8888888888888888,
  0.7272727272727273,
  0.5833333333333333,
  0.5555555555555556,
  0.5555555555555556,
  0.4545454545454546,
  0.33333333333333337,
  0.2222222222222222,
  0.18181818181818177,
  0.08333333333333337],
 'tags': ['close',
  'man',
  'chef',
  'baker',
  'dessert',
  'person',
  'hand',
  'type',
  'kitchen',
  'pastry',
  'good',
  'frosting',
  'container']}

In [31]:
data_fast["train2014"]["384029"]

{'captions': ['A man preparing desserts in a kitchen covered in frosting.',
  'A chef is preparing and decorating many small pastries.',
  'A baker prepares various types of baked goods.',
  'a close up of a person grabbing a pastry in a container',
  'Close up of a hand touching various pastries.'],
 'category_ids': [1, 60, 61],
 'category_names': ['cake', 'donut', 'person'],
 'file_name': 'COCO_train2014_000000384029.jpg',
 'scores': [0.11727694422006607,
  0.10713212192058563,
  0.07604345679283142,
  0.056507110595703125,
  0.05331281200051308,
  0.046535227447748184,
  0.033007536083459854,
  0.02365998737514019,
  0.021940473467111588,
  0.01893511787056923],
 'tags': ['pastry',
  'close',
  'dessert',
  'type',
  'frosting',
  'person',
  'donut',
  'container',
  'man',
  'good']}

### Check data

In [14]:
data = json.load(open('pruebas/noun_fasttext_th_0.0.tags'))

In [15]:
for set_ in ["train2014", "val2014"]:
    for id_ in data[set_].keys():
        tags = data[set_][id_]["tags"]
        scores = data[set_][id_]["scores"]
        if not (sorted(scores, reverse=True) == scores):
            print(id_, tags, scores)
            break
        for t in tags:
            if len(t.split()) > 1:
                print(id_, data[set_][id_])
                raise

#### It seem to be sorted --> :/

### Amount of tags per sample

In [9]:
data = json.load(open('pruebas/coco_noun_alpha_0.0.tags'))

In [10]:
for set_ in ["train2014", "val2014"]:
    me = 0.
    n = len(data[set_].keys())
    for id_ in data[set_].keys():
        tags = data[set_][id_]["tags"]
        me += len(tags)
    me /= n
    print("{}: {}".format(set_, me))

train2014: 10.096420762716983
val2014: 10.055574758048587


In [26]:
len(model.predict("cat",k=2000)[0])

1364

In [27]:
len(model.get_labels())

10977