In [1]:
import torch
import torchvision.models as models
from PIL import Image
from torchvision import transforms
import json
import os
import numpy as np

In [4]:
BASE_DIR = './newsdata'
WORKING_DIR = './newsmodels'

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
labels = json.load(open('imagenet_class_index.json'))
vgg16 = models.vgg16(weights='IMAGENET1K_V1')
# Get articles
f = open(BASE_DIR+'/article.json')
articles = json.load(f)
f.close()

In [8]:
def classify_one_image(img_path,model):
    img = Image.open(img_path)
    torch.manual_seed(42)
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])
    image = transform(img).unsqueeze(0).to(device)  
    prediction = model(image)
    prob = torch.nn.functional.softmax(prediction, dim=1)[0].detach().numpy()
    top3_idx = prob.argsort()[-3:]
    class_labels = []
    for i in top3_idx:
        class_prob = prob[i]
        if class_prob>0.1: # Only consider classes with probability greater than 0.1
            class_label = labels[str(i)][1]
            # print(class_prob)
            # print(class_label)
            class_labels.append(class_label)
    if len(class_labels)==0:
        class_label = labels[str(top3_idx[-1])][1]
        class_labels.append(class_label)
    return class_labels


In [9]:
def generate_ner_dicts(images,articles,model,indices):
    ner_dict= {}
    for i in indices:
        img_name = images[i]
        article_id = img_name[:24]
        if article_id in articles.keys():
            img_classes = classify_one_image(BASE_DIR+'/all_images/'+img_name,model)
            ners = articles[article_id]['ner']
            for c in img_classes:
                if c not in ner_dict.keys():
                    ner_dict_class = {'PERSON':[],'NORP':[],'FAC':[],'ORG':[],'GPE':[],
                                    'LOC':[],'PRODUCT':[],'EVENT':[],'WORK_OF_ART':[],
                                    'LAW':[],'LANGUAGE':[],'DATE':[],'TIME':[],'PERCENT':[],
                                    'MONEY':[],'QUANTITY':[],'ORDINAL':[],'CARDINAL':[]}
                    for key,val in ners.items():
                        ner_dict_class[val].append(key)
                    ner_dict[c] = ner_dict_class
                else:
                    for key,val in ners.items():
                        ner_dict[c][val].append(key)
    return ner_dict

In [13]:
all_images = os.listdir(BASE_DIR+"/all_images/")
np.random.seed(123)
indices = np.random.choice(range(len(all_images)), size=10000, replace=False)
ner_dict = generate_ner_dicts(all_images,articles,vgg16,indices)

In [391]:
with open(WORKING_DIR+'/ner_for_classes.json', 'w') as fp:
    json.dump(ner_dict, fp)