In [4]:
import json

import numpy as np
import pandas as pd
from sklearn.metrics import classification_report, label_ranking_average_precision_score, dcg_score

from fos.settings import ASSETS_DIR

meta = pd.read_pickle(ASSETS_DIR / "fields/fos.pkl.gz")
meta.index = meta.index.astype(int)

true = pd.read_json("autumn_cs_annotations.jsonl", lines=True)
true.set_index("merged_id", inplace=True)
true.head()

Unnamed: 0_level_0,title,abstract,text,options,meta,html,_input_hash,_task_hash,_view_id,accept,config,answer,_timestamp
merged_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
carticle_0165770296,Networked measurement and control system based...,The application of networked measurement and c...,networked measurement and control system based...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0165770296'},Networked measurement and control system based...,-943446144,529049400,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991522
carticle_0065264564,Hybrid Cloud Rendering System for Massive CAD ...,The recent advances in cloud services enable a...,hybrid cloud rendering system for massive cad ...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0065264564'},Hybrid Cloud Rendering System for Massive CAD ...,-180950055,1850390516,blocks,[High Performace Computing],{'choice_style': 'multiple'},accept,1647991549
carticle_0085809617,To Shorten Total Trip Time for Passengers by R...,,to shorten total trip time for passengers by r...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0085809617'},To Shorten Total Trip Time for Passengers by R...,410318024,1731964067,blocks,[AI/ML],{'choice_style': 'multiple'},accept,1647991567
carticle_0188819773,Connectivity of Wireless Sensor Complex Networks,,connectivity of wireless sensor complex networks,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0188819773'},Connectivity of Wireless Sensor Complex Networks,1677727149,1084729738,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991571
carticle_0021247370,Proceedings of the Second International Worksh...,,proceedings of the second international worksh...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0021247370'},Proceedings of the Second International Worksh...,224564397,1929177218,blocks,[IR & Knowledge management],{'choice_style': 'multiple'},accept,1647991622


In [5]:
true['answer'].value_counts()

accept    220
reject     31
ignore      2
Name: answer, dtype: int64

In [6]:
true = true.query("answer == 'accept'").copy()

In [7]:
labels = [x['id'] for x in true.iloc[0]['options']]
labels

['Algorithms',
 'AI/ML',
 'Computer engineering',
 'Graphics',
 'Networks',
 'Security & Privacy',
 'IR & Knowledge management',
 'Software',
 'Theoretical computer science',
 'High Performace Computing',
 'Human–computer interaction']

In [8]:
y_true = {}
for merged_id, row in true.iterrows():
    y_true[merged_id] = {label: label in row['accept'] for label in labels}

In [9]:
from fos.model import FieldModel

field_model = FieldModel("en")

In [10]:
from fos.util import preprocess_text
true['text'] = true.apply(preprocess_text, axis=1)
true['scores'] = true['text'].\
    apply(field_model.embed).\
    apply(field_model.score).\
    apply(lambda x: x.average())
true.head()

Unnamed: 0_level_0,title,abstract,text,options,meta,html,_input_hash,_task_hash,_view_id,accept,config,answer,_timestamp,scores
merged_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
carticle_0165770296,Networked measurement and control system based...,The application of networked measurement and c...,networked measurement and control system based...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0165770296'},Networked measurement and control system based...,-943446144,529049400,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991522,"[0.6252872995068798, 0.6508370191531042, 0.624..."
carticle_0065264564,Hybrid Cloud Rendering System for Massive CAD ...,The recent advances in cloud services enable a...,hybrid cloud rendering system for massive cad ...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0065264564'},Hybrid Cloud Rendering System for Massive CAD ...,-180950055,1850390516,blocks,[High Performace Computing],{'choice_style': 'multiple'},accept,1647991549,"[0.6416180819090423, 0.6759123919423775, 0.652..."
carticle_0085809617,To Shorten Total Trip Time for Passengers by R...,,to shorten total trip time for passengers by r...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0085809617'},To Shorten Total Trip Time for Passengers by R...,410318024,1731964067,blocks,[AI/ML],{'choice_style': 'multiple'},accept,1647991567,"[0.6006255887318293, 0.6235317604846059, 0.616..."
carticle_0188819773,Connectivity of Wireless Sensor Complex Networks,,connectivity of wireless sensor complex networks,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0188819773'},Connectivity of Wireless Sensor Complex Networks,1677727149,1084729738,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991571,"[0.2208237385978619, 0.2508910754504358, 0.224..."
carticle_0021247370,Proceedings of the Second International Worksh...,,proceedings of the second international worksh...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0021247370'},Proceedings of the Second International Worksh...,224564397,1929177218,blocks,[IR & Knowledge management],{'choice_style': 'multiple'},accept,1647991622,"[0.2863368251756025, 0.3098823845821171, 0.274..."


In [11]:
true['field_scores'] = true['scores'].apply(
    lambda x: {
        meta.loc[int(field_id), 'display_name']: score
        for field_id, score in zip(field_model.index, x)
    })
true.head()

Unnamed: 0_level_0,title,abstract,text,options,meta,html,_input_hash,_task_hash,_view_id,accept,config,answer,_timestamp,scores,field_scores
merged_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
carticle_0165770296,Networked measurement and control system based...,The application of networked measurement and c...,networked measurement and control system based...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0165770296'},Networked measurement and control system based...,-943446144,529049400,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991522,"[0.6252872995068798, 0.6508370191531042, 0.624...",{'Industrial organization': 0.6252872995068798...
carticle_0065264564,Hybrid Cloud Rendering System for Massive CAD ...,The recent advances in cloud services enable a...,hybrid cloud rendering system for massive cad ...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0065264564'},Hybrid Cloud Rendering System for Massive CAD ...,-180950055,1850390516,blocks,[High Performace Computing],{'choice_style': 'multiple'},accept,1647991549,"[0.6416180819090423, 0.6759123919423775, 0.652...",{'Industrial organization': 0.6416180819090423...
carticle_0085809617,To Shorten Total Trip Time for Passengers by R...,,to shorten total trip time for passengers by r...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0085809617'},To Shorten Total Trip Time for Passengers by R...,410318024,1731964067,blocks,[AI/ML],{'choice_style': 'multiple'},accept,1647991567,"[0.6006255887318293, 0.6235317604846059, 0.616...",{'Industrial organization': 0.6006255887318293...
carticle_0188819773,Connectivity of Wireless Sensor Complex Networks,,connectivity of wireless sensor complex networks,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0188819773'},Connectivity of Wireless Sensor Complex Networks,1677727149,1084729738,blocks,[Networks],{'choice_style': 'multiple'},accept,1647991571,"[0.2208237385978619, 0.2508910754504358, 0.224...",{'Industrial organization': 0.2208237385978619...
carticle_0021247370,Proceedings of the Second International Worksh...,,proceedings of the second international worksh...,"[{'id': 'Algorithms', 'text': 'Algorithms'}, {...",{'merged_id': 'carticle_0021247370'},Proceedings of the Second International Worksh...,224564397,1929177218,blocks,[IR & Knowledge management],{'choice_style': 'multiple'},accept,1647991622,"[0.2863368251756025, 0.3098823845821171, 0.274...",{'Industrial organization': 0.2863368251756025...


In [12]:
cs_fields = """\
Library science
Algorithm
Pattern recognition
Computer vision
Computer network
Knowledge management
World Wide Web
Data science
Human–computer interaction
Data mining
Software engineering
Computer security
Real-time computing
Information retrieval
Machine learning
Distributed computing
Multimedia
Internet privacy
Computer hardware
Natural language processing
Theoretical computer science
Telecommunications
Programming language
Simulation
Speech recognition
Embedded system
Operating system
Artificial intelligence
Computer graphics (images)
Database
Parallel computing
Computer engineering
Computational science
Computer architecture""".split('\n')
cs_fields

['Library science',
 'Algorithm',
 'Pattern recognition',
 'Computer vision',
 'Computer network',
 'Knowledge management',
 'World Wide Web',
 'Data science',
 'Human–computer interaction',
 'Data mining',
 'Software engineering',
 'Computer security',
 'Real-time computing',
 'Information retrieval',
 'Machine learning',
 'Distributed computing',
 'Multimedia',
 'Internet privacy',
 'Computer hardware',
 'Natural language processing',
 'Theoretical computer science',
 'Telecommunications',
 'Programming language',
 'Simulation',
 'Speech recognition',
 'Embedded system',
 'Operating system',
 'Artificial intelligence',
 'Computer graphics (images)',
 'Database',
 'Parallel computing',
 'Computer engineering',
 'Computational science',
 'Computer architecture']

In [13]:
true['field_scores'] = true['field_scores'].\
    apply(lambda x: {k: v for k, v in x.items() if k in cs_fields})

In [14]:
true['field_scores'] = true['field_scores'].\
    apply(lambda x: sorted(x.items(), key=lambda x: x[1], reverse=True))

In [15]:
true['top_pred'] = true['field_scores'].apply(lambda x: x[0][0])
true['top_pred']

merged_id
carticle_0165770296              Computer network
carticle_0065264564    Computer graphics (images)
carticle_0085809617                     Algorithm
carticle_0188819773              Computer network
carticle_0021247370                  Data science
                                  ...            
carticle_0171028584               Embedded system
carticle_0023217678    Human–computer interaction
carticle_0163850398              Operating system
carticle_0000293880              Computer network
carticle_0189168626    Human–computer interaction
Name: top_pred, Length: 220, dtype: object

In [24]:
agg = {
    "AI/ML": "Artificial intelligence	Data mining	Pattern recognition	Machine learning	Data science".split('\t'),
    "Computer engineering": "Computer engineering	Computer architecture	Computer hardware".split('\t'),
    "Networks": "Computer network	World Wide Web".split('\t'),
    'IR & Knowledge management': "Knowledge management	Database	Information retrieval	Library science".split('\t'),
    'Software': "Software engineering	Programming language	Operating system	Embedded system".split('\t'),
    'Security & Privacy': "Computer security	Internet privacy".split('\t'),
    'High Performace Computing': "Real-time computing	Distributed computing	Parallel computing".split('\t'),
}
y_top_pred = {}
for i, row in true.iterrows():
    y_top_pred[i] = row['top_pred']
    for cat, subcats in agg.items():
        if row['top_pred'] in subcats:
            y_top_pred[i] = cat
            break

y_pred = {}
for i, row['field_scores'] in true.iterrows():

    result =
    for cat, subcats in agg.items():

    row['field_scores']
    break

row['field_scores']

# y_pred = {}
# for i, row in true.iterrows():


[('Computer network', 0.7003316057060335),
 ('Distributed computing', 0.6647598722399581),
 ('Operating system', 0.6645782615319036),
 ('Embedded system', 0.6608830854579071),
 ('Simulation', 0.6608791964571538),
 ('Telecommunications', 0.6559305557162505),
 ('Computer security', 0.6539448830040847),
 ('Theoretical computer science', 0.6526612826182231),
 ('Computer architecture', 0.6518544687688257),
 ('Computational science', 0.6508370191531042),
 ('Real-time computing', 0.647983174449315),
 ('Computer vision', 0.6479559182596587),
 ('Computer engineering', 0.6460526312980773),
 ('Parallel computing', 0.6441418055653203),
 ('Database', 0.6437524455930167),
 ('Multimedia', 0.6430017792477815),
 ('Human–computer interaction', 0.6418413537056039),
 ('Speech recognition', 0.6408775341946321),
 ('Machine learning', 0.6392496141537373),
 ('Computer graphics (images)', 0.6390952197815649),
 ('Information retrieval', 0.638900683007743),
 ('Computer hardware', 0.6370022895252111),
 ('Artifici

In [18]:
y_true = dict(sorted(y_true.items()))
y_pred = dict(sorted(y_top_pred.items()))

In [19]:
def true_to_array(true):
    arrays = []
    for row in true.values():
        sorted_row = dict(sorted(row.items()))
        arrays.append(np.array(list(sorted_row.values())).astype(int))
    return np.array(arrays)

def top_k_accuracy(true_values, pred_values, top_k=5, proportion=True):
    true_values = list(true_values.values())
    pred_values = list(pred_values.values())
    assert len(true_values) == len(pred_values)
    n = 0
    correct = {i: 0 for i in range(top_k)}
    for true, pred in zip(true_values, pred_values):
        true_fields = {k for k, v in true.items() if v}
        if not len(true_fields):
            # If we've restricted the labels for evalutaion to STEM fields, discard the docs with no positive labels
            continue
        # Ensure that fields are in order of score descending
        pred = [k for k, v in sorted(pred.items(), key=lambda x: x[1], reverse=True)]
        # This is a hack for having annotated multilabel -- if there are 2 true labels, we're taking them
        # interchangeably. We look at the top-scoring label, and check whether it's among the true labels. If so, we
        # count that as accurate @ 1. If not, we check the 2nd-scoring label. If among the true labels, that's accurate
        # @ 2
        for i in range(top_k):
            if pred[i] in true_fields:
                correct[i] += 1
                break
        n += 1
    # The above gave us counts for how often the top-scoring field is correct; if not, how often the 2nd-ranked field
    # label is correct ... and so on to the 5th-ranked field. Taking a cumulative sum over these counts gives us how
    # often a correct label (*the* true label if we manually chose a single label, or one of the correct labels if we
    # chose more than one ... in practice it's mostly 1 label; some 2 labels; rarely 3 labels) is the top label; among
    # the top 2 labels; among the top 3 labels ...
    for i in reversed(range(top_k)):
        for j in range(i):
            correct[i] += correct[j]

    if proportion:
        return {k: x / n for k, x in correct.items()}
    return correct

def to_arrays(doc_scores):
    # Structure scores for sklearn.metrics
    arrays = []
    for doc_id, scores in doc_scores.items():
        arrays.append([v for k, v in sorted(scores.items())])
    return arrays

def pred_to_array(pred, ranks=False):
    arrays = []
    for row in pred.values():
        sorted_row = dict(sorted(row.items()))
        row_array = np.array(list(sorted_row.values()))
        if ranks:
            arrays.append(np.array(list(reversed(row_array.argsort()))) + 1)
        else:
            arrays.append(row_array)
    return np.array(arrays)

def dichotomize(scores, top_k=1):
    """Dichotomize an array of continuous field scores such that the top_k fields are true and the rest false."""
    output = {}
    ordered_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    for i, (label, score) in enumerate(ordered_scores, 1):
        output[label] = i <= top_k
    return output




In [20]:
# top_k_accuracy(y_true, y_pred)
y_true

{'carticle_0000255204': {'Algorithms': False,
  'AI/ML': False,
  'Computer engineering': True,
  'Graphics': False,
  'Networks': False,
  'Security & Privacy': False,
  'IR & Knowledge management': False,
  'Software': False,
  'Theoretical computer science': False,
  'High Performace Computing': False,
  'Human–computer interaction': False},
 'carticle_0000293880': {'Algorithms': False,
  'AI/ML': False,
  'Computer engineering': False,
  'Graphics': False,
  'Networks': True,
  'Security & Privacy': True,
  'IR & Knowledge management': False,
  'Software': False,
  'Theoretical computer science': False,
  'High Performace Computing': False,
  'Human–computer interaction': False},
 'carticle_0001990826': {'Algorithms': False,
  'AI/ML': False,
  'Computer engineering': False,
  'Graphics': False,
  'Networks': True,
  'Security & Privacy': False,
  'IR & Knowledge management': False,
  'Software': False,
  'Theoretical computer science': False,
  'High Performace Computing': False,


In [21]:
y_pred

{'carticle_0000255204': 'Simulation',
 'carticle_0000293880': 'Networks',
 'carticle_0001990826': 'Networks',
 'carticle_0002436541': 'Software',
 'carticle_0002565774': 'Theoretical computer science',
 'carticle_0003044952': 'Simulation',
 'carticle_0004118963': 'Computational science',
 'carticle_0004834905': 'Theoretical computer science',
 'carticle_0006211110': 'AI/ML',
 'carticle_0007064058': 'AI/ML',
 'carticle_0009278952': 'Algorithm',
 'carticle_0009850090': 'Computer engineering',
 'carticle_0011291905': 'Computer vision',
 'carticle_0012543572': 'IR & Knowledge management',
 'carticle_0013350185': 'Networks',
 'carticle_0013603907': 'Simulation',
 'carticle_0014121826': 'Simulation',
 'carticle_0017394439': 'Telecommunications',
 'carticle_0017641391': 'Networks',
 'carticle_0018137783': 'High Performace Computing',
 'carticle_0018187318': 'Theoretical computer science',
 'carticle_0019307725': 'Algorithm',
 'carticle_0020037141': 'Computer vision',
 'carticle_0021247370': '

In [22]:
y_pred = {merged_id: {k: v for k, v in pred.items() if k in labels} for merged_id, pred in y_pred.items()}
# y_pred_array = list({k: v for k, v in pred.items()} for pred in y_pred.values())
# y_pred_array[0]
y_pred

AttributeError: 'str' object has no attribute 'items'

In [None]:
binary_pred = {k: dichotomize(v, top_k=1) for k, v in y_pred.items()}
# binary_pred
# classification_report(to_arrays(y_true), to_arrays(binary_pred))
# y_pred
# to_arrays(binary_pred)
# classification_report(list(y_true.values()), to_arrays(binary_pred))
# len(to_arrays(binary_pred)[0])
# list(y_true.values())

In [None]:
def top_k_accuracy(true_values, pred_values, top_k=5, proportion=True):
    true_values = list(true_values.values())
    pred_values = list(pred_values.values())
    assert len(true_values) == len(pred_values)
    n = 0
    correct = {i: 0 for i in range(top_k)}
    for true, pred in zip(true_values, pred_values):
        true_fields = {k for k, v in true.items() if v}
        if not len(true_fields):
            # If we've restricted the labels for evalutaion to STEM fields, discard the docs with no positive labels
            continue
        # Ensure that fields are in order of score descending
        pred = [k for k, v in sorted(pred.items(), key=lambda x: x[1], reverse=True)]
        # This is a hack for having annotated multilabel -- if there are 2 true labels, we're taking them
        # interchangeably. We look at the top-scoring label, and check whether it's among the true labels. If so, we
        # count that as accurate @ 1. If not, we check the 2nd-scoring label. If among the true labels, that's accurate
        # @ 2
        for i in range(top_k):
            if pred[i] in true_fields:
                correct[i] += 1
                break
        n += 1
    # The above gave us counts for how often the top-scoring field is correct; if not, how often the 2nd-ranked field
    # label is correct ... and so on to the 5th-ranked field. Taking a cumulative sum over these counts gives us how
    # often a correct label (*the* true label if we manually chose a single label, or one of the correct labels if we
    # chose more than one ... in practice it's mostly 1 label; some 2 labels; rarely 3 labels) is the top label; among
    # the top 2 labels; among the top 3 labels ...
    for i in reversed(range(top_k)):
        for j in range(i):
            correct[i] += correct[j]

    if proportion:
        return {k: x / n for k, x in correct.items()}
    return correct


In [None]:
top_k_accuracy(y_true, y_pred)