In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
import spacy
from sklearn.linear_model import LogisticRegression

In [4]:
def load_data(file_path):
    data = []
    with open(file_path) as f:
        for el in f:
            data.append(json.loads(el))
            
    return data

In [5]:
def create_features(data):
    nlp = spacy.load('en_core_web_sm')
    for el in data:
        doc = nlp(el['text'])
        el['vector'] = doc.vector
    
    return data

In [6]:
def format_data(data):
    labels = []
    for el in data:
        labels.append(el['label'])
    
    train_data = []
    for el in data:
        train_data.append(el['vector'])
    
    return labels, train_data

In [13]:
def train_model(data, model):
    labels, train_data = format_data(data)

    model.fit(train_data, labels)
    
    return model, labels, train_data

## Baseline Model 1

In [8]:
train_data = load_data('data/train.jsonl')
val_data = load_data('data/dev.jsonl')

In [9]:
train_data = create_features(train_data)
val_data = create_features(val_data)

In [14]:
model = LogisticRegression()
lr, train_labels, train_vectors = train_model(train_data, model)



In [15]:
print('Train Accuracy: {}'.format(lr.score(train_vectors, train_labels)))
val_labels, val_vectors = format_data(val_data)
print('Validation Accuracy: {}'.format(lr.score(val_vectors, val_labels)))

Train Accuracy: 0.646235294117647
Validation Accuracy: 0.512


## Baseline Model 2

In [30]:
image_data = pd.read_csv('data/cleaned_getty_data.csv', index_col = False)
image_data = image_data.drop(['Unnamed: 0'], axis = 1)

In [31]:
image_data.head()

Unnamed: 0,url,img_url,tags,caption,src,id,Color_Score,Key_Point_Score,image,color_score_rank,...,color_url,color_img_url,color_tags,color_caption,key_point_url,key_point_img_url,key_point_tags,key_point_caption,best_caption,best_tags
0,https://www.gettyimages.co.uk/detail/photo/dru...,https://media.gettyimages.com/photos/drugs-pic...,"Opioid Photos,Syringe Photos,Addiction Photos,...",Drugs,56482,1.0,14.358936,144301.065239,56482,1.0,...,https://www.gettyimages.co.uk/detail/photo/dru...,https://media.gettyimages.com/photos/drugs-pic...,"Opioid Photos,Syringe Photos,Addiction Photos,...",Drugs,https://www.gettyimages.co.uk/detail/photo/dru...,https://media.gettyimages.com/photos/drugs-pic...,"Opioid Photos,Syringe Photos,Addiction Photos,...",Drugs,Drugs,Opioid Syringe Addiction Despair Heroi...
1,https://www.gettyimages.co.uk/detail/photo/bui...,https://media.gettyimages.com/photos/buisnessm...,"Men Photos,20-29 Years Photos,Adult,Adults Onl...",Businessman standing outside of his office and...,60183,2.0,37.201224,176873.773315,60183,1.0,...,https://www.gettyimages.co.uk/detail/photo/bui...,https://media.gettyimages.com/photos/buisnessm...,"Men Photos,20-29 Years Photos,Adult,Adults Onl...",Businessman standing outside of his office and...,https://www.gettyimages.co.uk/detail/photo/he-...,https://media.gettyimages.com/photos/he-makes-...,"Adult,Adults Only Photos,Answering Photos,Brin...",Studio shot of a stylish young businessman usi...,Businessman standing outside of his office and...,Men 20-29 Years Adult Adults Only Agreem...
2,https://www.gettyimages.co.uk/detail/photo/rea...,https://media.gettyimages.com/photos/rear-view...,"Animal Photos,Animal Themes Photos,Animals In ...",,43905,1.0,2.254148,162676.233573,43905,1.0,...,https://www.gettyimages.co.uk/detail/photo/rea...,https://media.gettyimages.com/photos/rear-view...,"Animal Photos,Animal Themes Photos,Animals In ...",,https://www.gettyimages.co.uk/detail/photo/rea...,https://media.gettyimages.com/photos/rear-view...,"Animal Photos,Animal Themes Photos,Animals In ...",,,Animal Animal Themes Animals In The Wild ...
3,https://www.gettyimages.co.uk/detail/photo/sou...,https://media.gettyimages.com/photos/source-of...,"Active Volcano Photos,Akita Prefecture Photos,...","Higashi Naruse Village, Akita Prefecture, Japa...",7825,5.0,20.450264,180862.123245,7825,1.0,...,https://www.gettyimages.co.uk/detail/photo/sou...,https://media.gettyimages.com/photos/source-of...,"Active Volcano Photos,Akita Prefecture Photos,...","Higashi Naruse Village, Akita Prefecture, Japa...",https://www.gettyimages.co.uk/detail/photo/wil...,https://media.gettyimages.com/photos/wildebees...,"Adventure Photos,Animal Photos,Animal Themes P...",The Great Migration. Wildebeest and Zebra cros...,"Higashi Naruse Village, Akita Prefecture, Japa...",Active Volcano Akita Prefecture Autumn A...
4,https://www.gettyimages.co.uk/detail/photo/see...,https://media.gettyimages.com/photos/see-it-an...,"Spectacles Photos,Men Photos,Caucasian Appeara...",Studio shot of a handsome young man posing aga...,50413,2.0,24.906096,174126.235321,50413,1.0,...,https://www.gettyimages.co.uk/detail/photo/see...,https://media.gettyimages.com/photos/see-it-an...,"Spectacles Photos,Men Photos,Caucasian Appeara...",Studio shot of a handsome young man posing aga...,https://www.gettyimages.co.uk/detail/photo/por...,https://media.gettyimages.com/photos/portrait-...,"Men Photos,Human Face Photos,Portrait Photos,S...",Man with a serious expression,Studio shot of a handsome young man posing aga...,Spectacles Men Caucasian Appearance Port...


In [37]:
src_photos = set(image_data.src)

In [38]:
train_data[0]

{'id': 42953,
 'img': 'img/42953.png',
 'label': 0,
 'text': 'its their character not their color that matters',
 'vector': array([-0.39059347,  0.34201172,  0.79031336, -0.67663956,  1.2195603 ,
        -0.03345944, -1.9631115 ,  0.00423736, -0.20557106,  0.32930514,
         0.45759752,  0.95797217, -0.7769648 , -0.40923548, -0.24158663,
         0.96003276, -0.23099978, -0.63708985,  0.3379006 ,  0.8056612 ,
        -0.26968667,  0.48419315,  1.3091208 , -1.0345259 , -0.43772987,
         1.5756359 ,  1.1876926 , -0.25269243, -2.6762524 , -1.4061985 ,
        -1.12094   ,  0.45184684, -0.893475  , -0.44338486, -0.13037719,
        -0.9010774 , -0.6172734 , -0.71861506, -0.27039078, -0.43401954,
         1.7497404 ,  0.4978023 ,  0.12083992, -0.96050525,  0.8526076 ,
        -0.20645005, -0.32321572,  1.0391794 ,  0.8955164 ,  1.5223812 ,
        -0.32024488, -0.5678183 , -1.0750859 , -1.2296531 ,  0.2370046 ,
         1.1308669 ,  1.1394649 ,  0.44225276, -0.9677632 , -0.00342326,
 

In [39]:
new_train_data = []
for el in train_data:
    if el['id'] in src_photos:
        new_train_data.append(el)

In [41]:
len(new_train_data)

8483

In [42]:
new_val_data = []
for el in val_data:
    if el['id'] in src_photos:
        new_val_data.append(el)

In [43]:
len(new_val_data)

500

In [53]:
nlp = spacy.load('en_core_web_sm')
for el in new_train_data:
    image_desc = str(image_data[image_data['src'] == el['id']]['best_caption'].values[0]) + ' ' + str(image_data[image_data['src'] == el['id']]['best_tags'].values[0])
    doc = nlp(el['text'] + ' ' + image_desc)
    el['new_vector'] = doc.vector

In [54]:
nlp = spacy.load('en_core_web_sm')
for el in new_val_data:
    image_desc = str(image_data[image_data['src'] == el['id']]['best_caption'].values[0]) + ' ' + str(image_data[image_data['src'] == el['id']]['best_tags'].values[0])
    doc = nlp(el['text'] + ' ' + image_desc)
    el['new_vector'] = doc.vector

In [60]:
new_train_labels = []
for el in new_train_data:
    new_train_labels.append(el['label'])

new_train_vectors = []
for el in new_train_data:
    new_train_vectors.append(el['new_vector'])

In [56]:
new_val_vectors = []
for el in new_val_data:
    new_val_vectors.append(el['new_vector'])

In [63]:
new_lr = LogisticRegression()
new_lr.fit(new_train_vectors, new_train_labels)
print('Train Accuracy: {}'.format(new_lr.score(new_train_vectors, new_train_labels)))
print('Validation Accuracy: {}'.format(new_lr.score(new_val_vectors, val_labels)))



Train Accuracy: 0.6441117529175999
Validation Accuracy: 0.502
