In [None]:
### only colab
# ! git clone 'https://github.com/kangjun205/Dacon_AuthorClassification.git'
# %cd Dacon_AuthorClassification

In [None]:
import datetime
import sys
import os

import numpy as np
import pandas as pd
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup

# ## only local
# sys.path.append('..')

from utils.clean import clean_texts
from utils.tokenizer import get_tokenizer
from utils.util import save_model, set_seed

from data.dataloader import get_dataloader
from data.datasets import TextDataset

from src.loss import MultiLabelLoss
from src.model import BertForMultiLabelClassification
from tqdm import tqdm

# Reading

In [None]:
train = pd.read_csv('train.csv')
train.drop(columns = ['index'], inplace = True)
train.head()

# Setting

In [None]:
! wandb login

In [None]:
NUM_LABELS = 5
SEED = 42
set_seed(SEED)

# Preprocessing

In [None]:
## train - valid split
train_indices, valid_indices = train_test_split(range(len(train)), test_size=0.2, random_state=42)

train_data = [train['text'].iloc[i] for i in train_indices] ## input text
train_target = [train['author'].iloc[i] for i in train_indices] ## target label

valid_data = [train['text'].iloc[i] for i in valid_indices] ## input text
valid_target = [train['author'].iloc[i] for i in valid_indices] ## target label

## cleaning
train_data = clean_texts(train_data)
valid_data = clean_texts(valid_data)

# Train

In [None]:
def train(config = None) :
    with wandb.init(config = config) :
        config = wandb.config

        ## tokenizer
        tokenizer = get_tokenizer()

        ## dataloader
        train_dataloader = get_dataloader(train_data, train_target, tokenizer, config['MAX_LEN'], config['BATCH_SIZE'], shuffle = True)
        valid_dataloader = get_dataloader(valid_data, valid_target, tokenizer, config['MAX_LEN'], config['BATCH_SIZE'], shuffle = True)
        
        ## model
        model = BertForMultiLabelClassification(NUM_LABELS, config['NUM_HIDDEN'])
        model.to('cuda' if torch.cuda.is_available() else 'cpu')

        ## loss & optimizer
        criterion = MultiLabelLoss()
        optimizer = torch.optim.Adam(model.classifier.parameters(), lr = config['LEARNING_RATE'])

        ## learning rate scheduler
        total_steps = len(train_dataloader) * config['EPOCHS']
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps = 0,
            num_training_steps = total_steps
        )

        ## minimum validataion loss setting
        val_loss_min = 1e+05

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        for epoch in range(config['EPOCHS']):
            model.train() ## training
            train_loss = 0
            for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{config['EPOCHS']} - Training"):
                input_ids = batch['input_ids'].to(device) ## input text
                attention_mask = batch['attention_mask'].to(device) ## mask for padding
                labels = F.one_hot(batch['labels'].squeeze(), num_classes = 5).to(device) ## target label

                optimizer.zero_grad()
                outputs = model(input_ids, attention_mask) ## batch_size X num_labels
                loss = criterion(outputs, labels.float())
                train_loss += loss.item()
                loss.backward()
                optimizer.step()
                scheduler.step()

            ## logging
            wandb.log({'train_loss' : train_loss/config['BATCH_SIZE']})

            model.eval() ## evaluating
            val_loss = 0
            with torch.no_grad():
                for batch in tqdm(valid_dataloader, desc=f"Epoch {epoch+1}/{config['EPOCHS']} - Validation"):
                    input_ids = batch['input_ids'].to(device) ## input text
                    attention_mask = batch['attention_mask'].to(device) ## mask for padding
                    labels = F.one_hot(batch['labels'].squeeze(), num_classes = 5).to(device) ## target label

                    outputs = model(input_ids, attention_mask) ## batch_size X num_labels
                    loss = criterion(outputs, labels.float())
                    val_loss += loss.item()

            ## logging
            wandb.log({'valid_loss' : val_loss/config['BATCH_SIZE']})
            print(f"Epoch {epoch+1}/{config['EPOCHS']}, Validation Loss: {val_loss/len(valid_dataloader)}")

            ## check point
            if val_loss < val_loss_min and epoch > config['EPOCHS']/2 :
                ## validation loss가 최저점을 갱신한 경우 chekout point 생성
                ## epoch 반 이상 돌린 경우에만 적용
                val_loss_min = val_loss
                now = datetime.datetime.now().strftime('%d%H%M')
                save_model(model, f'BERT_{now}.pt')
                wandb.save(f'BERT_{now}.pt')
                print(f'model BERT_{now}.pt is saved')

In [None]:
# Sweep 설정
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'log loss',
        'goal': 'minimize'   
    },
    'parameters': {
        'BATCH_SIZE': {
            'values': [32, 64]
        },
        'NUM_HIDDEN': {
            'values': [32, 64, 128]
        },
        'LEARNING_RATE': {
            'min': 0.0001,
            'max': 0.1
        },
        'EPOCHS': {
            'value': 10
        },
        'MAX_LEN': {
            'value': 128
        }
    }
}

# Sweep 초기화 및 실행
sweep_id = wandb.sweep(sweep_config, project="Dacon_AuthorClassification")
wandb.agent(sweep_id, train, count=5)