In [1]:
GITHUB_DIR = "aspects-space"

In [None]:
!pip install transformers
!pip install stanza
!pip install transformers datasets evaluate
!pip install accelerate -U
!pip install pynvml
!pip install gdown

#download the model from my google drive
!gdown --folder --id "1-OIgf-F7lNsSkDv3fuqVN3tpCNv2LgiF"

In [None]:
repo_path = GITHUB_DIR
!git clone "https://github.com/katrinrohrb/aspects-space-dev.git" $repo_path

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join('/content', GITHUB_DIR))
sys.path.insert(0, module_path)

from katspace.core import RESULTS_DIR, DATA_DIR

In [None]:
import torch

from pathlib import Path
import json
import gc

from transformers import AutoTokenizer

from transformers import logging
from transformers import pipeline, AutoConfig

from pynvml import *
import katspace.core as katspace

import evaluate

accuracy = evaluate.load("accuracy")

RESULTS_DIR = Path(RESULTS_DIR, "test")
TXT_DIR = Path(DATA_DIR, "txt")
MODEL_DIR = Path("checkpoint-286")

import logging
LOG_FILENAME = Path("session.log")
logging.basicConfig(filename=LOG_FILENAME, encoding='utf-8', level=logging.INFO, force=True)
logging.info("Start logging")

In [None]:
checkpoint = MODEL_DIR

label2id = {"perceived_space": 0, "action_space": 1, "visual_space": 2, "descriptive_space":3, "no_space":4}
id2label = {v : k for k, v in label2id.items()}


config = AutoConfig.from_pretrained(checkpoint, label2id=label2id, id2label=id2label)

tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")

device = 0 if torch.cuda.is_available() else -1
pipe = pipeline('text-classification', model = checkpoint , tokenizer = tokenizer, config = config, device = device, batch_size=120)

In [7]:
def data_generator(filename):
  with Path(filename).open(mode = "r", encoding = "utf8") as f:
    while True:
      line = f.readline()
      if line == '': break
      yield line

In [None]:
dirname = TXT_DIR.as_posix()
!wget https://raw.githubusercontent.com/katrinrohrb/aspects-space-dev/refs/heads/colabtest/data/txt/Franz_Kafka_-_Der_Prozeß.txt Franz_Kafka_-_Der_Prozeß.txt
!mkdir $TXT_DIR ; cp Franz_Kafka_-_Der_Prozeß.txt $TXT_DIR

In [37]:
#filenames = [Path(TXT_DIR, "Franz_Kafka_-_Der_Prozeß.txt")]
filenames = [filename for filename in TXT_DIR.iterdir()]
forbidden_list = []

In [19]:
def safe_enumerate(gen, errors):
  sent_num = 0
  gen.__iter__()
  while True:
    try:

      yield sent_num, next(gen)
    except StopIteration:
      break
    except RuntimeError as e:
      logging.warning(f"Skipping sentence {sent_num} in {filename} because it caused a runtime error!")
      logging.warning(e)
      errors.append((sent_num, e))
      torch.cuda.empty_cache()
      gc.collect()
    finally:
      sent_num += 1

In [38]:
#filenames = [Path(txt_dir, "Jacob_Burckhardt_-_Die_Zeit_Constantins_des_Großen.txt")]

results_list = []

forbidden_list = [Path(filename) for filename in forbidden_list]
failed_list = []

if RESULTS_DIR and not Path(RESULTS_DIR).exists():
    Path(RESULTS_DIR).mkdir()

for filename in filenames:
  #if results already exists: skip this file
  if filename in forbidden_list:
    logging.info(f"Skipping {filename} because it was manually excluded!")
    continue
  if Path(RESULTS_DIR, Path(filename).stem + "-result.json").exists():
    logging.info(f"Skipping {filename} because it already exists!")
    continue
  logging.info(f"Processing {filename}")

  results = []
  failed_sents = []

  #this results object is only a generator
  iterator = pipe(data_generator(filename))
  try:
    wrapper = safe_enumerate(iterator, failed_sents)
    for sent_num, result in wrapper:
      #print(sent_num)
      results.append(result)
  except RuntimeError as error:
    logging.warning(filename)
    save = (filename, sent_num, error)
    failed_list.append(save)
    save = (sent_num, error)
    failed_sents.append(save)
    print(f"Skipping sentence {sent_num} in {filename} because it caused a runtime error!")
  try:
    filename_res =  Path(filename).stem + "-result.json"
    with Path(RESULTS_DIR, filename_res).open('w', encoding="utf-8") as f:
      json.dump(results, f)
    if len(failed_sents) > 0:
      filename_err = Path(filename).stem + "-errs.json"
      with Path(RESULTS_DIR, filename_err).open('w', encoding="utf-8") as f:
        json.dump(results, f)
  except KeyboardInterrupt as kint:
    for fn in [filename_res, filename_err]:
      if os.path.exists(fn):
        os.remove(fn)
    raise KeyboardInterrupt("interrupted during writing of " + filename + "\nDeleting!") from kint
