In [1]:
from utils.utils import *
from load_data import load_data
from trainer import trainer

import torch.optim as optim
import torch.nn as nn
import torch

from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
)



In [2]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import ast
class IMDB(Dataset):
    def __init__(self, file) -> None:
        super().__init__()
        file = open(file, 'r', encoding='utf-8')
        data = pd.read_csv(file, sep='\t')
        columns = ['input_ids', 'input_mask', 'input_type_ids', 'label_ids']
        self.tensors = [torch.tensor(data[c].apply(lambda x: ast.literal_eval(x)), dtype=torch.long) \
                            for c in columns[:-1]]
        self.tensors.append(torch.tensor(data[columns[-1]]))
        file.close()
    
    def __len__(self):
        return self.tensors[0].size(0)
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)
    

In [42]:
file ='data/imdb_sup_test.txt'
dataset = IMDB(file)
dataiter = DataLoader(dataset, batch_size=32, shuffle=True)

valid_file ='data/sani.txt'
valid_dataset = IMDB(file)
valid_dataiter = DataLoader(dataset, batch_size=32, shuffle=True)

In [43]:
# load model
device = torch.device('cuda')
config = AutoConfig.from_pretrained("bert-base-uncased")
#print(config)
model = AutoModelForSequenceClassification.from_config(config=config).to(device)
#model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased').to(device)

In [44]:
optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

In [41]:
# test셋은 제대로 학습하는 지 체크...
import logging

model.train()
logging.basicConfig(level=logging.INFO)
acc = True
total_cor = 0
total_try = 0
epochs = 100
loss = {'sup': []}

for n in range(epochs):
    v_acc = 0
    for step, batch in enumerate(dataiter):
        model.train()
        optimizer.zero_grad()
        sup_input_ids, sup_input_mask, sup_input_type_ids, label_ids = (t.to(device) for t in batch)
        sup_outputs = model(sup_input_ids, sup_input_mask, sup_input_type_ids)
        sup_loss = criterion(sup_outputs.logits, label_ids)
        sup_loss.backward()
        optimizer.step()
        
        ## logging
        loss['sup'].append(sup_loss)
        pred = torch.argmax(sup_outputs.logits,dim=1)
        acc = torch.sum(pred==label_ids)/len(label_ids)
        total_cor += torch.sum(pred==label_ids)
        total_try += len(label_ids)
        if step % 4 == 3:
            logging.info(f'Currnent   Step: {step+1}/{len(dataiter)} of epoch {n+1}/{epochs}')
            logging.info(f'Currnent  Loss : {sup_loss}')
            logging.info(f'Currnent   Acc : {acc :.3f}')
            logging.info(f'Total      Acc : {total_cor/total_try :.3f}')
    logging.info(f'Valid     ON : {total_cor/total_try :.3f}')
    for v_step, v_batch in enumerate(valid_dataiter):
        
        model.eval()
        sup_input_ids, sup_input_mask, sup_input_type_ids, label_ids = (t.to(device) for t in v_batch)
        sup_outputs = model(sup_input_ids, sup_input_mask, sup_input_type_ids)
        v_pred = torch.argmax(sup_outputs.logits,dim=1)
        v_acc += torch.sum(pred==label_ids)/len(label_ids)
    logging.info(f'Valid   Acc : {v_acc/len(valid_dataset) :.3f}')


INFO:root:Currnent   Step: 1/782 of epoch 1/100
INFO:root:Currnent  Loss : 2.405747652053833
INFO:root:Currnent  Pred : tensor([1, 1, 1, 0, 1, 0, 1, 1], device='cuda:0')
INFO:root:Currnent   Acc : 0.438
INFO:root:Total      Acc : 0.438
INFO:root:Currnent   Step: 2/782 of epoch 1/100
INFO:root:Currnent  Loss : 2.646862268447876
INFO:root:Currnent  Pred : tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
INFO:root:Currnent   Acc : 0.656
INFO:root:Total      Acc : 0.547
INFO:root:Currnent   Step: 3/782 of epoch 1/100
INFO:root:Currnent  Loss : 4.196529865264893
INFO:root:Currnent  Pred : tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
INFO:root:Currnent   Acc : 0.438
INFO:root:Total      Acc : 0.510
INFO:root:Currnent   Step: 4/782 of epoch 1/100
INFO:root:Currnent  Loss : 2.573439598083496
INFO:root:Currnent  Pred : tensor([0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')
INFO:root:Currnent   Acc : 0.625
INFO:root:Total      Acc : 0.539
INFO:root:Currnent   Step: 5/782 of epoch 1/100
INFO:roo

KeyboardInterrupt: 

In [None]:
from utils.utils import save_fig
fig_PATH = "sani/results/"
save_fig(loss, fig_PATH, True)