In [1]:
# imdb  영화평 - 텍스트

In [2]:
import random
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.metrics import f1_score, confusion_matrix
from torchsummary import summary
from tensorflow.keras.datasets import imdb


In [3]:
# 하이퍼 파라메터
CFG = {
    'MY_NUM':10000,             #  사전의 최대 단어수
    'MY_LEN':80,                #  각 문장의 단어수
    'MY_EMBED':30,              #  단어 임베딩의 출력 차원
    'MY_HIDDEN':100,            #  LSTM 은닉층
    'MY_EPOCHS':200,
    'SEED':41
}

In [4]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

In [5]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [6]:
(x_train,y_train),(x_test,y_test) =  imdb.load_data(num_words=CFG['MY_NUM'])  

In [7]:
x_train.shape,y_train.shape

((25000,), (25000,))

In [8]:
# 0 부정  1 긍정
print(len(x_train[0]), y_train[0])


218 1


In [9]:
for i in range(10):
  print(len(x_train[i]), y_train[i])

218 1
189 0
141 0
550 1
147 0
43 0
123 1
562 0
233 1
130 0


In [10]:
#  단어 -> 숫자
word_to_id = imdb.get_word_index()
print(f"총 단어수 : {len(word_to_id)}")
print(f"단어 the 는 {word_to_id['the']} 번째 수")

총 단어수 : 88584
단어 the 는 1 번째 수


In [11]:
# 숫자된 영화평을 문장으로 전환
# 원본.. 특수문자 3개를 처리하기위해서 각 단어에 3을 더한 값
# 숫자 - > 단어
id_to_word = { value-3:key for key,value in  word_to_id.items()}

In [12]:
temp = ''
for id in x_train[1]:
  temp += id_to_word[id]+' '
temp  

'of between impressive between respond people minutes in it henry stereotype domino go all br bring this know favourite movie often are with off previous enough films in point acting alike his movie we big this they he i br woman man too br over that grow funniest in change br did that her awarded br feel that short pieces br know that br scene 1984 you br ideas in then an plenty good just br becomes that what some about flying candy perfect br jason all corman in find as colorful to br impressive that between effort i respond to wrong dean old career to melbourne but your still than to corman but wrong few buffs think in br minutes that some to tone but instead ever in ever down as whole want something up that similarly minutes respond in to talking day roy in br raft while 10 tries not to so addressed images movie that it remember people have english can loves that this again though he br mid but film br spoiler in by it time another seemed from then people year film though any '

In [13]:
# 데이터의 길이를 통일
X_train =  pad_sequences(
    x_train,
    maxlen = CFG['MY_LEN']
)
X_test =  pad_sequences(
    x_test,
    maxlen = CFG['MY_LEN']
)

In [14]:
# 문장길이 통일 확인
for i in range(10):
  print(len(X_train[i]), y_train[i])

80 1
80 0
80 0
80 1
80 0
80 0
80 1
80 0
80 1
80 0


In [15]:
X_train.shape

(25000, 80)

In [16]:
# LSTM 을 위한 클래스 설계 - 모델
# 토치에서 모델을 상속받아서 설계.. 
# 생성자에서 각 층을 객체로 만들고.. forward에서 조립  이때 forward는 우리가 호출.. 모델에서 자동으로 training 할때 호출되는
# callback 함수
class MyLstm(nn.Module):
    def __init__(self):
      super(MyLstm, self).__init__()
      self.embedding = nn.Embedding(CFG['MY_NUM'], CFG['MY_EMBED'])
      self.lstm = nn.LSTM(CFG['MY_EMBED'], CFG['MY_HIDDEN'],batch_first=True)
      self.linear = nn.Linear(CFG['MY_HIDDEN'], 1)
      self.sigmoid = nn.Sigmoid()
    def forward(self,x):
      x = self.embedding(x)
      lstm_out,_ =  self.lstm(x)

      # lstm의 최종 은닉층 3차원 출력
      x = lstm_out[:,-1,:]
      x = self.linear(x)
      x = self.sigmoid(x)
      return x

In [17]:
model = MyLstm()
x = torch.zeros((80,1),dtype = torch.long).to(device)
model.to(device)
# summary(model,x)

MyLstm(
  (embedding): Embedding(10000, 30)
  (lstm): LSTM(30, 100, batch_first=True)
  (linear): Linear(in_features=100, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-2)
criterion = nn.BCELoss().to(device)

In [19]:
# numpy -> torch.tensor
X_train = torch.LongTensor(X_train)
y_train = torch.FloatTensor(y_train)
type(X_train), type(y_train)

(torch.Tensor, torch.Tensor)

In [20]:
X_train = X_train.to(device)

In [21]:
# 데이터 셋 생성
class MyDataset(Dataset):
  def __init__(self,x,y):
    super(MyDataset, self).__init__()
    self.x = x
    self.y = y
  def __getitem__(self, index):
    return self.x[index], self.y[index]
  def __len__(self):
    return self.x.shape[0]

X_train_dataset = MyDataset(X_train,y_train)

In [22]:
# gpu를 사용할때는 미니배치개념이 중요함... 
# 데이터의 수에 따라서 할당되는 gup 메모리가 달라짐
X_train_loader = DataLoader(X_train_dataset, batch_size = 64, shuffle=True, num_workers=0)
# X_train_dataset = MyDataset(X_train,y_train)

In [23]:
# 훈련
from tqdm import tqdm
best_lowest_loss = 100
best_model = None
for epoch in range(CFG['MY_EPOCHS']):
  model.train()    
  train_loss = []
  for data,target in tqdm(X_train_loader):
    output = model(data)
    # (00000, 1)  - > (00000)
    output = torch.squeeze(output)      
    output = output.to(device)      
    target = target.to(device)
    loss = criterion(output, target)    
    train_loss.append(loss.to('cpu').detach().numpy() )
    # 역전파
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()          
  print(f"epoch:{epoch} mean loss:{np.mean(train_loss)} best_lowest_loss : {best_lowest_loss} ")  
  if best_lowest_loss > np.mean(train_loss):
    best_lowest_loss = np.mean(train_loss)    
    best_model = model
    print("find best model")

100%|██████████| 391/391 [00:02<00:00, 186.70it/s]


epoch:0 mean loss:0.5363430976867676 best_lowest_loss : 100 
find best model


100%|██████████| 391/391 [00:01<00:00, 241.44it/s]


epoch:1 mean loss:0.3315993845462799 best_lowest_loss : 0.5363430976867676 
find best model


100%|██████████| 391/391 [00:01<00:00, 236.69it/s]


epoch:2 mean loss:0.2383839637041092 best_lowest_loss : 0.3315993845462799 
find best model


100%|██████████| 391/391 [00:01<00:00, 236.52it/s]


epoch:3 mean loss:0.16756398975849152 best_lowest_loss : 0.2383839637041092 
find best model


100%|██████████| 391/391 [00:01<00:00, 236.25it/s]


epoch:4 mean loss:0.11930021643638611 best_lowest_loss : 0.16756398975849152 
find best model


100%|██████████| 391/391 [00:01<00:00, 237.54it/s]


epoch:5 mean loss:0.10068994015455246 best_lowest_loss : 0.11930021643638611 
find best model


100%|██████████| 391/391 [00:01<00:00, 242.31it/s]


epoch:6 mean loss:0.0878651812672615 best_lowest_loss : 0.10068994015455246 
find best model


100%|██████████| 391/391 [00:01<00:00, 239.36it/s]


epoch:7 mean loss:0.07555929571390152 best_lowest_loss : 0.0878651812672615 
find best model


100%|██████████| 391/391 [00:01<00:00, 239.66it/s]


epoch:8 mean loss:0.07400677353143692 best_lowest_loss : 0.07555929571390152 
find best model


100%|██████████| 391/391 [00:01<00:00, 240.20it/s]


epoch:9 mean loss:0.05868653208017349 best_lowest_loss : 0.07400677353143692 
find best model


100%|██████████| 391/391 [00:01<00:00, 236.46it/s]


epoch:10 mean loss:0.06079646944999695 best_lowest_loss : 0.05868653208017349 


100%|██████████| 391/391 [00:01<00:00, 240.10it/s]


epoch:11 mean loss:0.06926116347312927 best_lowest_loss : 0.05868653208017349 


100%|██████████| 391/391 [00:01<00:00, 239.77it/s]


epoch:12 mean loss:0.05426999554038048 best_lowest_loss : 0.05868653208017349 
find best model


100%|██████████| 391/391 [00:01<00:00, 238.95it/s]


epoch:13 mean loss:0.059562742710113525 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 203.15it/s]


epoch:14 mean loss:0.06917260587215424 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:02<00:00, 195.45it/s]


epoch:15 mean loss:0.09452667832374573 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 196.23it/s]


epoch:16 mean loss:0.09758453816175461 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.68it/s]


epoch:17 mean loss:0.08271224051713943 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.69it/s]


epoch:18 mean loss:0.08289337158203125 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.42it/s]


epoch:19 mean loss:0.0803719237446785 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.38it/s]


epoch:20 mean loss:0.07304444909095764 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.65it/s]


epoch:21 mean loss:0.07203034311532974 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.16it/s]


epoch:22 mean loss:0.08729254454374313 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.54it/s]


epoch:23 mean loss:0.09811209887266159 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 244.31it/s]


epoch:24 mean loss:0.09763257205486298 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.30it/s]


epoch:25 mean loss:0.08079967647790909 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.24it/s]


epoch:26 mean loss:0.0884014442563057 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.21it/s]


epoch:27 mean loss:0.11181827634572983 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.76it/s]


epoch:28 mean loss:0.12731683254241943 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.09it/s]


epoch:29 mean loss:0.16946037113666534 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.77it/s]


epoch:30 mean loss:0.20301435887813568 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.42it/s]


epoch:31 mean loss:0.25496914982795715 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.01it/s]


epoch:32 mean loss:0.28326216340065 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.85it/s]


epoch:33 mean loss:0.31050944328308105 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.65it/s]


epoch:34 mean loss:0.3228571116924286 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 209.34it/s]


epoch:35 mean loss:0.31261470913887024 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 242.32it/s]


epoch:36 mean loss:0.3186124861240387 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.34it/s]


epoch:37 mean loss:0.3373197913169861 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.90it/s]


epoch:38 mean loss:0.34488698840141296 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.86it/s]


epoch:39 mean loss:0.35251283645629883 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.15it/s]


epoch:40 mean loss:0.348102331161499 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.80it/s]


epoch:41 mean loss:0.3861736059188843 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.52it/s]


epoch:42 mean loss:0.42289191484451294 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.88it/s]


epoch:43 mean loss:0.4617714583873749 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.34it/s]


epoch:44 mean loss:0.48037368059158325 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.35it/s]


epoch:45 mean loss:0.5114127397537231 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.83it/s]


epoch:46 mean loss:0.5331630110740662 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.66it/s]


epoch:47 mean loss:0.5443791747093201 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.86it/s]


epoch:48 mean loss:0.5428581237792969 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.79it/s]


epoch:49 mean loss:0.5443335175514221 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.84it/s]


epoch:50 mean loss:0.5413349866867065 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.77it/s]


epoch:51 mean loss:0.5474277138710022 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.36it/s]


epoch:52 mean loss:0.5522992014884949 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.67it/s]


epoch:53 mean loss:0.5573684573173523 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.45it/s]


epoch:54 mean loss:0.555294930934906 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 231.25it/s]


epoch:55 mean loss:0.5606123208999634 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.44it/s]


epoch:56 mean loss:0.5591400861740112 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.46it/s]


epoch:57 mean loss:0.5573805570602417 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.65it/s]


epoch:58 mean loss:0.5626792907714844 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.58it/s]


epoch:59 mean loss:0.5627918243408203 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.35it/s]


epoch:60 mean loss:0.564849853515625 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.65it/s]


epoch:61 mean loss:0.5674629807472229 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.69it/s]


epoch:62 mean loss:0.5688655376434326 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.82it/s]


epoch:63 mean loss:0.570963442325592 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.05it/s]


epoch:64 mean loss:0.5730898380279541 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.63it/s]


epoch:65 mean loss:0.5752657055854797 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.76it/s]


epoch:66 mean loss:0.5842142105102539 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.61it/s]


epoch:67 mean loss:0.5917394757270813 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.56it/s]


epoch:68 mean loss:0.5934751629829407 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.47it/s]


epoch:69 mean loss:0.5958806276321411 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.94it/s]


epoch:70 mean loss:0.5974552035331726 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 230.26it/s]


epoch:71 mean loss:0.6014338731765747 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.63it/s]


epoch:72 mean loss:0.599331259727478 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.12it/s]


epoch:73 mean loss:0.5984025001525879 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.81it/s]


epoch:74 mean loss:0.5989043712615967 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 243.42it/s]


epoch:75 mean loss:0.6035981178283691 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.43it/s]


epoch:76 mean loss:0.60789954662323 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.11it/s]


epoch:77 mean loss:0.608431339263916 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.55it/s]


epoch:78 mean loss:0.6076055765151978 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.26it/s]


epoch:79 mean loss:0.6061476469039917 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.96it/s]


epoch:80 mean loss:0.6073417067527771 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.80it/s]


epoch:81 mean loss:0.6101340651512146 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.66it/s]


epoch:82 mean loss:0.6115065813064575 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.33it/s]


epoch:83 mean loss:0.6103021502494812 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.27it/s]


epoch:84 mean loss:0.6088371276855469 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.86it/s]


epoch:85 mean loss:0.6107240915298462 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.10it/s]


epoch:86 mean loss:0.6119206547737122 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.40it/s]


epoch:87 mean loss:0.6103634834289551 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 227.03it/s]


epoch:88 mean loss:0.6119706034660339 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 231.74it/s]


epoch:89 mean loss:0.6102657318115234 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.29it/s]


epoch:90 mean loss:0.6116487383842468 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.05it/s]


epoch:91 mean loss:0.6127033233642578 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.02it/s]


epoch:92 mean loss:0.6129622459411621 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.84it/s]


epoch:93 mean loss:0.6097033023834229 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.98it/s]


epoch:94 mean loss:0.6098147630691528 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.62it/s]


epoch:95 mean loss:0.608677089214325 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.10it/s]


epoch:96 mean loss:0.6069501638412476 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.89it/s]


epoch:97 mean loss:0.6076716184616089 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.65it/s]


epoch:98 mean loss:0.6060487031936646 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.65it/s]


epoch:99 mean loss:0.6056524515151978 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.27it/s]


epoch:100 mean loss:0.6066384315490723 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.54it/s]


epoch:101 mean loss:0.6070865988731384 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.69it/s]


epoch:102 mean loss:0.6078152656555176 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.06it/s]


epoch:103 mean loss:0.6102375984191895 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.63it/s]


epoch:104 mean loss:0.6121296286582947 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.59it/s]


epoch:105 mean loss:0.6111844182014465 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.27it/s]


epoch:106 mean loss:0.6127248406410217 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.26it/s]


epoch:107 mean loss:0.6114368438720703 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.28it/s]


epoch:108 mean loss:0.6118085384368896 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.94it/s]


epoch:109 mean loss:0.6141446232795715 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.36it/s]


epoch:110 mean loss:0.6144466996192932 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.44it/s]


epoch:111 mean loss:0.613489031791687 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.00it/s]


epoch:112 mean loss:0.6120343208312988 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.50it/s]


epoch:113 mean loss:0.6119363903999329 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.45it/s]


epoch:114 mean loss:0.6120277047157288 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.91it/s]


epoch:115 mean loss:0.6140310764312744 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.99it/s]


epoch:116 mean loss:0.6141181588172913 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.94it/s]


epoch:117 mean loss:0.6136948466300964 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.34it/s]


epoch:118 mean loss:0.6112073659896851 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.18it/s]


epoch:119 mean loss:0.6121492981910706 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.46it/s]


epoch:120 mean loss:0.6108007431030273 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.46it/s]


epoch:121 mean loss:0.6120307445526123 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.43it/s]


epoch:122 mean loss:0.6122766137123108 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.47it/s]


epoch:123 mean loss:0.6075945496559143 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.38it/s]


epoch:124 mean loss:0.6056029796600342 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.25it/s]


epoch:125 mean loss:0.6070494055747986 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 230.89it/s]


epoch:126 mean loss:0.6078812479972839 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.95it/s]


epoch:127 mean loss:0.6071263551712036 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.16it/s]


epoch:128 mean loss:0.6051967740058899 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.14it/s]


epoch:129 mean loss:0.6057068705558777 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.35it/s]


epoch:130 mean loss:0.6083521842956543 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.69it/s]


epoch:131 mean loss:0.6096652150154114 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.01it/s]


epoch:132 mean loss:0.6113539338111877 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.73it/s]


epoch:133 mean loss:0.611762523651123 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 240.92it/s]


epoch:134 mean loss:0.6105018258094788 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.34it/s]


epoch:135 mean loss:0.6131231784820557 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.60it/s]


epoch:136 mean loss:0.6162039041519165 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.80it/s]


epoch:137 mean loss:0.6139471530914307 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:02<00:00, 188.86it/s]


epoch:138 mean loss:0.6133155822753906 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:02<00:00, 168.35it/s]


epoch:139 mean loss:0.6142781972885132 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 219.35it/s]


epoch:140 mean loss:0.6142399311065674 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 231.99it/s]


epoch:141 mean loss:0.6122880578041077 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.39it/s]


epoch:142 mean loss:0.611605703830719 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.05it/s]


epoch:143 mean loss:0.6165420413017273 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.81it/s]


epoch:144 mean loss:0.6169773936271667 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.78it/s]


epoch:145 mean loss:0.6155837178230286 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.73it/s]


epoch:146 mean loss:0.6155028939247131 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.37it/s]


epoch:147 mean loss:0.6180755496025085 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.51it/s]


epoch:148 mean loss:0.6200528144836426 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.85it/s]


epoch:149 mean loss:0.6225769519805908 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.16it/s]


epoch:150 mean loss:0.6214280128479004 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.63it/s]


epoch:151 mean loss:0.6243590116500854 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.43it/s]


epoch:152 mean loss:0.6234400868415833 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.05it/s]


epoch:153 mean loss:0.6234668493270874 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.59it/s]


epoch:154 mean loss:0.6235256195068359 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.11it/s]


epoch:155 mean loss:0.624459445476532 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 232.36it/s]


epoch:156 mean loss:0.6237500905990601 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 231.41it/s]


epoch:157 mean loss:0.6260935068130493 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.25it/s]


epoch:158 mean loss:0.6270498037338257 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.11it/s]


epoch:159 mean loss:0.6246073842048645 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.23it/s]


epoch:160 mean loss:0.6269392967224121 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.56it/s]


epoch:161 mean loss:0.6255053877830505 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 229.89it/s]


epoch:162 mean loss:0.6245137453079224 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.23it/s]


epoch:163 mean loss:0.6268295049667358 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.13it/s]


epoch:164 mean loss:0.6274781227111816 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.13it/s]


epoch:165 mean loss:0.628315269947052 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.64it/s]


epoch:166 mean loss:0.6268520355224609 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.41it/s]


epoch:167 mean loss:0.6266345381736755 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.75it/s]


epoch:168 mean loss:0.6254844665527344 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.61it/s]


epoch:169 mean loss:0.6257376670837402 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.95it/s]


epoch:170 mean loss:0.6255725622177124 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.67it/s]


epoch:171 mean loss:0.6261512637138367 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.19it/s]


epoch:172 mean loss:0.6265878677368164 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 243.41it/s]


epoch:173 mean loss:0.6260915398597717 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.39it/s]


epoch:174 mean loss:0.625899076461792 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.62it/s]


epoch:175 mean loss:0.6262540221214294 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.63it/s]


epoch:176 mean loss:0.6268375515937805 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.97it/s]


epoch:177 mean loss:0.6296141743659973 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.66it/s]


epoch:178 mean loss:0.6290524005889893 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.40it/s]


epoch:179 mean loss:0.6301996111869812 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.60it/s]


epoch:180 mean loss:0.6314772367477417 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 233.48it/s]


epoch:181 mean loss:0.6342065334320068 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.96it/s]


epoch:182 mean loss:0.6351812481880188 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.71it/s]


epoch:183 mean loss:0.6333587169647217 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.12it/s]


epoch:184 mean loss:0.6338216066360474 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.59it/s]


epoch:185 mean loss:0.6334739923477173 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 236.13it/s]


epoch:186 mean loss:0.6329032182693481 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.92it/s]


epoch:187 mean loss:0.632870078086853 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 241.68it/s]


epoch:188 mean loss:0.6306843757629395 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.57it/s]


epoch:189 mean loss:0.6300164461135864 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.70it/s]


epoch:190 mean loss:0.6311831474304199 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.70it/s]


epoch:191 mean loss:0.6294879913330078 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.63it/s]


epoch:192 mean loss:0.6280730366706848 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 234.21it/s]


epoch:193 mean loss:0.6276441812515259 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 237.37it/s]


epoch:194 mean loss:0.6289032101631165 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.75it/s]


epoch:195 mean loss:0.6300305724143982 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 238.61it/s]


epoch:196 mean loss:0.6319049596786499 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.05it/s]


epoch:197 mean loss:0.6308540105819702 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 239.20it/s]


epoch:198 mean loss:0.6315211057662964 best_lowest_loss : 0.05426999554038048 


100%|██████████| 391/391 [00:01<00:00, 235.11it/s]

epoch:199 mean loss:0.6333111524581909 best_lowest_loss : 0.05426999554038048 





In [25]:
# 평가를위해서 평가용데이터 torch변환(long)
X_test = torch.LongTensor(X_test)
y_test = torch.FloatTensor(y_test)

In [30]:
# batch 로 전환
X_test_dataset = MyDataset(X_test,y_test)
X_test_loader = DataLoader(X_test_dataset, batch_size = 64, num_workers=0)

In [40]:
threshold = 0.5
f1_score_list = []
with torch.no_grad():
  best_model.eval()  
  for data,target in tqdm(X_test_loader):
    data = data.to(device)
    # target = target.to(device)

    pred = best_model(data)
    pred = torch.squeeze(pred)
    pred = (pred > threshold).float()    
    pred = pred.to('cpu').detach().numpy()    
    f1_score_list.append(f1_score(target,pred))

100%|██████████| 391/391 [00:01<00:00, 357.00it/s]


In [41]:
np.mean(f1_score_list)

0.5865803725438976