In [1]:
from datasets import load_dataset
import random
import pandas as pd
from transformers import AutoTokenizer
import pandas as pd
import re
from nltk.corpus import stopwords
import json
import os

  from .autonotebook import tqdm as notebook_tqdm


In [23]:
import pandas as pd

#pd.set_option('display.max_rows', 1000)  # 최대 1000개 행 출력 허용
pd.set_option('display.max_colwidth', None)

In [3]:
def get_frequent_words(df, dataname, min_count=10, language='english'):
    if dataname == 'sql':
        text_col='query'
    else: 
        text_col='question'
    stop_words = set(stopwords.words(language))

    # 모든 문장을 하나로 합친 뒤 소문자 변환 + 특수문자 제거
    text = ' '.join(df[text_col]).lower()
    text = re.sub(r'[^\w\s]', '', text)  # 특수문자 제거

    # 토큰화 후 불용어 제거
    words = [word for word in text.split() if word not in stop_words]

    # 단어 개수 세기
    word_counts = pd.Series(words).value_counts()
    return word_counts[word_counts >= min_count]

In [4]:
import re

def annotate_question(question):
    q_lower = question.lower()

    # 일반적인 'lookup' 조건 
    if re.search(r'\b(which|what|who|name)\b', q_lower):
        return 'Lookup'
    
    # WikiTQ 데이터셋에서 추가된
    # 'where'과 'first'가 동시에 있을 때
    if 'when' in q_lower and 'first' in q_lower:
        return 'Lookup'
    
    if re.search(r'\b(where)\b', q_lower):
        return 'Lookup'

    # 일반적인 'Aggregation' 조건 
    if re.search(r'\bhow many\b', q_lower) or \
       re.search(r'\b(total|sum|count|average)\b', q_lower):
        return 'Aggregation'
    
    # WikiTQ 데이터셋에서 추가된 Aggregation 조건 
    if re.search(r'\b(highest|lowest|most|least|fastest|slowest|largest|fewest)\b', q_lower):
        return 'Aggregation'
    
    if re.search(r'\bnumber of\b', q_lower):
        return 'Aggregation'
    
    if re.search(r'\b(is|were|was|does|did)\b', q_lower):
        return 'Aggregation'
    
    if re.search(r'\bmore or less\b', q_lower):
        return 'Aggregation'
    
    return 'Other'

In [6]:
# Load Datasets
datasets = load_dataset("wikitablequestions", trust_remote_code=True)

In [7]:
train = pd.DataFrame(datasets['train'])
val = pd.DataFrame(datasets['validation'])
test = pd.DataFrame(datasets['test'])

In [8]:
print(len(train))
print(len(test))
print(len(val))

11321
4344
2831


In [9]:
tokenizer = AutoTokenizer.from_pretrained("neulab/omnitab-large")

# 결과 저장용 딕셔너리
results = {}
over_1024 = {}   
under_1024 = {}

# 각 데이터셋(train, test, validation)에 대해 처리
for split in ["train", "test", "validation"]:
    data = datasets[split]

    # 각 테이블의 토큰 개수 계산
    token_counts = []
    under_1024[split] = []
    over_1024[split] = []  # split별 리스트 초기화

    for sample in data:
        table = sample["table"]  # 테이블 데이터 가져오기

        # Pandas DataFrame 변환
        df_table = pd.DataFrame(table["rows"], columns=table["header"])

        # TAPEX(OmniTab)는 DataFrame과 질문을 함께 입력받아야 함
        tokenized = tokenizer(table=df_table, query=sample["question"], truncation=False)

        # input_ids가 리스트인지 확인 후 길이 측정
        token_count = len(tokenized["input_ids"]) if isinstance(tokenized["input_ids"], list) else tokenized["input_ids"]
        token_counts.append(token_count)

        # 1024 초과인 경우 저장
        if token_count > 1024:
            over_1024[split].append(sample)
        else: 
            under_1024[split].append(sample)

Token indices sequence length is longer than the specified maximum sequence length for this model (1381 > 1024). Running this sequence through the model will result in indexing errors


In [11]:
print({k: len(v) for k, v in over_1024.items()})
print({k: len(v) for k, v in under_1024.items()})

{'train': 1987, 'test': 607, 'validation': 496}
{'train': 9334, 'test': 3737, 'validation': 2335}


In [47]:
random.seed(42)
over_train = random.sample(over_1024['train'], 1000)

In [34]:
output_dir = 'C:/Users/B6313/tableqa_wiki/wikiTQ_json'
os.makedirs(output_dir, exist_ok=True)

In [None]:
for split in ['train', 'validation', 'test']:
    # over_1024 저장
    over_path = os.path.join(output_dir, f'over_{split}.json')
    with open(over_path, 'w', encoding='utf-8') as f:
        json.dump(over_1024[split], f, ensure_ascii=False, indent=2)

    # under_1024 저장
    under_path = os.path.join(output_dir, f'under_{split}.json')
    with open(under_path, 'w', encoding='utf-8') as f:
        json.dump(under_1024[split], f, ensure_ascii=False, indent=2)

In [7]:
over_1024 ={}
under_1024 = {}

over_1024_df = {}
under_1024_df = {}

for split in ['train', 'validation', 'test']:
    # over_1024 불러오기
    over_path = os.path.join(output_dir, f'over_{split}.json')
    with open(over_path, 'r', encoding='utf-8') as f:
        over_1024[split] = json.load(f)
        over_1024_df[split] = pd.DataFrame(over_1024[split])
    
    # under_1024 불러오기
    under_path = os.path.join(output_dir, f'under_{split}.json')
    with open(under_path, 'r', encoding='utf-8') as f:
        under_1024[split] = json.load(f)
        under_1024_df[split] = pd.DataFrame(under_1024[split])

In [8]:
for split in ['train', 'validation', 'test']:
    under_1024_df[split]['annotation'] = under_1024_df[split]['question'].apply(annotate_question)

for split in ['train', 'validation', 'test']:
    print(f"[under_1024] {split}: \n{under_1024_df[split]['annotation'].value_counts()}")

[under_1024] train: 
annotation
Lookup         5917
Aggregation    3235
Other           182
Name: count, dtype: int64
[under_1024] validation: 
annotation
Lookup         1454
Aggregation     813
Other            68
Name: count, dtype: int64
[under_1024] test: 
annotation
Lookup         2328
Aggregation    1319
Other            90
Name: count, dtype: int64


In [21]:
under_1024_df['test'][['table', 'question', 'annotation']][:10]

Unnamed: 0,table,question,annotation
0,"{'header': ['Rank', 'Cyclist', 'Team', 'Time',...",which country had the most cyclists finish wit...,Lookup
1,"{'header': ['Description Losses', '1939/40', '...",how many people were murdered in 1940/41?,Aggregation
2,"{'header': ['Year', 'Division', 'League', 'Reg...",how long did it take for the new york american...,Aggregation
3,"{'header': ['Series #', 'Season #', 'Title', '...",alfie's birthday party aired on january 19. wh...,Lookup
4,"{'header': ['Date', 'Competition', 'Location',...",what is the number of 1st place finishes acros...,Lookup
5,"{'header': ['Year', 'Competition', 'Venue', 'P...",in which competition did hopley finish fist?,Lookup
6,"{'header': ['Year', 'Film', 'Role', 'Language'...",what is the total number of films with the lan...,Lookup
7,"{'header': ['Game', 'Day', 'Date', 'Kickoff', ...",what was the number of people attending the to...,Lookup
8,"{'header': ['Year', 'Kit Manufacturer', 'Shirt...",what time period had no shirt sponsor?,Lookup
9,"{'header': ['Year', 'Competition', 'Venue', 'P...",when was his first 1st place record?,Aggregation


In [38]:
# under_1024_df['train'][under_1024_df['train']['annotation'].str.contains('Other', case=False, na=False)][10:]

In [9]:
filtered_under_1024_df = {}

for split in ['train', 'validation', 'test']:
    df = under_1024_df[split]
    filtered_under_1024_df[split] = df[df['annotation'].isin(['Lookup', 'Aggregation'])]

for split in ['train', 'validation', 'test']:
    print(f"[filtered_under_1024_df] {split}: \n{filtered_under_1024_df[split]['annotation'].value_counts()}")

[filtered_under_1024_df] train: 
annotation
Lookup         5917
Aggregation    3235
Name: count, dtype: int64
[filtered_under_1024_df] validation: 
annotation
Lookup         1454
Aggregation     813
Name: count, dtype: int64
[filtered_under_1024_df] test: 
annotation
Lookup         2328
Aggregation    1319
Name: count, dtype: int64


---

1. 라벨 None을 'annotation'열에서 Lookup으로 바꾸기
2. Lookup을 제외한 나머지 라벨은 'annotation'열에서 Aggregation으로 바꾸기
3. 개수 차이나는지 확인해보기

In [10]:
wikisql = load_dataset("wikisql", trust_remote_code=True)

In [11]:
print({k: len(v) for k, v in wikisql.items()})

{'test': 15878, 'validation': 8421, 'train': 56355}


In [12]:
wikisql_df = {}

for split in ['train', 'validation', 'test']:
    wikisql_df[split] = pd.DataFrame(wikisql[split])
    wikisql_df[split]['annotation'] = wikisql_df[split]['sql'].apply(
        lambda x: 'Lookup' if x['agg'] == 0 else 'Aggregation'
    )

In [13]:
for split in ['train', 'validation', 'test']:
    print(f"[wikisql_df] {split}: \n{wikisql_df[split]['annotation'].value_counts()}")

[wikisql_df] train: 
annotation
Lookup         40606
Aggregation    15749
Name: count, dtype: int64
[wikisql_df] validation: 
annotation
Lookup         6017
Aggregation    2404
Name: count, dtype: int64
[wikisql_df] test: 
annotation
Lookup         11324
Aggregation     4554
Name: count, dtype: int64


4. 개수 차이가 나므로 비율 정해서 언더샘플링 

In [14]:
def agg_ratio(df):
    return df['annotation'].value_counts() / df['annotation'].value_counts().sum()

for split in ['train', 'validation', 'test']:
    print(f"[{split}]")
    print(agg_ratio(wikisql_df[split]), end="\n\n")

[train]
annotation
Lookup         0.720539
Aggregation    0.279461
Name: count, dtype: float64

[validation]
annotation
Lookup         0.714523
Aggregation    0.285477
Name: count, dtype: float64

[test]
annotation
Lookup         0.713188
Aggregation    0.286812
Name: count, dtype: float64



In [15]:
def undersample(df, target_col='annotation', target_class='Lookup', target_ratio=0.35, random_state=7):
    # 클래스별로 분리
    majority = df[df[target_col] == target_class]
    others = df[df[target_col] != target_class]

    # 목표 비율에 맞게 클래스 'Lookup'에서 일부만 샘플링
    target_n = int(len(df) * target_ratio)
    #print(f'traget_n :{target_n}')
    sampled_majority = majority.sample(n=target_n, random_state=random_state)

    # 합치기
    balanced_df = pd.concat([sampled_majority, others], axis=0).sample(frac=1, random_state=random_state).reset_index(drop=True) # frac=1 → 전체 행을 다 섞음 (shuffle)
    return balanced_df

In [16]:
wikisql_balanced_df = {}

for split in ['train', 'validation', 'test']:
    wikisql_balanced_df[split] = undersample(wikisql_df[split])

In [17]:
for split in ['train', 'validation', 'test']:
    print(f"[{split}]")
    print(wikisql_balanced_df[split]['annotation'].value_counts().round(2), "\n")

[train]
annotation
Lookup         19724
Aggregation    15749
Name: count, dtype: int64 

[validation]
annotation
Lookup         2947
Aggregation    2404
Name: count, dtype: int64 

[test]
annotation
Lookup         5557
Aggregation    4554
Name: count, dtype: int64 



5. wikisql 데이터와 wikiTableQuestion 데이터 join

In [18]:
print(filtered_under_1024_df['train'].columns)
print(wikisql_balanced_df['train'].columns)

Index(['id', 'question', 'answers', 'table', 'annotation'], dtype='object')
Index(['phase', 'question', 'table', 'sql', 'annotation'], dtype='object')


In [26]:
combined = {}

for split in ['train', 'validation', 'test']:
    df1 = filtered_under_1024_df[split][['question', 'table', 'annotation']].copy()
    df2 = wikisql_balanced_df[split][['question', 'table', 'annotation']].copy()

    # 필요하다면 source 구분
    df1['source'] = 'wikitq'
    df2['source'] = 'wikisql'

    for df in [df1, df2]:
        df['annotation_num'] = df['annotation'].map({'Lookup':0, 'Aggregation':1})

    combined[split] = pd.concat([df1, df2], axis=0).reset_index(drop=True)

In [40]:
for split in ['train', 'validation', 'test']:
    combined[split]['header'] = combined[split]['table'].apply(lambda x: x['header'])

In [42]:
for split in ['train', 'validation', 'test']:
    file_path = os.path.join(output_dir, f'combined_{split}.json')
    combined[split].to_json(file_path, force_ascii=False, orient='records', indent=2)

In [36]:
combined = {}

for split in ['train', 'validation', 'test']:
    file_path = os.path.join(output_dir, f'combined_{split}.json')
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
        combined[split] = pd.DataFrame(data)

In [39]:
combined['validation']

Unnamed: 0,question,table,annotation,source,annotation_num,header
0,which team won previous to crettyard?,"{'header': ['Team', 'County', 'Wins', 'Years w...",Lookup,wikitq,0,"[Team, County, Wins, Years won]"
1,how many more passengers flew to los angeles t...,"{'header': ['Rank', 'City', 'Passengers', 'Ran...",Aggregation,wikitq,1,"[Rank, City, Passengers, Ranking, Airline]"
2,after winning on four credits with a full hous...,"{'header': ['Hand', '1 credit', '2 credits', '...",Lookup,wikitq,0,"[Hand, 1 credit, 2 credits, 3 credits, 4 credi..."
3,which players played the same position as ardo...,"{'header': ['No.', 'Player', 'Birth Date', 'We...",Lookup,wikitq,0,"[No., Player, Birth Date, Weight, Height, Posi..."
4,what was the venue when he placed first?,"{'header': ['Year', 'Competition', 'Venue', 'P...",Lookup,wikitq,0,"[Year, Competition, Venue, Position, Notes]"
...,...,...,...,...,...,...
7613,"How much Gain has a Long of 29, and an Avg/G s...","{'header': ['Name', 'Gain', 'Loss', 'Long', 'A...",Aggregation,wikisql,1,"[Name, Gain, Loss, Long, Avg/G]"
7614,What is the chapter for Illinois Wesleyan?,"{'header': ['Chapter', 'Installation Date', 'I...",Lookup,wikisql,0,"[Chapter, Installation Date, Institution, Loca..."
7615,What is the score when the tie is 9?,"{'header': ['Tie no', 'Home team', 'Score', 'A...",Lookup,wikisql,0,"[Tie no, Home team, Score, Away team, Date]"
7616,Name the D 47 when it has a D 45 of d 32,"{'header': ['D 48', 'D 47', 'D 46', 'D 45', 'D...",Lookup,wikisql,0,"[D 48, D 47, D 46, D 45, D 44, D 43, D 42, D 41]"


In [40]:
# wikiTQ, sql 데이터 섞기
for split in ['train', 'validation', 'test']:
    combined[split] = combined[split].sample(frac=1, random_state=7).reset_index(drop=True)

6. 분류기 생성

In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from easydict import EasyDict
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
import pandas as pd
from easydict import EasyDict
import gzip
import json
from sklearn.metrics import accuracy_score
import numpy as np
import evaluate
from transformers import Trainer, TrainingArguments
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_validate
import os
from transformers import EarlyStoppingCallback
import random
from torch.utils.data import Dataset
from accelerate import Accelerator
from transformers import DataCollatorWithPadding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
device

'cuda'

In [45]:
# 토크나이저를 쉽게 처리하기 위해 json 파일로 저장
def convert_to_jsonl(df, out_path):
    with gzip.open(out_path, 'wt', encoding='utf-8') as f:
        for i in range(len(df)):

            # header안에 리스트인 경우 문자로 변환해서 *로 합쳐줘야함 
            header = df['header'][i]
            if isinstance(header, list):
                header = [str(h) for h in header]
            else:
                header = str(header)

            # 라벨값이 numpy이면 json.dumps가 처리하지 못함
            label = df['annotation_num'][i]
            if isinstance(label, (np.integer, np.int64, np.int32)):
                label = int(label)

            item = {
                'id': i,
                'query': df['question'][i],
                'header': ' * '.join(header),
                'label': label,
                'category' : df['annotation'][i]
            }

            f.write(json.dumps(item) + '\n')

In [46]:
# def preprocess(example):
#     return tokenizer(example['query'], example['header'], 
#                      #return_tensors='pt', 
#                      truncation=True, 
#                      padding='max_length', # 최대길이가 안되면 나머지 0으로 채움
#                      max_length=128) # 문장 최대 길이

def preprocess(example):
    encoding = tokenizer(example['query'], example['header'], truncation=True)
    encoding['labels'] = example['label']
    return  encoding

# def preprocess(example):
#     encoding = tokenizer(example['query'], truncation=True)
#     encoding['labels'] = example['label']
#     return  encoding

In [47]:
output_path = "C:/Users/B6313/tableqa_wiki/wiki_cls_dataset/"

for split in ['train', 'validation', 'test']:
    convert_to_jsonl(combined[split], os.path.join(output_path, f'{split}.jsonl.gz'))

In [48]:
output_path = "C:/Users/B6313/tableqa_wiki/wiki_cls_dataset/"
file_check = os.path.join(output_path + 'test.jsonl.gz')
with gzip.open(file_check, 'rt', encoding='utf-8') as f:
    for i, line in enumerate(f):
        print(line)
        if i > 2:
            break

{"id": 0, "query": "What was the air date of part 2 of the episode whose part 1 was aired on January 31, 2008?", "header": "Episode # * Title * Part 1 * Part 2 * Part 3 * Part 4 * Part 5 * Part 6", "label": 0, "category": "Lookup"}

{"id": 1, "query": "The candidate who received 133 votes in the Bronx won what percentage overall?", "header": "1921 * party * Manhattan * The Bronx * Brooklyn * Queens * Richmond [Staten Is.] * Total * %", "label": 0, "category": "Lookup"}

{"id": 2, "query": "What is the lowest 2004 population when there were 5158 households?", "header": "Name * Geographic code * Type * Households * Population (2004) * Foreign population * Moroccan population", "label": 1, "category": "Aggregation"}

{"id": 3, "query": "What is the average number of innings with more than 3 in the 100s category?", "header": "Player * Team * Matches * Innings * Runs * Average * Highest Score * 100s", "label": 1, "category": "Aggregation"}



In [49]:
dataset = load_dataset('json', data_files={
    'train': os.path.join(output_path, 'train.jsonl.gz'),
    'validation': os.path.join(output_path,'validation.jsonl.gz'),
    'test': os.path.join(output_path ,'test.jsonl.gz')
})

Generating train split: 44625 examples [00:00, 885776.69 examples/s]
Generating validation split: 7618 examples [00:00, 441810.92 examples/s]
Generating test split: 13758 examples [00:00, 959227.94 examples/s]


In [62]:
((44625/8) * 10) * 0.05

2789.0625

In [50]:
dataset['train'][0]

{'id': 0,
 'query': 'name someone else from scotland inducted before alan brazil.',
 'header': 'Season * Level * Name * Position * Nationality * International\\ncaps',
 'label': 0,
 'category': 'Lookup'}

In [51]:
model_path = 'google-bert/bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenized_dataset = dataset.map(preprocess, batched=True)

Map: 100%|██████████| 44625/44625 [00:01<00:00, 39864.62 examples/s]
Map: 100%|██████████| 7618/7618 [00:00<00:00, 36033.18 examples/s]
Map: 100%|██████████| 13758/13758 [00:00<00:00, 16387.55 examples/s]


In [28]:
print(tokenized_dataset['test'][:20])

{'id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], 'query': ['which country had the most cyclists finish within the top 10?', 'how many people were murdered in 1940/41?', 'how long did it take for the new york americans to win the national cup after 1936?', "alfie's birthday party aired on january 19. what was the airdate of the next episode?", 'what is the number of 1st place finishes across all events?', 'in which competition did hopley finish fist?', 'what is the total number of films with the language of kannada listed?', 'what was the number of people attending the toros mexico vs. monterrey flash game?', 'what time period had no shirt sponsor?', 'when was his first 1st place record?', 'does pat or john have the highest total?', 'what is the combined score of year end rankings before 2009?', 'how many more ships were wrecked in lake huron than in erie?', 'what was the total number of points scored by the tide in the last 3 games combined.', 'who came im

In [52]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [53]:
labels = ["Lookup", "Aggregation"]
Num_labels = len(labels)
id2label = {id:label for id, label in enumerate(labels)}
label2id = {label:id for id, label in enumerate(labels)}

In [54]:
id2label

{0: 'Lookup', 1: 'Aggregation'}

In [55]:
label2id

{'Lookup': 0, 'Aggregation': 1}

In [56]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_path, num_labels=2, id2label=id2label, label2id=label2id
)

model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e

In [57]:
# load metrics
accuracy = evaluate.load("accuracy")
f1_score = evaluate.load("f1")

def compute_metrics(eval_pred):
    # eval_pred = (predictions, labels)
    predictions, labels = eval_pred

    # 다중분류
    # 특정 i라벨의 확률 = 특정 i 라벨의 승산/모든 라벨의 승산 
    # predictions = [batch_size, num_labels]
    #probabilities = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)

    #positive_class_probs = probabilities[:, 1] # 클래스 1일 확률 

    # compute auc
    #auc = np.round(auc_score.compute(prediction_scores=positive_class_probs,
    #                reference=labels)['roc_auc'],3)

    # 가장 로짓이 큰 라벨 추출
    predicted_classes = np.argmax(predictions, axis=1)

    # compute accuracy
    acc = np.round(accuracy.compute(predictions=predicted_classes, 
                                     references=labels)['accuracy'],3)
    
    f1 = np.round(f1_score.compute(predictions=predicted_classes, references=labels, average='macro')['f1'], 3) #  라벨별 f1-score를 산술평균한 것 : 현재 라벨의 갯수가 같아서 이렇게 써도 된다고 판단
    
    return {"Accuracy": acc, "f1": f1}

In [None]:
#hyperparameters
lr = 3e-5
num_epochs = 10

training_args = TrainingArguments(
    output_dir='C:/Users/B6313/tableqa_wiki/bert-agg',
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    num_train_epochs=num_epochs,
    # Number of steps used for a linear warmup
    warmup_steps=2800,                
    weight_decay=0.01,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    report_to='none',
    fp16=True,
    #metric_for_best_model="f1",
    dataloader_num_workers=4,
)

In [32]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    tokenizer=tokenizer,
    data_collator = data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1
1,0.3253,0.29436,0.912,0.911
2,0.3182,0.29904,0.921,0.919
3,0.6654,0.68147,0.578,0.366


KeyboardInterrupt: 

In [None]:
preds = trainer.predict(tokenized_dataset['test'])

logits = preds.predictions
labels = pred.label_ids

metrics = compute_metrics((logits, labels))
print(metrics)