In [1]:
!nvidia-smi

Mon Feb 24 18:41:40 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          On  |   00000000:87:00.0 Off |                    0 |
| N/A   34C    P0             55W /  400W |       1MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import os
import pandas as pd
import pytorch_lightning as pl
import torch

from codealltag_data_processor_v2025 import CodealltagDataProcessor
from pandas import DataFrame
from tqdm import tqdm
from transformers import MT5ForConditionalGeneration, MT5TokenizerFast
from typing import Any, Dict, List, Tuple

  warn(f"Failed to load image Python extension: {e}")


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [4]:
cache_dir = os.path.join(*['/home', 's81481', '.huggingface'])

In [5]:
cdp_2022 = CodealltagDataProcessor(data_version='20220513', config_path=['codealltag_data_processor.yml'])
cdp_2020 = CodealltagDataProcessor(data_version='20200518', config_path=['codealltag_data_processor.yml'])

In [6]:
sample_size = 10_000
k = 5

In [7]:
dataset = cdp_2022.get_train_dev_test_datasetdict_for_sample_size(cdp_2020, sample_size, k)

In [8]:
test_df = dataset["test"].to_pandas()

In [9]:
class LightningModel(pl.LightningModule):
    
    def __init__(self, hparam):
        super(LightningModel, self).__init__()
        self.hparam = hparam
        self.model = MT5ForConditionalGeneration.from_pretrained(hparam.model_name_or_path, cache_dir=cache_dir)
        self.tokenizer = MT5TokenizerFast.from_pretrained(hparam.model_name_or_path, cache_dir=cache_dir)

    def forward(self, 
                input_ids,
                attention_mask=None,
                decoder_input_ids=None,
                decoder_attention_mask=None,
                labels=None):
        
        return self.model(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels
        )

In [10]:
model_dir = f"logs/mT5/NER-PG/{sample_size//1000}K/k{k}/version_0/checkpoints/"
ckpt_name = next(iter(os.listdir(model_dir)), None)
model_path = os.path.join(model_dir, ckpt_name); model_path

'logs/mT5/NER-PG/10K/k5/version_0/checkpoints/epoch=02-step=00002-val_loss=1.3786.ckpt'

In [11]:
lightning_model = LightningModel.load_from_checkpoint(model_path)

/home/s81481/pseugc/lib/python3.9/site-packages/lightning_fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(checkpoint_file, map_location="cp

In [12]:
def predict(lightning_model: LightningModel, input_text: str):
    
    model = lightning_model.model
    tokenizer = lightning_model.tokenizer
    
    tokenized_outputs = tokenizer.batch_encode_plus(
        [input_text],
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    input_ids = tokenized_outputs["input_ids"]
    attention_mask = tokenized_outputs["attention_mask"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    model.to(device)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    outs = model.generate(input_ids=input_ids,
                          attention_mask=attention_mask,
                          max_length=512,
                          temperature=0.8,
                          do_sample=True,
                          top_k=100)
    dec = [
        tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False).strip()
        for ids in outs
    ]

    return dec[0]

In [13]:
def create_predicted_text_df(test_df: DataFrame, lightning_model: LightningModel, repeat: int = 5) -> DataFrame:
    tuples: List[Tuple] = list()
    with tqdm(total=len(test_df), position=0, leave=True) as progress_bar:
        for idx, row in test_df.iterrows():
            row_items: List[str] = list()
            file_path = row.FilePath
            row_items.append(file_path)
            input_text = cdp_2022.read_email(file_path)[1]
            for repeat_num in range(0, repeat):
                generated_text = predict(lightning_model=lightning_model, input_text=input_text)
                row_items.append(generated_text)
            tuples.append(tuple(row_items))
            progress_bar.update(1)
    return pd.DataFrame(tuples, columns=["FilePath", *[f'V{item+1}' for item in range(0, repeat)]])

In [14]:
predicted_text_df = create_predicted_text_df(test_df=test_df, lightning_model=lightning_model)

100%|█████████████████████████████████████| 2000/2000 [3:22:40<00:00,  6.08s/it]


In [15]:
predicted_text_df.to_csv(f'PredictedText_DF_{cdp_2022.get_data_version()}_{sample_size // 1000}K_k{k}.csv')