In [1]:
!nvidia-smi

Mon Feb 10 21:40:35 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:47:00.0 Off |                    0 |
| N/A   31C    P0             63W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                     

In [2]:
import os
import pandas as pd
import pytorch_lightning as pl
import torch

from pandas import DataFrame
from timeit import default_timer as timer
from transformers import MT5ForConditionalGeneration, MT5TokenizerFast
from typing import List, Dict, Any

  warn(f"Failed to load image Python extension: {e}")


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
class LightningModel(pl.LightningModule):
    
    def __init__(self, hparam):
        super(LightningModel, self).__init__()
        self.hparam = hparam
        self.model = MT5ForConditionalGeneration.from_pretrained(hparam.model_name_or_path)
        self.tokenizer = MT5TokenizerFast.from_pretrained(hparam.model_name_or_path)

    def forward(self, 
                input_ids,
                attention_mask=None,
                decoder_input_ids=None,
                decoder_attention_mask=None,
                labels=None):
        
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
        )

In [5]:
model_checkpoint_path = "logs/mT5/NER-PG/10K/k3/version_0/checkpoints/epoch=03-step=00003-val_loss=1.3927.ckpt"
lightning_model = LightningModel.load_from_checkpoint(model_checkpoint_path)

/home/s81481/pseugc/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(checkpoint_file, map_location="cp

In [6]:
def predict(lightning_model: LightningModel, input_text: str):
        
    start = timer()
    
    model = lightning_model.model
    tokenizer = lightning_model.tokenizer
    
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_stat = "CPU" if device == "cpu" else torch.cuda.get_device_name(0)
    print(f"device_stat: {device_stat}")

    tokenized_outputs = tokenizer.batch_encode_plus(
        [input_text],
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = tokenized_outputs["input_ids"]
    attention_mask = tokenized_outputs["attention_mask"]

    model.to(device)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    outs = model.generate(input_ids=input_ids,
                          attention_mask=attention_mask,
                          max_length=512,
                          temperature=0.8,
                          do_sample=True,
                          top_k=100)
    dec = [
        tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
        for ids in outs
    ]

    end = timer()

    print(f"inference_time: {round(end - start, 3)}s")

    return dec[0]

In [7]:
def get_annotation_df_with_input_text_and_predicted_text(input_text: str, 
                                                         predicted_text: str,
                                                         labels: List[str]) -> DataFrame:
    tuples = list()

    input_text_length = len(input_text)
    input_text_copy = input_text[0: input_text_length]

    item_delim = "; "
    token_delim = ": "
    pseudonym_delim = " **"
    token_id = 0
    next_cursor = 0

    predicted_items = predicted_text.split(item_delim)
    for item in predicted_items:

        label, token, pseudonym = "", "", ""

        for l in labels:
            if item.startswith(l):
                label = l

        if label != "" and (label+token_delim) in item:

            value_splits = item.split(label+token_delim)
            token_pseudonym = value_splits[1]

            if (pseudonym_delim in token_pseudonym and token_pseudonym.endswith(pseudonym_delim.strip())):

                pseudonym_splits = token_pseudonym.split(pseudonym_delim)
                token = pseudonym_splits[0]
                pseudonym = pseudonym_splits[1][:-2]

            else:
                token = token_pseudonym

            if len(token.strip()) > 0:

                start = input_text_copy.find(token)
                if start == -1 and ' ' in token:
                    start = input_text_copy.find(token.split(' ')[0])
                    token = token.replace(' ', '')

                if start != -1:
                    end = start + len(token)

                    token_id += 1
                    prev_cursor = next_cursor
                    next_cursor += end
                    input_text_copy = input_text[next_cursor: input_text_length]

                    start = prev_cursor + start
                    end = prev_cursor + end

                    tuples.append((
                        'T' + str(token_id),
                        label,
                        start,
                        end,
                        input_text[start:end],
                        pseudonym
                    ))

    return pd.DataFrame(
        tuples,
        columns=["Token_ID", "Label", "Start", "End", "Token", "Pseudonym"]
    )

def get_pseudonymized_text(input_text: str, predicted_annotation_df: DataFrame) -> str:
    output_text = input_text
    offset = 0
    for index, row in predicted_annotation_df.iterrows():
        output_text = output_text[:(row.Start+offset)] + row.Pseudonym + output_text[(row.End+offset):]
        offset += len(row.Pseudonym) - len(row.Token)
    return output_text

In [16]:
def predict_for_sample_text(lightning_model: LightningModel, input_text: str) -> str:
    labels = ['CITY', 'DATE', 'EMAIL', 'FAMILY', 'FEMALE', 'MALE', 'ORG', 'PHONE', 'STREET', 'STREETNO', 'UFID', 'URL', 'USER', 'ZIP']
    predicted_text = predict(lightning_model, input_text)
    print(f"predicted_text: {predicted_text}")
    output_df = get_annotation_df_with_input_text_and_predicted_text(input_text, predicted_text, labels)
    output_text = get_pseudonymized_text(input_text, output_df)
    return output_df, output_text

In [17]:
input_text = '''
Am Sun, 21 Nov 2013 20:50:06 +0100 schrieb Mr. TIMO:


Man passt sich im allgemeinen erst mal den üblichen Gepflogenheiten an wenn
man irgendwo neu dazu kommt.

Wenn das Signieren üblich wäre würden es die verbreiteten Newsreader in der
Darstellung filtern. Es nervt beim lesen jedes Postings.

Vincenzo

-- 
Mit unseren Sensoren ist der Administrator informiert, bevor es Probleme im 
Serverraum gibt: preiswerte Monitoring Hard- und Software-kostenloses Plugin 
auch für Nagios - Nachricht per e-mail,SMS und SNMP: http://qpr.azkmja.ye
Messwerte nachträgliche Wärmedämmung http://sch.zwkapb.af/jhnhjtewbyqgyh.rzh
'''

In [25]:
output_df, output_text = predict_for_sample_text(lightning_model, input_text)

device_stat: NVIDIA A100-SXM4-40GB
inference_time: 1.737s
predicted_text: DATE: 21 Nov 2013 **07 Okt 2013**; MALE: TIMO **Heino**; MALE: Vincenzo **Italo**; ORG: Nagios **FDV**; URL: http://qpr.azkmja.ye **http://qgy.xmhbzq.nk**; URL: http://sch.zwkapb.af/jhnhjtewbyqgyh.rzh **http://vl.xkhdpt.qu/hjdtdtjspnvgjp.dxi**


In [26]:
output_df

Unnamed: 0,Token_ID,Label,Start,End,Token,Pseudonym
0,T1,DATE,9,20,21 Nov 2013,07 Okt 2013
1,T2,MALE,48,52,TIMO,Heino
2,T3,MALE,296,304,Vincenzo,Italo
3,T4,ORG,474,480,Nagios,FDV
4,T5,URL,518,538,http://qpr.azkmja.ye,http://qgy.xmhbzq.nk
5,T6,URL,576,615,http://sch.zwkapb.af/jhnhjtewbyqgyh.rzh,http://vl.xkhdpt.qu/hjdtdtjspnvgjp.dxi


In [27]:
print(output_text)


Am Sun, 07 Okt 2013 20:50:06 +0100 schrieb Mr. Heino:


Man passt sich im allgemeinen erst mal den üblichen Gepflogenheiten an wenn
man irgendwo neu dazu kommt.

Wenn das Signieren üblich wäre würden es die verbreiteten Newsreader in der
Darstellung filtern. Es nervt beim lesen jedes Postings.

Italo

-- 
Mit unseren Sensoren ist der Administrator informiert, bevor es Probleme im 
Serverraum gibt: preiswerte Monitoring Hard- und Software-kostenloses Plugin 
auch für FDV - Nachricht per e-mail,SMS und SNMP: http://qgy.xmhbzq.nk
Messwerte nachträgliche Wärmedämmung http://vl.xkhdpt.qu/hjdtdtjspnvgjp.dxi

