In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pytorch_transformers import BertTokenizer, BertForSequenceClassification
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
train_df = pd.read_csv('D:/kim/kim/DP/data/training.txt', sep = "\t")
valid_df = pd.read_csv('D:/kim/kim/DP/data/validing.txt', sep = "\t")
test_df = pd.read_csv('D:/kim/kim/DP/data/testing.txt', sep = "\t")

In [5]:
train_df = train_df.sample(frac = 0.1, random_state = 500)
valid_df = valid_df.sample(frac = 0.1, random_state = 500)
test_df = test_df.sample(frac = 0.1, random_state = 500)

In [6]:
class Datasets(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        text = self.df.iloc[idx, 1]
        label = self.df.iloc[idx, 2]
        return text, label

In [7]:
train_dataset = Datasets(train_df)
train_loader = DataLoader(train_dataset, batch_size = 2, shuffle = True, num_workers = 0)

valid_dataset = Datasets(valid_df)
valid_loader = DataLoader(valid_dataset, batch_size = 2, shuffle = True, num_workers = 0)

test_dataset = Datasets(test_df)
test_loader = DataLoader(test_dataset, batch_size = 2, shuffle = True, num_workers = 0)

In [8]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.to(device)

100%|██████████| 231508/231508 [00:00<00:00, 360222.72B/s]
100%|██████████| 433/433 [00:00<00:00, 432001.34B/s]
100%|██████████| 440473133/440473133 [00:47<00:00, 9212631.30B/s] 


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [9]:
def save_checkpoint(save_path, model, valid_loss):
    if save_path == None:
        return
    state_dict = {'model_state_dict' : model.state_dict(),
                  'valid_loss' : valid_loss}
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_checkpoint(load_path, model):
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location = device)
    print(f'Model loaded from <== {load_path}')
    model.load_state_dict(state_dict['model_state_dict'])
    return state_dict['valid_loss']

def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
    if save_path == None:
        return
    state_dict = {'train_loss_list' : train_loss_list,
                  'valid_loss_list' : valid_loss_list,
                  'global_steps_list' : global_steps_list}
    torch.save(state_dict, save_path)
    print(f'Model saved to ==> {save_path}')

def load_metrics(load_path):
    if load_path == None:
        return
    state_dict = torch.load(load_path, map_location = device)
    print(f'Model loaded from <== {load_path}')
    return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']