<a href="https://colab.research.google.com/github/nicolaiberk/_rrpviol_med/blob/master/_sc/_mig_clsfr/BERT_estimates.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install transformers



In [3]:
import os
import csv
import numpy as np
import pandas as pd
from datetime import datetime
from itertools import islice
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

In [4]:
# get full set of news articles
if not os.path.isfile('newspapers/_bild_articles.csv'):
    os.system('wget -O articles.zip https://www.dropbox.com/sh/bbf0655w9931xbk/AADQNpkipxBENPk4Gp5j1UaDa?dl=0')
    os.system('unzip articles.zip -d newspapers')
    os.system('rm articles.zip')

In [5]:
model_name = "drive/MyDrive/Bild/mig_clsfr_BERT_torch"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained('bert-base-german-cased', model_max_length = 512)
classifier = pipeline('text-classification', model=model, tokenizer=tokenizer, device=0)

Check if the model does what it is told:

In [6]:
classifier('Flüchtlinge haben es oft schwer, sich in Deutschland zu integrieren. Die Bundesreierung will nun mit weiteren Abschiebingen in Kriegsgebieten das Leben der Asylsuchenden noch schwerer machen.')

[{'label': 'LABEL_1', 'score': 0.9998746514320374}]

In [7]:
classifier('Deutschland schiebt vorerst keine Menschen mehr nach Afghanistan ab. "Der Bundesinnenminister hat aufgrund der aktuellen Entwicklungen der Sicherheitslage entschieden, Abschiebungen nach Afghanistan zunächst auszusetzen", sagte ein Sprecher des Bundesinnenministeriums am Mittwoch der Deutschen Presse-Agentur.')

[{'label': 'LABEL_1', 'score': 0.9998152852058411}]

In [8]:
classifier('19-mal reiste Merkel nach Russland. Die Bundeskanzlerin hält den russischen Präsidenten für stets latent beleidigt, Putin schreckt nicht davor zurück, Merkel gelegentlich zu beleidigen.')

[{'label': 'LABEL_0', 'score': 0.9999224543571472}]

Looking good, lets do this for our 2.4M newspaper articles:

In [9]:
batch_size = 100

for paper in os.listdir('newspapers'):
  filename = 'newspapers/'+paper
  print(f'Processing file {paper}')
  with open(filename, 'r') as csvfile:
    with open('drive/MyDrive/Bild/BERT_estimates.csv', mode='a') as fo:
      reader = csv.reader(csvfile)
      for row in reader:
          # define relative position in row based on title
          titlerow = np.argmax([r == 'title' for r in row])
          linkrow  = np.argmax([r == 'url'   for r in row])
          daterow  = np.argmax([r == 'date'  for r in row])
          textrow  = np.argmax([r == 'text'  for r in row])
          topicrow = np.argmax([r == 'topic' for r in row])
          break

      writer = csv.writer(fo)
      writer.writerow(['title', 'link', 'date', 'topic', 'est', 'label'])

      titlebatch = []
      linkbatch = []
      datebatch = []
      textbatch = []
      topicbatch = []

      count = 0
      batches_run = 0

      for row in tqdm(reader):
        if count >= batch_size:
            tempbatch = classifier(textbatch, padding='max_length', truncation=True, return_tensors='pt')
            estbatch = [row['score'] if row['label'] == 'LABEL_1' else (1-row['score']) for row in tempbatch]
            labelbatch = [row['label'] == 'LABEL_1' for row in tempbatch]

            for title, link, date, topic, est, label in zip(titlebatch, linkbatch, datebatch, topicbatch, estbatch, labelbatch):
              writer.writerow([title, link, date, topic, est, label])

            count = 0
            titlebatch = []
            linkbatch = []
            datebatch = []
            textbatch = []
            topicbatch = []

        titlebatch.append(row[titlerow])
        linkbatch.append(row[linkrow])
        datebatch.append(row[daterow])
        textbatch.append(row[textrow])
        topicbatch.append(row[topicrow])
        count += 1
        
      if len(titlebatch) > 0:
        for title, link, date, topic, est, label in zip(titlebatch, linkbatch, datebatch, topicbatch, estbatch, labelbatch):
              writer.writerow([title, link, date, topic, est, label])
    print(f'\tFinished file {paper}.')

    

0it [00:00, ?it/s]

Processing file _bild_articles_2019.csv


1601it [00:16, 96.02it/s]

KeyboardInterrupt: ignored