In [1]:
from datetime import datetime, timedelta
from newspaper import Article
from pathlib import Path
from typing import List
from zipfile import ZipFile
import concurrent.futures

from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
)
from transformers import RealmTokenizer, RealmEmbedder
from tqdm.notebook import tqdm
import datatable as dt
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch

# Load data

In [2]:
event_meta_df = pd.read_csv("20220201.csv")

In [3]:
event_meta_df

Unnamed: 0,GLOBALEVENTID,SQLDATE,MonthYear,Year,FractionDate,Actor1Code,Actor1Name,Actor1CountryCode,Actor1KnownGroupCode,Actor1EthnicCode,...,ActionGeo_Type,ActionGeo_FullName,ActionGeo_CountryCode,ActionGeo_ADM1Code,ActionGeo_ADM2Code,ActionGeo_Lat,ActionGeo_Long,ActionGeo_FeatureID,DATEADDED,SOURCEURL
0,1025959174,20210201,202102,2021,2021.0849,CVL,TRAVELLER,,,,...,4,"Poblacion, Camarines Sur, Philippines",RP,RP16,24242,13.80000,123.01700,-2446409,20220201000000,https://www.ttrweekly.com/site/2022/02/berjaya...
1,1025959175,20210201,202102,2021,2021.0849,CVL,IMMIGRANT,,,,...,1,Afghanistan,AF,AF,,33.00000,66.00000,AF,20220201000000,https://www.wksu.org/health-science/2022-01-31...
2,1025959176,20210201,202102,2021,2021.0849,GOV,STATE OFFICIAL,,,,...,3,"Cape Cod Bay, Massachusetts, United States",US,USMA,,42.03340,-70.41610,617926,20220201000000,https://www.msn.com/en-us/news/us/despite-thre...
3,1025959177,20210201,202102,2021,2021.0849,GOV,STATE OFFICIAL,,,,...,2,"Massachusetts, United States",US,USMA,,42.23730,-71.53140,MA,20220201000000,https://www.msn.com/en-us/news/us/despite-thre...
4,1025959178,20210201,202102,2021,2021.0849,LEG,CONGRESS,,,,...,2,"Illinois, United States",US,USIL,,40.33630,-89.00220,IL,20220201000000,https://www.cities929.com/2022/01/31/restauran...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
122309,1026147681,20220201,202202,2022,2022.0849,idg,INDIGENOUS,,,idg,...,4,"Fort Frances, Ontario, Canada",CA,CA08,30027,48.61670,-93.41670,-564798,20220201234500,https://fftimes.com/news/canada/indigenous-del...
122310,1026147682,20220201,202202,2022,2022.0849,idg,INDIGENOUS,,,idg,...,4,"Fort Frances, Ontario, Canada",CA,CA08,30027,48.61670,-93.41670,-564798,20220201234500,https://fftimes.com/news/canada/indigenous-del...
122311,1026147683,20220201,202202,2022,2022.0849,idg,INDIGENOUS,,,idg,...,4,"Fort Frances, Ontario, Canada",CA,CA08,30027,48.61670,-93.41670,-564798,20220201234500,https://fftimes.com/news/canada/indigenous-del...
122312,1026147684,20220201,202202,2022,2022.0849,idg,INDIGENOUS,,,idg,...,4,"Fort Frances, Ontario, Canada",CA,CA08,30027,48.61670,-93.41670,-564798,20220201234500,https://fftimes.com/news/canada/indigenous-del...


# REALM

In [4]:
class NewsEmbedder:
    def __init__(self):
        self.tokenizer = RealmTokenizer.from_pretrained("google/realm-cc-news-pretrained-embedder")
        self.model = RealmEmbedder.from_pretrained("google/realm-cc-news-pretrained-embedder")
        self.model.eval()
        
    def inference(self, title, text, with_title=True):
        with torch.no_grad():
            if with_title:
                inputs = self.tokenizer(title, text, return_tensors="pt", max_length=512, truncation=True)
            else:
                inputs = self.tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
            outputs = self.model(**inputs)
            scores = outputs.projected_score
            # scores = F.normalize(scores)
        
        return scores
    
    def embedding_from_url(self, url, with_title=True):
        article = Article(url)
        try:
            article.download()
            article.parse()
            return self.inference(article.title, article.text, with_title)
        except:
            return None
        
    def embedding_from_text(self, title=None, text=None, with_title=True):
        try:
            return self.inference(title, text, with_title)
        except:
            return None

In [5]:
ne = NewsEmbedder()

# Write to Milvus

In [7]:
class MilvusClient:
    def __init__(self, reset_ds=False):
        connections.connect("default", host="127.0.0.1", port="19530")
        if not utility.has_collection("realm_news"):
            self.create_table()
        elif reset_ds:
            utility.drop_collection("realm_news")
            self.create_table()
        self.collection = Collection("realm_news")
        
    def create_table(self):
        fields = [
            FieldSchema(name="global_event_id", dtype=DataType.INT64, is_primary=True, auto_id=False),
            FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=128)
        ]
        schema = CollectionSchema(fields, "realm_news stores the REALM embeddings of GDELT news")
        realm_news = Collection("realm_news", schema, consistency_level="Strong")
    
    def insert(self, global_event_id, embedding):
        if isinstance(embedding, torch.Tensor):
            embedding = embedding.detach().cpu().numpy()
        embedding = embedding.flatten()
        insert_result = self.collection.insert([
            [global_event_id], [embedding]
        ])
        return insert_result

In [8]:
milvus_client = MilvusClient(reset_ds=True)

# Multi-process

In [9]:
def download_news(url):
    article = Article(url)
    try:
        article.download()
        article.parse()
        return article.title, article.text
    except:
        return None, None

In [12]:
def run(ne, milvus_client, event_meta_df, with_title=False):
    ids = event_meta_df.GLOBALEVENTID.values
    urls = event_meta_df.SOURCEURL.values
    with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
        for global_event_id, (title, text) in tqdm(zip(ids, executor.map(download_news, urls)), total=len(ids)):
            embedding = ne.embedding_from_text(self, title, text, with_title=with_title)
            if embedding is not None:
                milvus_client.insert(global_event_id, embedding)

In [13]:
run(ne, milvus_client, event_meta_df, with_title=False)

Process SpawnProcess-11:
Traceback (most recent call last):
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/concurrent/futures/process.py", line 233, in _process_worker
    call_item = call_queue.get(block=True)
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/multiprocessing/queues.py", line 116, in get
    return _ForkingPickler.loads(res)
AttributeError: Can't get attribute 'download_news' on <module '__main__' (built-in)>
Process SpawnProcess-12:
Traceback (most recent call last):
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/Users/maoxin/opt/anaconda3/env

  0%|          | 0/122314 [00:00<?, ?it/s]

Exception in thread QueueManagerThread:
Traceback (most recent call last):
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/concurrent/futures/process.py", line 394, in _queue_management_worker
    work_item.future.set_exception(bpe)
  File "/Users/maoxin/opt/anaconda3/envs/n2_one_click/lib/python3.8/concurrent/futures/_base.py", line 547, in set_exception
    raise InvalidStateError('{}: {!r}'.format(self._state, self))
concurrent.futures._base.InvalidStateError: CANCELLED: <Future at 0x7f80c7caf430 state=cancelled>


BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

In [None]:
# for idx, row in tqdm(event_meta_df.iterrows(), total=len(event_meta_df)):
#     global_event_id = row.GLOBALEVENTID
#     url = row.SOURCEURL
#     embedding = ne.embedding_from_url(url, with_title=False)
#     if embedding is not None:
#         milvus_client.insert(global_event_id, embedding)