In [1]:
from transformers import GPT2Model
from transformers import GPT2LMHeadModel
from transformers import PreTrainedTokenizerFast
import torch
from torch.utils.data import Dataset, DataLoader
import urllib
import pandas as pd
from zeus.monitor import ZeusMonitor

In [2]:
def get_naver_review_examples():
    urllib.request.urlretrieve("https://raw.githubusercontent.com/e9t/nsmc/master/ratings_test.txt", filename="ratings_test.txt")

    test_data = pd.read_table('ratings_test.txt')
    
    return test_data

In [3]:
class NaverReviewDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
          text,
          add_special_tokens=True,
          max_length=self.max_len,
          return_token_type_ids=False,
          padding='max_length',
          return_attention_mask=True,
          return_tensors='pt',
          truncation=True,
        )

        return {
          'text': text,
          'input_ids': encoding['input_ids'].flatten(),
          'attention_mask': encoding['attention_mask'].flatten(),
          'labels': torch.tensor(label, dtype=torch.long)
        }
    
    def __len__(self):
        return len(self.texts)

In [4]:
tokenizer = PreTrainedTokenizerFast.from_pretrained("skt/kogpt2-base-v2", bos_token='</s>', eos_token='</s>', unk_token='<unk>', pad_token='<pad>', mask_token='<mask>', padding_side='left') 

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT2Tokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [5]:
batch_size = 8

naver_data = get_naver_review_examples()

dataset = NaverReviewDataset(naver_data['document'], naver_data['label'], tokenizer, 100)
train_set, valid_set, test_set = torch.utils.data.random_split(dataset, [40000, 5000, 5000])

train_dataloader = DataLoader(train_set, batch_size=batch_size,shuffle=True)

valid_dataloader = DataLoader(valid_set, batch_size=batch_size,shuffle=True)

test_dataloader = DataLoader(test_set, batch_size=batch_size,shuffle=True)

In [6]:
class GPT2SentimentClassifier(torch.nn.Module):

    def __init__(self, n_classes):
        super(GPT2SentimentClassifier, self).__init__()

        self.gpt_model = GPT2Model.from_pretrained('skt/kogpt2-base-v2')
        self.drop = torch.nn.Dropout(p=0.1)
        self.out = torch.nn.Linear(self.gpt_model.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        hidden_states = self.gpt_model(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        last_hidden_state = hidden_states[0]
        
        output = self.drop(last_hidden_state[:, -1, :])

        return self.out(output)

In [16]:
gpt_clf = GPT2SentimentClassifier(n_classes=1)
gpt_clf.train()

learning_rate = 5e-5
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(gpt_clf.parameters(), lr=learning_rate)

device = 'cuda'

epochs = 11
count = 0

In [17]:
def cal_correct_num(predicts, labels):
    predicts_ = predicts >= 0.5
    correct_num = torch.sum(predicts_ == labels)
        
    return correct_num

In [18]:
def custom_loader(dataset,max_batch_size,shuffle=True):
    if batch_size>max_batch_size:
        maximized=True
        return DataLoader(dataset,batch_size=batch_size-1,shuffle=True)
    return DataLoader(dataset, batch_size=batch_size,shuffle=True)

In [19]:
tot_train_loss = 0.0
tot_valid_loss = 0.0

train_correct_num = 0
valid_correct_num = 0

prev_valid_loss = 10000
traindone=False
print('KoGPT-2 Training Start!')
monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
#plo = GlobalPowerLimitOptimizer(monitor)
batch_size=4
batch_effect={}
maximized=False
for epoch in range(epochs):
    monitor.begin_window("epoch")
    #plo.on_epoch_begin()
    measurements = []
    print(f'current batch size: {batch_size}')
    train_dataloader = custom_loader(train_set, max_batch_size=14,shuffle=True)
    valid_dataloader = custom_loader(valid_set, max_batch_size=14,shuffle=True)
    test_dataloader = custom_loader(test_set, max_batch_size=14,shuffle=True)
    for batch, train_data in enumerate(train_dataloader):
        monitor.begin_window("step")
        #plo.on_step_begin()
        gpt_clf.to(device)
        train_inputs = train_data['input_ids'].to(device)
        train_masks = train_data['attention_mask'].to(device)
        train_labels = train_data['labels'].to(device)
        
        train_outputs = gpt_clf(train_inputs, train_masks)
        
        train_loss = criterion(train_outputs.view(-1), train_labels.float())
        
            
        valid_data = next(iter(valid_dataloader))

        gpt_clf.to(device)
        valid_inputs = valid_data['input_ids'].to(device)    
        valid_masks = valid_data['attention_mask'].to(device)
        valid_labels = valid_data['labels'].to(device)
        
        valid_outputs = gpt_clf(valid_inputs, valid_masks)
        
        valid_loss = criterion(valid_outputs.view(-1), valid_labels.float())
        
        gpt_clf.to(device)
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        tot_train_loss += train_loss.item()
        tot_valid_loss += valid_loss.item()
        
        train_correct_num += cal_correct_num(torch.sigmoid(train_outputs.view(-1)), train_labels.float())
        valid_correct_num += cal_correct_num(torch.sigmoid(valid_outputs.view(-1)), valid_labels.float())
        result = monitor.end_window("step")
        #plo.on_step_end()
        measurements.append(result)
        if count % 200 == 0:
            cnt = ((count+1) * batch_size)
            current_train_loss = tot_train_loss / cnt
            current_valid_loss = tot_valid_loss / cnt
            
            train_acc = train_correct_num / cnt
            valid_acc = valid_correct_num / cnt
            
            #print(f'epoch : %5d | batch : %5d | train_loss : %.5f | valid_loss : %.5f | train_acc : %.5f | valid_acc : %.5f' %(epoch+1, batch+1, current_train_loss, current_valid_loss, train_acc, valid_acc))
            
            tot_train_loss = 0.0
            tot_valid_loss = 0.0
            
            train_correct_num = 0
            valid_correct_num = 0
            
            count = 0
            
            if prev_valid_loss > current_valid_loss:
                prev_valid_loss = current_valid_loss
                torch.save(gpt_clf.state_dict(), f'./KoGPT-Classifier-model.pth')
            if valid_acc>0.9:
                traindone=True
                break
        count += 1
    eres = monitor.end_window("epoch")
    print(f"Epoch {epoch} consumed {eres.time} s and {eres.total_energy} J.")
    #plo.on_epoch_end()

    avg_time = sum(map(lambda m: m.time, measurements)) / len(measurements)
    avg_energy = sum(map(lambda m: m.total_energy, measurements)) / len(measurements)
    batch_effect[batch_size]=eres.total_energy
    print(f"One step took {avg_time} s and {avg_energy} J on average.")
    if traindone==True:
        break
    if maximized ==False:
        batch_size+=1
    else:
        batch_size=min(batch_effect,key=batch_effect.get)
print(f'optimal batch size: {batch_size}')

KoGPT-2 Training Start!
[2023-12-12 20:51:56,387] [zeus.monitor.energy](energy.py:157) Monitoring GPU indices [0].
current batch size: 4
Epoch 0 consumed 1056.1363623142242 s and 86853.98999999999 J.
One step took 0.10301281726360322 s and 8.496719100000005 J on average.
current batch size: 5
Epoch 1 consumed 921.3956432342529 s and 76666.826 J.
One step took 0.11245977991819382 s and 9.381312374999958 J on average.
current batch size: 6
Epoch 2 consumed 866.1784181594849 s and 70673.657 J.
One step took 0.12727255339884982 s and 10.414939253037264 J on average.
current batch size: 7
Epoch 3 consumed 843.7583587169647 s and 68242.13100000005 J.
One step took 0.14472319197466993 s and 11.69603797025359 J on average.
current batch size: 8
Epoch 4 consumed 760.541571855545 s and 65775.38 J.
One step took 0.14906626534461975 s and 12.930494799999963 J on average.
current batch size: 9
Epoch 5 consumed 721.1694252490997 s and 63994.64199999999 J.
One step took 0.15941453639782455 s and 14.1