In [34]:
import pandas as pd
import requests
import re

from collections import Counter
from typing import List, Dict

from transformers import BertModel, BertTokenizerFast
import torch

from consts import *

In [12]:
df = pd.read_csv('kns_csv_files/kns_committee.csv')
df = df[df['KnessetNum'] >= 25]
df = df[df['CategoryID'].isin([MONEY_COM_CATEGORY_ID, DEFENSE_COM_CATEGORY_ID, LAW_ORDER_COM_CATEGORY_ID, MESADERET_COM_CATEGORY_ID, KNESSET_COM_CATEGORY_ID])]
commitee_ids = df['CommitteeID'].to_list()

In [13]:
def get_meeting_protocol_text(text_path):
    # Define the URL to fetch
    base_url = 'https://production.oknesset.org/pipelines/data/committees/meeting_protocols_text/'
    # Send GET request to the URL
    response = requests.get(base_url + text_path)

    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        response.encoding = 'utf-8'
        # Retrieve the content of the file
        return response.text
    else:
        raise ValueError(f"Failed to retrieve content. Status code: {response.status_code}")


In [14]:
com_session_df = pd.read_csv('kns_csv_files/kns_committeesession.csv')
com_session_df = com_session_df[com_session_df['CommitteeID'].isin(commitee_ids)]

com_session_df.dropna(subset=['text_parsed_filename'], inplace=True)
text_paths = com_session_df['text_parsed_filename'].to_list()
texts = [get_meeting_protocol_text(path) for path in text_paths]


In [26]:
knesset_members_df = pd.read_csv('kns_csv_files/kns_person.csv')
first_names, last_names = knesset_members_df['FirstName'].to_list(), knesset_members_df['LastName'].to_list()
knesset_members = [' '.join([first_name, last_name]) for first_name, last_name in zip(first_names, last_names)]

warnings = {mem: [0, 0, 0] for mem in knesset_members}

# handle members with a middle name or a nickname
new_first_names, new_last_names = [], []

for fn, ln in zip(first_names, last_names):
    names = re.findall('\w+', fn)
    
    for name in names:
        warnings[name + ' ' + ln] = warnings[fn + ' ' + ln]
        new_first_names.append(name)
        new_last_names.append(ln)

# update first and last names
first_names = new_first_names
last_names = new_last_names

knesset_members = [' '.join([first_name, last_name]) for first_name, last_name in zip(first_names, last_names)]

In [27]:
def get_meeting_warnings(text, warnings, knesset_members) -> None:
    """
    Return warnings from the meeting protocol text.

    Parameters
    ----------
    text : str
        Meeting protocol text.

    warnings: Dict[str, List[int]]
        Number of warnings for each Knesset member.

    knesset_members: List[str]
        List of Knesset members.
    """

    # find all warnings
    matches = re.findall(WARNING_REGEX, text, flags=re.MULTILINE)
    print(len(matches))
    for i, match in enumerate(matches):
        print(f'match #{i}:')
        print(match)
        sentences = match.split('\n')
        first_sentence, last_sentence = sentences[0], sentences[-1]
        for kns_member in knesset_members:
            if kns_member in first_sentence:
                word2idx = {'ראש': 0, 'שני': 1, 'שליש': 2}
                for word, idx in word2idx.items():
                    if word in last_sentence:
                        warnings[kns_member][idx] += 1
                        break
    
    

In [None]:
alephbert_tokenizer = BertTokenizerFast.from_pretrained('onlplab/alephbert-base')
alephbert = BertModel.from_pretrained('onlplab/alephbert-base')

In [44]:
# if not finetuning - disable dropout
alephbert.eval()
text = 'מי אתה בכלל שתרים את הקול שלך עליי'
encoding = alephbert_tokenizer(text, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
    output = alephbert(**encoding)
    print('output:', output)
    logits = output[0]
    print('logists shape:', logits.shape)
    print('logits:', logits)
    pred = torch.argmax(logits, dim=1)
    print(pred.item())

output: BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-7.5732e-01,  5.3082e-02, -6.2342e-01,  ..., -3.4851e-01,
           1.5200e-01,  5.9104e-02],
         [ 2.2266e-01, -8.5340e-01, -2.5723e-02,  ..., -6.8251e-01,
           4.8405e-01, -2.7740e-01],
         [ 1.4334e+00, -2.0554e-02,  3.7486e-01,  ...,  1.6602e-01,
           1.2668e+00, -1.6225e-01],
         ...,
         [ 1.2939e+00, -5.1341e-01,  3.2570e-01,  ...,  4.7693e-05,
           9.7235e-01, -5.4676e-01],
         [ 1.3832e+00, -1.2884e+00, -2.6654e-01,  ...,  1.4298e+00,
           3.7618e-01,  5.9848e-01],
         [ 7.1438e-01, -6.4273e-01, -5.7363e-01,  ..., -3.5674e-01,
           3.1811e-01, -3.2872e-01]]]), pooler_output=tensor([[ 0.0143,  0.1606, -0.1559, -0.1836,  0.4954, -0.2131,  0.6380,  0.5410,
         -0.3370, -0.3516,  0.3491, -0.5018, -0.1477, -0.0547,  0.6812, -0.0139,
          0.4014,  0.4610, -0.1625, -0.1118, -0.2276,  0.0291, -0.5252,  0.6775,
         -0.2995, -0.4915

RuntimeError: a Tensor with 768 elements cannot be converted to Scalar

In [37]:
type(output)

transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions

In [36]:
dir(output)

['__annotations__',
 '__class__',
 '__contains__',
 '__dataclass_fields__',
 '__dataclass_params__',
 '__delattr__',
 '__delitem__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__post_init__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__reversed__',
 '__setattr__',
 '__setitem__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 'attentions',
 'clear',
 'copy',
 'cross_attentions',
 'fromkeys',
 'get',
 'hidden_states',
 'items',
 'keys',
 'last_hidden_state',
 'move_to_end',
 'past_key_values',
 'pooler_output',
 'pop',
 'popitem',
 'setdefault',
 'to_tuple',
 'update',
 'values']

In [42]:
output[0]

tensor([[[-7.5732e-01,  5.3082e-02, -6.2342e-01,  ..., -3.4851e-01,
           1.5200e-01,  5.9104e-02],
         [ 2.2266e-01, -8.5340e-01, -2.5723e-02,  ..., -6.8251e-01,
           4.8405e-01, -2.7740e-01],
         [ 1.4334e+00, -2.0554e-02,  3.7486e-01,  ...,  1.6602e-01,
           1.2668e+00, -1.6225e-01],
         ...,
         [ 1.2939e+00, -5.1341e-01,  3.2570e-01,  ...,  4.7693e-05,
           9.7235e-01, -5.4676e-01],
         [ 1.3832e+00, -1.2884e+00, -2.6654e-01,  ...,  1.4298e+00,
           3.7618e-01,  5.9848e-01],
         [ 7.1438e-01, -6.4273e-01, -5.7363e-01,  ..., -3.5674e-01,
           3.1811e-01, -3.2872e-01]]])