In [38]:
from datasets import load_dataset
import random
import pandas as pd
from transformers import AutoTokenizer
import pandas as pd
import re
from nltk.corpus import stopwords
import json
import enum

In [39]:
import pandas as pd

#pd.set_option('display.max_rows', 1000)  # 최대 1000개 행 출력 허용
#pd.set_option('display.max_colwidth', None)

In [40]:
def get_frequent_words(df, dataname,  min_count=10, language='english'):
    if dataname == 'sql':
        text_col='query'
    else: 
        text_col='question'
    stop_words = set(stopwords.words(language))

    # 모든 문장을 하나로 합친 뒤 소문자 변환 + 특수문자 제거
    text = ' '.join(df[text_col]).lower()
    text = re.sub(r'[^\w\s]', '', text)  # 특수문자 제거

    # 토큰화 후 불용어 제거
    words = [word for word in text.split() if word not in stop_words]

    # 단어 개수 세기
    word_counts = pd.Series(words).value_counts()
    return word_counts[word_counts >= min_count]

In [41]:
import re

def annotate_question(question):
    q_lower = question.lower()
    
    if re.search(r'\bhow many\b', q_lower) or \
       re.search(r'\b(total|sum|count|average)\b', q_lower):
        return 'Aggregation'
    
    if re.search(r'\b(highest|lowest|most|least|fastest|slowest)\b', q_lower):
        return 'Aggregation'
    
    if re.search(r'\b(which|what|who)\b', q_lower):
        return 'Lookup'
    
    return 'Other'

In [42]:
datasets = load_dataset("wikitablequestions")

In [43]:
train = datasets['train']
val = pd.DataFrame(datasets['validation'])
test = pd.DataFrame(datasets['test'])

In [44]:
print(len(train))
print(len(val))
print(len(test))

11321
2831
4344


In [45]:
tokenizer = AutoTokenizer.from_pretrained("neulab/omnitab-large")

# 결과 저장용 딕셔너리
results = {}
over_1024 = {}   

# 각 데이터셋(train, test, validation)에 대해 처리
for split in ["train", "test", "validation"]:
    data = datasets[split]

    # 각 테이블의 토큰 개수 계산
    token_counts = []
    over_1024[split] = []  # split별 리스트 초기화

    for sample in data:
        table = sample["table"]  # 테이블 데이터 가져오기

        # Pandas DataFrame 변환
        df_table = pd.DataFrame(table["rows"], columns=table["header"])

        # TAPEX는 DataFrame과 질문을 함께 입력받아야 함
        tokenized = tokenizer(table=df_table, query=sample["question"], truncation=False)

        # input_ids가 리스트인지 확인 후 길이 측정
        token_count = len(tokenized["input_ids"]) if isinstance(tokenized["input_ids"], list) else tokenized["input_ids"]
        token_counts.append(token_count)

        # 1024 초과인 경우 저장
        if token_count > 1024:
            over_1024[split].append(sample)

Token indices sequence length is longer than the specified maximum sequence length for this model (1381 > 1024). Running this sequence through the model will result in indexing errors


In [46]:
print({k: len(v) for k, v in over_1024.items()})

{'train': 1987, 'test': 607, 'validation': 496}


In [47]:
random.seed(42)
over_train = random.sample(over_1024['train'], 1000)

In [48]:
with open('/home/eunji/workspace/kim-internship/Eunji/over_train.json', 'w', encoding='utf-8') as f:
   json.dump(over_train, f, ensure_ascii=False, indent=2)

In [49]:
with open('/home/eunji/workspace/kim-internship/Eunji/over_val.json', 'w', encoding='utf-8') as f:
   json.dump(over_1024['validation'], f, ensure_ascii=False, indent=2)

In [50]:
with open('/home/eunji/workspace/kim-internship/Eunji/over_test.json', 'w', encoding='utf-8') as f:
   json.dump(over_1024['test'], f, ensure_ascii=False, indent=2)

In [51]:
with open('/home/eunji/workspace/kim-internship/Eunji/over_train.json', 'r', encoding='utf-8') as f:
    over_train = json.load(f)

In [52]:
overtrain = pd.DataFrame(over_train)
overval = pd.DataFrame(over_1024['validation'])
overtest = pd.DataFrame(over_1024['test'])
#over_query = overtrain['question'].tolist()

In [53]:
# 각 질문에 annotation 적용
overtrain['annotation'] = overtrain['question'].apply(annotate_question)
overval['annotation'] = overval['question'].apply(annotate_question)
overtest['annotation'] = overtest['question'].apply(annotate_question)

# 결과 출력
print(f'train : {overtrain["annotation"].value_counts()}')
print(f"validation : {overval['annotation'].value_counts()}")
print(f"test:{overtest['annotation'].value_counts()}")

train : annotation
Lookup         436
Aggregation    432
Other          132
Name: count, dtype: int64
validation : annotation
Lookup         224
Aggregation    214
Other           58
Name: count, dtype: int64
test:annotation
Lookup         278
Aggregation    244
Other           85
Name: count, dtype: int64


In [54]:
overtrain_others = overtrain[overtrain['annotation'] == "Other"]
overval_others = overval[overval['annotation']=="Other"]
overtest_others = overtest[overtest['annotation']=="Other"]

In [55]:
get_frequent_words(overtest_others, dataname="tq", min_count=2)[:10]

number     19
long       15
first      12
one        10
name        8
route       7
list        6
last        6
game        6
unicode     5
Name: count, dtype: int64

In [56]:
get_frequent_words(overval_others, dataname="tq", min_count=2)[:10]

name        9
long        8
last        7
hospital    5
number      5
first       5
state       4
saros       4
county      4
largest     4
Name: count, dtype: int64

In [57]:
get_frequent_words(overtrain_others, dataname="tq", min_count=2)[:10]

name      28
long      21
number    12
first     10
year       9
last       8
one        7
held       7
game       6
listed     6
Name: count, dtype: int64

In [58]:
#overval_others[overval_others['annotation']=="Other"][]

---

In [59]:
wikisql = load_dataset("wikisql")

In [60]:
wikisql_val= pd.DataFrame(wikisql['validation'])

In [61]:
wikisql_val.columns

Index(['phase', 'question', 'table', 'sql'], dtype='object')

In [62]:
val_query = wikisql_val['question'].tolist()
val_header =  wikisql_val['table'].apply(lambda x: x['header']).tolist()

In [63]:
labels = ["None", "Max", "Min", "Count", "Sum", "Average"]

In [64]:
Num_labels = len(labels)
id2label = {id:label for id, label in enumerate(labels)}
label2id = {label:id for id, label in enumerate(labels)}

In [65]:
wikisql_val_agg = wikisql_val['sql'].apply(lambda x: x['agg']).tolist()

In [66]:
val_qt = [id2label[x] for x in wikisql_val_agg]

In [67]:
val_data = pd.DataFrame({'query': val_query, 'header': val_header, 'agg': wikisql_val_agg, 'agg_label': val_qt})

In [69]:
none_val = val_data[val_data['agg']==0]
max_val = val_data[val_data['agg']==1]
min_val = val_data[val_data['agg']==2]
count_val =  val_data[val_data['agg']==3]
sum_val =  val_data[val_data['agg']==4]
avg_val =  val_data[val_data['agg']==5]

In [70]:
print(len(none_val), len(max_val), len(min_val), len(count_val), len(sum_val), len(avg_val))

6017 507 468 779 321 329


In [76]:
def get_frequent_words_for_none(df, text_col='query', min_count=10):
    words = ' '.join(df[text_col]).split()
    word_counts = pd.Series(words).value_counts()
    return word_counts[word_counts >= min_count]

In [83]:
get_frequent_words_for_none(none_val)[:10]

the     6385
is      3124
of      2985
What    2892
a       1718
was     1489
for     1033
and     1024
has     1001
when     934
Name: count, dtype: int64

In [79]:
get_frequent_words(avg_val, dataname="sql")[:10]

average    231
less        81
number      54
total       47
0           44
1           43
larger      34
smaller     34
points      33
rank        32
Name: count, dtype: int64

In [82]:
get_frequent_words(max_val, dataname="sql")[:10]

highest    213
number     101
less        76
larger      47
name        45
total       42
points      42
smaller     42
team        38
year        37
Name: count, dtype: int64

In [None]:
get_frequent_words(count_val, dataname="sql")[:10]

many       397
total      258
number     251
name        79
less        72
points      69
larger      51
smaller     44
goals       42
0           41
Name: count, dtype: int64

In [None]:
get_frequent_words(sum_val, dataname="sql")[:10]

sum        94
total      84
many       71
less       64
number     40
points     33
0          32
1          32
greater    30
larger     30
Name: count, dtype: int64

In [None]:
get_frequent_words(avg_val, dataname="sql")[:10]

average    231
less        81
number      54
total       47
0           44
1           43
larger      34
smaller     34
points      33
rank        32
Name: count, dtype: int64

---