In [2]:
import pandas as pd
import torch
import numpy as np
import torch.nn as nn
#from sklearn.model_selection import train_test_split
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, ElectraForSequenceClassification, AdamW
from tqdm.notebook import tqdm

#device = torch.device("cuda:0")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class SOFDataset(Dataset):
  
  def __init__(self, csv_file):
    # 일부 값중에 NaN이 있음...
    self.dataset = pd.read_csv(csv_file).dropna(axis=0) 
    
    # 중복제거
    self.dataset.drop_duplicates(subset=['Body'], inplace=True)
    
    # Y 값 숫자로 레이블링
    
    self.dataset['Y'] = self.dataset['Y'].map({'LQ_CLOSE':0, 
                                      'LQ_EDIT':1,
                                      'HQ':2})
    
    self.tokenizer = AutoTokenizer.from_pretrained('google/electra-small-discriminator')

    print(self.dataset.describe())
  
  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, idx):
    row = self.dataset.iloc[idx,[2,5]].values # 데이터셋 따라 idx 변경
    text = row[0]
    y = row[1]

    inputs = self.tokenizer(
        text, 
        return_tensors='pt',
        truncation=True,
        max_length=512,
        pad_to_max_length=True,
        add_special_tokens=True
        )
    
    input_ids = inputs['input_ids'][0]
    attention_mask = inputs['attention_mask'][0]

    return input_ids, attention_mask, y

In [4]:
train_dataset = SOFDataset("/home/harock96/hw/project/data/train.csv")
valid_dataset = SOFDataset("/home/harock96/hw/project/data/valid.csv")

                 Id             Y
count  4.500000e+04  45000.000000
mean   4.575616e+07      1.000000
std    7.120035e+06      0.816506
min    3.455266e+07      0.000000
25%    3.973593e+07      0.000000
50%    4.503563e+07      1.000000
75%    5.125584e+07      2.000000
max    6.046802e+07      2.000000
                 Id             Y
count  1.500000e+04  15000.000000
mean   4.576106e+07      1.000000
std    7.139021e+06      0.816524
min    3.455297e+07      0.000000
25%    3.980180e+07      0.000000
50%    4.502755e+07      1.000000
75%    5.130322e+07      2.000000
max    6.047032e+07      2.000000


In [5]:
#model

model = ElectraForSequenceClassification.from_pretrained('google/electra-small-discriminator', num_labels = 3)
model = nn.DataParallel(model)
model.to(device)

Some weights of the model checkpoint at google/electra-small-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-small-discriminator and are newly initialized: ['classifier

DataParallel(
  (module): ElectraForSequenceClassification(
    (electra): ElectraModel(
      (embeddings): ElectraEmbeddings(
        (word_embeddings): Embedding(30522, 128, padding_idx=0)
        (position_embeddings): Embedding(512, 128)
        (token_type_embeddings): Embedding(2, 128)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (embeddings_project): Linear(in_features=128, out_features=256, bias=True)
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output)

In [5]:
#text, attention_mask, y = train_dataset[0]
#model(text.unsqueeze(0).to(device), attention_mask=attention_mask.unsqueeze(0).to(device))

In [6]:
# parameter

epochs = 20
batch_size = 32

optimizer = AdamW(model.parameters(), lr=1e-5)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)

In [9]:
# Learn

losses = []
accuracies = []

for i in range(epochs):
  total_loss = 0.0
  correct = 0
  total = 0
  batches = 0

  model.train()

  for input_ids_batch, attention_masks_batch, y_batch in tqdm(train_loader):
    optimizer.zero_grad()
    y_batch = y_batch.to(device)
    y_pred = model(input_ids_batch.to(device), attention_mask=attention_masks_batch.to(device))[0]
    loss = F.cross_entropy(y_pred, y_batch)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    _, predicted = torch.max(y_pred, 1)
    correct += (predicted == y_batch).sum()
    total += len(y_batch)

    batches += 1
    if batches % 100 == 0:
      print("Batch Loss:", total_loss, "Accuracy:", correct.float() / total)
  
  losses.append(total_loss)
  accuracies.append(correct.float() / total)
  print(i+1, "Train Loss:", total_loss, "Accuracy:", correct.float() / total)

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 36.378212720155716 Accuracy: tensor(0.8547, device='cuda:0')
Batch Loss: 69.0741669088602 Accuracy: tensor(0.8619, device='cuda:0')
Batch Loss: 101.04212884604931 Accuracy: tensor(0.8643, device='cuda:0')
Batch Loss: 132.2782952412963 Accuracy: tensor(0.8658, device='cuda:0')
Batch Loss: 162.7401693239808 Accuracy: tensor(0.8668, device='cuda:0')
Batch Loss: 192.42131755501032 Accuracy: tensor(0.8686, device='cuda:0')
Batch Loss: 219.419739253819 Accuracy: tensor(0.8709, device='cuda:0')
Batch Loss: 248.96859212219715 Accuracy: tensor(0.8714, device='cuda:0')
Batch Loss: 276.86037239432335 Accuracy: tensor(0.8727, device='cuda:0')
Batch Loss: 303.42989887297153 Accuracy: tensor(0.8738, device='cuda:0')
Batch Loss: 330.01489901542664 Accuracy: tensor(0.8748, device='cuda:0')
Batch Loss: 357.93020837008953 Accuracy: tensor(0.8751, device='cuda:0')
Batch Loss: 385.09811802208424 Accuracy: tensor(0.8752, device='cuda:0')
Batch Loss: 411.10993602871895 Accuracy: tensor(0.8762, d

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 24.903156094253063 Accuracy: tensor(0.8991, device='cuda:0')
Batch Loss: 50.12509021162987 Accuracy: tensor(0.8945, device='cuda:0')
Batch Loss: 74.17527192085981 Accuracy: tensor(0.8957, device='cuda:0')
Batch Loss: 97.76129484921694 Accuracy: tensor(0.8977, device='cuda:0')
Batch Loss: 121.04342135041952 Accuracy: tensor(0.8993, device='cuda:0')
Batch Loss: 143.11332079023123 Accuracy: tensor(0.9012, device='cuda:0')
Batch Loss: 166.36610755324364 Accuracy: tensor(0.9017, device='cuda:0')
Batch Loss: 188.4811251387 Accuracy: tensor(0.9021, device='cuda:0')
Batch Loss: 210.72051333636045 Accuracy: tensor(0.9024, device='cuda:0')
Batch Loss: 233.2149366363883 Accuracy: tensor(0.9027, device='cuda:0')
Batch Loss: 256.47568825259805 Accuracy: tensor(0.9029, device='cuda:0')
Batch Loss: 279.6619789302349 Accuracy: tensor(0.9032, device='cuda:0')
Batch Loss: 300.4759212806821 Accuracy: tensor(0.9043, device='cuda:0')
Batch Loss: 322.1124983392656 Accuracy: tensor(0.9053, device

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 19.4992934204638 Accuracy: tensor(0.9216, device='cuda:0')
Batch Loss: 38.37700317054987 Accuracy: tensor(0.9219, device='cuda:0')
Batch Loss: 58.208843395113945 Accuracy: tensor(0.9218, device='cuda:0')
Batch Loss: 77.86080397851765 Accuracy: tensor(0.9214, device='cuda:0')
Batch Loss: 96.530870558694 Accuracy: tensor(0.9221, device='cuda:0')
Batch Loss: 117.43146347440779 Accuracy: tensor(0.9203, device='cuda:0')
Batch Loss: 136.64766944013536 Accuracy: tensor(0.9209, device='cuda:0')
Batch Loss: 155.91321725584567 Accuracy: tensor(0.9216, device='cuda:0')
Batch Loss: 176.81898847408593 Accuracy: tensor(0.9202, device='cuda:0')
Batch Loss: 196.75874837301672 Accuracy: tensor(0.9201, device='cuda:0')
Batch Loss: 216.29054719768465 Accuracy: tensor(0.9202, device='cuda:0')
Batch Loss: 234.66062288545072 Accuracy: tensor(0.9208, device='cuda:0')
Batch Loss: 253.06105976365507 Accuracy: tensor(0.9212, device='cuda:0')
Batch Loss: 271.529586719349 Accuracy: tensor(0.9215, devi

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 17.609430976212025 Accuracy: tensor(0.9287, device='cuda:0')
Batch Loss: 35.94367425143719 Accuracy: tensor(0.9267, device='cuda:0')
Batch Loss: 53.764523131772876 Accuracy: tensor(0.9288, device='cuda:0')
Batch Loss: 70.37184645049274 Accuracy: tensor(0.9296, device='cuda:0')
Batch Loss: 89.30865679122508 Accuracy: tensor(0.9271, device='cuda:0')
Batch Loss: 106.69580542109907 Accuracy: tensor(0.9273, device='cuda:0')
Batch Loss: 125.06562422774732 Accuracy: tensor(0.9271, device='cuda:0')
Batch Loss: 142.96553320623934 Accuracy: tensor(0.9269, device='cuda:0')
Batch Loss: 161.58422280289233 Accuracy: tensor(0.9264, device='cuda:0')
Batch Loss: 178.23074856586754 Accuracy: tensor(0.9271, device='cuda:0')
Batch Loss: 194.81189963780344 Accuracy: tensor(0.9274, device='cuda:0')
Batch Loss: 211.2493251543492 Accuracy: tensor(0.9283, device='cuda:0')
Batch Loss: 228.62199896760285 Accuracy: tensor(0.9283, device='cuda:0')
Batch Loss: 245.95419809035957 Accuracy: tensor(0.9289,

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 15.848745374009013 Accuracy: tensor(0.9362, device='cuda:0')
Batch Loss: 31.148051330819726 Accuracy: tensor(0.9397, device='cuda:0')
Batch Loss: 46.15767671726644 Accuracy: tensor(0.9397, device='cuda:0')
Batch Loss: 63.12989489547908 Accuracy: tensor(0.9375, device='cuda:0')
Batch Loss: 77.72773793712258 Accuracy: tensor(0.9382, device='cuda:0')
Batch Loss: 94.03611617162824 Accuracy: tensor(0.9380, device='cuda:0')
Batch Loss: 109.44699737615883 Accuracy: tensor(0.9379, device='cuda:0')
Batch Loss: 124.33246422186494 Accuracy: tensor(0.9382, device='cuda:0')
Batch Loss: 140.81883189454675 Accuracy: tensor(0.9373, device='cuda:0')
Batch Loss: 157.2132716178894 Accuracy: tensor(0.9366, device='cuda:0')
Batch Loss: 174.17990224063396 Accuracy: tensor(0.9359, device='cuda:0')
Batch Loss: 190.92478904128075 Accuracy: tensor(0.9352, device='cuda:0')
Batch Loss: 206.36383444443345 Accuracy: tensor(0.9354, device='cuda:0')
Batch Loss: 221.03681576624513 Accuracy: tensor(0.9358, 

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 13.892812365666032 Accuracy: tensor(0.9444, device='cuda:0')
Batch Loss: 27.61667506955564 Accuracy: tensor(0.9444, device='cuda:0')
Batch Loss: 42.91005669720471 Accuracy: tensor(0.9428, device='cuda:0')
Batch Loss: 57.15802917256951 Accuracy: tensor(0.9426, device='cuda:0')
Batch Loss: 70.59380289167166 Accuracy: tensor(0.9426, device='cuda:0')
Batch Loss: 84.77500891685486 Accuracy: tensor(0.9437, device='cuda:0')
Batch Loss: 99.93293638899922 Accuracy: tensor(0.9433, device='cuda:0')
Batch Loss: 113.31383410096169 Accuracy: tensor(0.9438, device='cuda:0')
Batch Loss: 127.14275669399649 Accuracy: tensor(0.9441, device='cuda:0')
Batch Loss: 142.8156818812713 Accuracy: tensor(0.9436, device='cuda:0')
Batch Loss: 158.7767598265782 Accuracy: tensor(0.9429, device='cuda:0')
Batch Loss: 172.9713108977303 Accuracy: tensor(0.9430, device='cuda:0')
Batch Loss: 187.34386711474508 Accuracy: tensor(0.9429, device='cuda:0')
Batch Loss: 202.3243422890082 Accuracy: tensor(0.9428, devic

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 12.300424614921212 Accuracy: tensor(0.9509, device='cuda:0')
Batch Loss: 24.350000127218664 Accuracy: tensor(0.9539, device='cuda:0')
Batch Loss: 36.93630543258041 Accuracy: tensor(0.9525, device='cuda:0')
Batch Loss: 48.6908105192706 Accuracy: tensor(0.9528, device='cuda:0')
Batch Loss: 60.17556328698993 Accuracy: tensor(0.9534, device='cuda:0')
Batch Loss: 72.29731773771346 Accuracy: tensor(0.9533, device='cuda:0')
Batch Loss: 85.2489193379879 Accuracy: tensor(0.9527, device='cuda:0')
Batch Loss: 97.88322128821164 Accuracy: tensor(0.9523, device='cuda:0')
Batch Loss: 110.08706947881728 Accuracy: tensor(0.9523, device='cuda:0')
Batch Loss: 123.52410671208054 Accuracy: tensor(0.9516, device='cuda:0')
Batch Loss: 136.88112787809223 Accuracy: tensor(0.9512, device='cuda:0')
Batch Loss: 150.08635599073023 Accuracy: tensor(0.9507, device='cuda:0')
Batch Loss: 164.0058828080073 Accuracy: tensor(0.9503, device='cuda:0')
Batch Loss: 176.31829942483455 Accuracy: tensor(0.9506, devi

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 10.989136166870594 Accuracy: tensor(0.9600, device='cuda:0')
Batch Loss: 21.756204136647284 Accuracy: tensor(0.9591, device='cuda:0')
Batch Loss: 32.296462340280414 Accuracy: tensor(0.9586, device='cuda:0')
Batch Loss: 43.40807876177132 Accuracy: tensor(0.9578, device='cuda:0')
Batch Loss: 55.233390813693404 Accuracy: tensor(0.9568, device='cuda:0')
Batch Loss: 66.66851504612714 Accuracy: tensor(0.9570, device='cuda:0')
Batch Loss: 78.13076629396528 Accuracy: tensor(0.9568, device='cuda:0')
Batch Loss: 89.586644154042 Accuracy: tensor(0.9570, device='cuda:0')
Batch Loss: 101.9172126268968 Accuracy: tensor(0.9560, device='cuda:0')
Batch Loss: 113.08610440790653 Accuracy: tensor(0.9561, device='cuda:0')
Batch Loss: 123.49144588410854 Accuracy: tensor(0.9563, device='cuda:0')
Batch Loss: 133.6645186925307 Accuracy: tensor(0.9568, device='cuda:0')
Batch Loss: 145.7320323260501 Accuracy: tensor(0.9565, device='cuda:0')
Batch Loss: 156.90296548046172 Accuracy: tensor(0.9567, devi

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 8.85404459759593 Accuracy: tensor(0.9641, device='cuda:0')
Batch Loss: 17.608056265395135 Accuracy: tensor(0.9655, device='cuda:0')
Batch Loss: 28.02081694966182 Accuracy: tensor(0.9638, device='cuda:0')
Batch Loss: 38.11743744648993 Accuracy: tensor(0.9633, device='cuda:0')
Batch Loss: 47.96537497360259 Accuracy: tensor(0.9634, device='cuda:0')
Batch Loss: 58.01044191792607 Accuracy: tensor(0.9630, device='cuda:0')
Batch Loss: 68.41587901674211 Accuracy: tensor(0.9625, device='cuda:0')
Batch Loss: 77.10506263561547 Accuracy: tensor(0.9633, device='cuda:0')
Batch Loss: 88.01095296349376 Accuracy: tensor(0.9628, device='cuda:0')
Batch Loss: 97.69324190309271 Accuracy: tensor(0.9631, device='cuda:0')
Batch Loss: 108.85984550230205 Accuracy: tensor(0.9629, device='cuda:0')
Batch Loss: 118.53908379562199 Accuracy: tensor(0.9631, device='cuda:0')
Batch Loss: 129.04543758463115 Accuracy: tensor(0.9631, device='cuda:0')
Batch Loss: 139.9717459450476 Accuracy: tensor(0.9629, device

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 7.676516512874514 Accuracy: tensor(0.9712, device='cuda:0')
Batch Loss: 15.463673420716077 Accuracy: tensor(0.9725, device='cuda:0')
Batch Loss: 24.10750490287319 Accuracy: tensor(0.9712, device='cuda:0')
Batch Loss: 32.67543426249176 Accuracy: tensor(0.9705, device='cuda:0')
Batch Loss: 41.99038269696757 Accuracy: tensor(0.9693, device='cuda:0')
Batch Loss: 51.061777336057276 Accuracy: tensor(0.9691, device='cuda:0')
Batch Loss: 60.578331905882806 Accuracy: tensor(0.9687, device='cuda:0')
Batch Loss: 68.34740996127948 Accuracy: tensor(0.9693, device='cuda:0')
Batch Loss: 76.64745203405619 Accuracy: tensor(0.9693, device='cuda:0')
Batch Loss: 86.21135571133345 Accuracy: tensor(0.9688, device='cuda:0')
Batch Loss: 94.02854804741219 Accuracy: tensor(0.9689, device='cuda:0')
Batch Loss: 102.10191636206582 Accuracy: tensor(0.9691, device='cuda:0')
Batch Loss: 110.78136647073552 Accuracy: tensor(0.9694, device='cuda:0')
Batch Loss: 119.22136450465769 Accuracy: tensor(0.9692, dev

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 6.30653291195631 Accuracy: tensor(0.9772, device='cuda:0')
Batch Loss: 13.734396295621991 Accuracy: tensor(0.9745, device='cuda:0')
Batch Loss: 21.17748061195016 Accuracy: tensor(0.9742, device='cuda:0')
Batch Loss: 29.81048355670646 Accuracy: tensor(0.9728, device='cuda:0')
Batch Loss: 37.02354456437752 Accuracy: tensor(0.9739, device='cuda:0')
Batch Loss: 43.52482303837314 Accuracy: tensor(0.9744, device='cuda:0')
Batch Loss: 51.34143751207739 Accuracy: tensor(0.9745, device='cuda:0')
Batch Loss: 59.648391623515636 Accuracy: tensor(0.9739, device='cuda:0')
Batch Loss: 66.55396603606641 Accuracy: tensor(0.9740, device='cuda:0')
Batch Loss: 74.54700234392658 Accuracy: tensor(0.9739, device='cuda:0')
Batch Loss: 82.0293931835331 Accuracy: tensor(0.9739, device='cuda:0')
Batch Loss: 88.79006615234539 Accuracy: tensor(0.9743, device='cuda:0')
Batch Loss: 96.01115062413737 Accuracy: tensor(0.9743, device='cuda:0')
Batch Loss: 102.6349013200961 Accuracy: tensor(0.9743, device='c

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 5.668959596659988 Accuracy: tensor(0.9800, device='cuda:0')
Batch Loss: 13.161183793097734 Accuracy: tensor(0.9778, device='cuda:0')
Batch Loss: 19.60829220688902 Accuracy: tensor(0.9773, device='cuda:0')
Batch Loss: 25.452220852021128 Accuracy: tensor(0.9777, device='cuda:0')
Batch Loss: 32.02819190383889 Accuracy: tensor(0.9780, device='cuda:0')
Batch Loss: 38.30642444896512 Accuracy: tensor(0.9781, device='cuda:0')
Batch Loss: 43.63116546836682 Accuracy: tensor(0.9784, device='cuda:0')
Batch Loss: 50.9511094680056 Accuracy: tensor(0.9780, device='cuda:0')
Batch Loss: 57.36935027316213 Accuracy: tensor(0.9776, device='cuda:0')
Batch Loss: 63.231621930375695 Accuracy: tensor(0.9775, device='cuda:0')
Batch Loss: 68.97908560559154 Accuracy: tensor(0.9777, device='cuda:0')
Batch Loss: 74.99064212990925 Accuracy: tensor(0.9778, device='cuda:0')
Batch Loss: 81.09817321621813 Accuracy: tensor(0.9777, device='cuda:0')
Batch Loss: 86.82045673322864 Accuracy: tensor(0.9778, device=

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 5.3080667103640735 Accuracy: tensor(0.9822, device='cuda:0')
Batch Loss: 10.920164150185883 Accuracy: tensor(0.9806, device='cuda:0')
Batch Loss: 15.783259107149206 Accuracy: tensor(0.9816, device='cuda:0')
Batch Loss: 22.051758130430244 Accuracy: tensor(0.9807, device='cuda:0')
Batch Loss: 26.544867528020404 Accuracy: tensor(0.9813, device='cuda:0')
Batch Loss: 32.05962885718327 Accuracy: tensor(0.9815, device='cuda:0')
Batch Loss: 36.862273346981965 Accuracy: tensor(0.9817, device='cuda:0')
Batch Loss: 42.377487263060175 Accuracy: tensor(0.9815, device='cuda:0')
Batch Loss: 48.87342945334967 Accuracy: tensor(0.9809, device='cuda:0')
Batch Loss: 53.13689935998991 Accuracy: tensor(0.9812, device='cuda:0')
Batch Loss: 58.77646720549092 Accuracy: tensor(0.9811, device='cuda:0')
Batch Loss: 64.18061913759448 Accuracy: tensor(0.9810, device='cuda:0')
Batch Loss: 70.60847287788056 Accuracy: tensor(0.9805, device='cuda:0')
Batch Loss: 77.20111133973114 Accuracy: tensor(0.9802, de

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 3.9913830950390548 Accuracy: tensor(0.9869, device='cuda:0')
Batch Loss: 8.030174874467775 Accuracy: tensor(0.9861, device='cuda:0')
Batch Loss: 12.069319942849688 Accuracy: tensor(0.9871, device='cuda:0')
Batch Loss: 15.985364180174656 Accuracy: tensor(0.9865, device='cuda:0')
Batch Loss: 20.243903709109873 Accuracy: tensor(0.9864, device='cuda:0')
Batch Loss: 24.083190066739917 Accuracy: tensor(0.9865, device='cuda:0')
Batch Loss: 30.120840467046946 Accuracy: tensor(0.9852, device='cuda:0')
Batch Loss: 35.58226076164283 Accuracy: tensor(0.9847, device='cuda:0')
Batch Loss: 40.119979734765366 Accuracy: tensor(0.9847, device='cuda:0')
Batch Loss: 44.63678141776472 Accuracy: tensor(0.9848, device='cuda:0')
Batch Loss: 50.00637284130789 Accuracy: tensor(0.9844, device='cuda:0')
Batch Loss: 54.75937247928232 Accuracy: tensor(0.9843, device='cuda:0')
Batch Loss: 59.41897334437817 Accuracy: tensor(0.9844, device='cuda:0')
Batch Loss: 64.44670431152917 Accuracy: tensor(0.9843, de

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 4.411635089898482 Accuracy: tensor(0.9847, device='cuda:0')
Batch Loss: 8.224202864803374 Accuracy: tensor(0.9864, device='cuda:0')
Batch Loss: 11.483223307179287 Accuracy: tensor(0.9873, device='cuda:0')
Batch Loss: 14.509611072135158 Accuracy: tensor(0.9882, device='cuda:0')
Batch Loss: 20.287119517452084 Accuracy: tensor(0.9868, device='cuda:0')
Batch Loss: 23.70126739551779 Accuracy: tensor(0.9870, device='cuda:0')
Batch Loss: 27.978679188061506 Accuracy: tensor(0.9867, device='cuda:0')
Batch Loss: 31.6266204995336 Accuracy: tensor(0.9871, device='cuda:0')
Batch Loss: 34.92487795121269 Accuracy: tensor(0.9873, device='cuda:0')
Batch Loss: 39.6123342703213 Accuracy: tensor(0.9869, device='cuda:0')
Batch Loss: 43.57816403330071 Accuracy: tensor(0.9870, device='cuda:0')
Batch Loss: 47.91541258589132 Accuracy: tensor(0.9868, device='cuda:0')
Batch Loss: 52.27656163251959 Accuracy: tensor(0.9867, device='cuda:0')
Batch Loss: 56.35823129967321 Accuracy: tensor(0.9865, device=

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 3.7147040823474526 Accuracy: tensor(0.9887, device='cuda:0')
Batch Loss: 6.889583157724701 Accuracy: tensor(0.9887, device='cuda:0')
Batch Loss: 10.418890409055166 Accuracy: tensor(0.9892, device='cuda:0')
Batch Loss: 14.77187432977371 Accuracy: tensor(0.9884, device='cuda:0')
Batch Loss: 17.96980230358895 Accuracy: tensor(0.9886, device='cuda:0')
Batch Loss: 21.899248445464764 Accuracy: tensor(0.9883, device='cuda:0')
Batch Loss: 25.09784190816572 Accuracy: tensor(0.9887, device='cuda:0')
Batch Loss: 29.212725163437426 Accuracy: tensor(0.9882, device='cuda:0')
Batch Loss: 33.72087454027496 Accuracy: tensor(0.9878, device='cuda:0')
Batch Loss: 36.29259208322037 Accuracy: tensor(0.9882, device='cuda:0')
Batch Loss: 41.19307138235308 Accuracy: tensor(0.9879, device='cuda:0')
Batch Loss: 45.257671451894566 Accuracy: tensor(0.9878, device='cuda:0')
Batch Loss: 49.117028098145965 Accuracy: tensor(0.9877, device='cuda:0')
Batch Loss: 53.04118486662628 Accuracy: tensor(0.9875, dev

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 3.3678112803027034 Accuracy: tensor(0.9872, device='cuda:0')
Batch Loss: 6.367810256895609 Accuracy: tensor(0.9897, device='cuda:0')
Batch Loss: 9.582831495907158 Accuracy: tensor(0.9899, device='cuda:0')
Batch Loss: 13.043125318188686 Accuracy: tensor(0.9895, device='cuda:0')
Batch Loss: 17.015612320334185 Accuracy: tensor(0.9888, device='cuda:0')
Batch Loss: 20.172071869543288 Accuracy: tensor(0.9888, device='cuda:0')
Batch Loss: 22.832087979710195 Accuracy: tensor(0.9893, device='cuda:0')
Batch Loss: 25.847469176806044 Accuracy: tensor(0.9895, device='cuda:0')
Batch Loss: 28.389661122928374 Accuracy: tensor(0.9897, device='cuda:0')
Batch Loss: 32.814406353398226 Accuracy: tensor(0.9891, device='cuda:0')
Batch Loss: 35.985608700546436 Accuracy: tensor(0.9889, device='cuda:0')
Batch Loss: 38.35167692450341 Accuracy: tensor(0.9892, device='cuda:0')
Batch Loss: 42.27108741673874 Accuracy: tensor(0.9891, device='cuda:0')
Batch Loss: 46.043867758184206 Accuracy: tensor(0.9890,

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 2.3556867600418627 Accuracy: tensor(0.9906, device='cuda:0')
Batch Loss: 4.623940799443517 Accuracy: tensor(0.9916, device='cuda:0')
Batch Loss: 7.284510224708356 Accuracy: tensor(0.9913, device='cuda:0')
Batch Loss: 9.850865548942238 Accuracy: tensor(0.9912, device='cuda:0')
Batch Loss: 12.381938728445675 Accuracy: tensor(0.9914, device='cuda:0')
Batch Loss: 15.143216425378341 Accuracy: tensor(0.9913, device='cuda:0')
Batch Loss: 18.02858223876683 Accuracy: tensor(0.9912, device='cuda:0')
Batch Loss: 21.538263688096777 Accuracy: tensor(0.9910, device='cuda:0')
Batch Loss: 24.57500802516006 Accuracy: tensor(0.9909, device='cuda:0')
Batch Loss: 27.258251017308794 Accuracy: tensor(0.9908, device='cuda:0')
Batch Loss: 30.045051293389406 Accuracy: tensor(0.9909, device='cuda:0')
Batch Loss: 32.2794152770075 Accuracy: tensor(0.9910, device='cuda:0')
Batch Loss: 34.763291630195454 Accuracy: tensor(0.9910, device='cuda:0')
Batch Loss: 38.54822902486194 Accuracy: tensor(0.9908, dev

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 2.192491597379558 Accuracy: tensor(0.9916, device='cuda:0')
Batch Loss: 4.682262751390226 Accuracy: tensor(0.9920, device='cuda:0')
Batch Loss: 6.580956151825376 Accuracy: tensor(0.9921, device='cuda:0')
Batch Loss: 9.088579487171955 Accuracy: tensor(0.9922, device='cuda:0')
Batch Loss: 11.770760982530192 Accuracy: tensor(0.9919, device='cuda:0')
Batch Loss: 14.374441502266563 Accuracy: tensor(0.9919, device='cuda:0')
Batch Loss: 16.813685771019664 Accuracy: tensor(0.9918, device='cuda:0')
Batch Loss: 19.17411237981287 Accuracy: tensor(0.9917, device='cuda:0')
Batch Loss: 22.326697670127032 Accuracy: tensor(0.9915, device='cuda:0')
Batch Loss: 26.273814432119252 Accuracy: tensor(0.9910, device='cuda:0')
Batch Loss: 29.02026687338366 Accuracy: tensor(0.9908, device='cuda:0')
Batch Loss: 31.753997957828688 Accuracy: tensor(0.9909, device='cuda:0')
Batch Loss: 34.715837049501715 Accuracy: tensor(0.9909, device='cuda:0')
Batch Loss: 37.45010257695685 Accuracy: tensor(0.9909, de

  0%|          | 0/1407 [00:00<?, ?it/s]

Batch Loss: 2.3841634393902496 Accuracy: tensor(0.9934, device='cuda:0')
Batch Loss: 4.32438068970805 Accuracy: tensor(0.9931, device='cuda:0')
Batch Loss: 6.7434532135375775 Accuracy: tensor(0.9927, device='cuda:0')
Batch Loss: 8.799369528249372 Accuracy: tensor(0.9929, device='cuda:0')
Batch Loss: 11.009234417113476 Accuracy: tensor(0.9929, device='cuda:0')
Batch Loss: 12.915206524310634 Accuracy: tensor(0.9931, device='cuda:0')
Batch Loss: 15.589087247033603 Accuracy: tensor(0.9926, device='cuda:0')
Batch Loss: 17.93497456255136 Accuracy: tensor(0.9927, device='cuda:0')
Batch Loss: 20.539291825785767 Accuracy: tensor(0.9923, device='cuda:0')
Batch Loss: 23.15426414652029 Accuracy: tensor(0.9924, device='cuda:0')
Batch Loss: 25.767018212587573 Accuracy: tensor(0.9924, device='cuda:0')
Batch Loss: 28.309997025702614 Accuracy: tensor(0.9922, device='cuda:0')
Batch Loss: 30.99888444028329 Accuracy: tensor(0.9921, device='cuda:0')
Batch Loss: 33.47585130634252 Accuracy: tensor(0.9920, de

In [10]:
# 모델 저장하기

torch.save(model.state_dict(), "model_electra_small.pt")

In [11]:
# 모델 불러오기

model.load_state_dict(torch.load("model_electra_small.pt"))

<All keys matched successfully>

In [12]:
# model.eval()

test_correct = 0
test_total = 0

for input_ids_batch, attention_masks_batch, y_batch in tqdm(valid_loader):
  y_batch = y_batch.to(device)
  y_pred = model(input_ids_batch.to(device), attention_mask=attention_masks_batch.to(device))[0]
  _, predicted = torch.max(y_pred, 1)
  test_correct += (predicted == y_batch).sum()
  test_total += len(y_batch)

print("Accuracy:", test_correct.float() / test_total)

  0%|          | 0/469 [00:00<?, ?it/s]

Accuracy: tensor(0.9134, device='cuda:0')
