In [None]:
### only colab
# ! git clone 'https://github.com/kangjun205/Dacon_AuthorClassification.git'
# %cd Dacon_AuthorClassification

In [None]:
import datetime
import sys

import numpy as np
import pandas as pd
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup

from utils.clean import clean_texts
from utils.tokenizer import get_tokenizer
from utils.util import save_model, set_seed

from data.dataloader import get_dataloader
from data.datasets import TextDataset

from src.loss import MultiLabelLoss
from src.model import BertForMultiLabelClassification

# Reading

In [None]:
train = pd.read_csv('train.csv')
train.drop(columns = ['index'], inplace = True)
train.head()

# Config

In [None]:
params = {
    'NUM_LABELS' : 5,
    'MAX_LEN' : 128,
    'BATCH_SIZE' : 32,
    'EPOCHS' : 5,
    'LEARNING_RATE' : 1e-3,
    'SEED' : 42
}
model_name = 'bert'

## seed value setting
set_seed(params['SEED'])

# Preprocessing

In [None]:
## train - valid split
train_indices, valid_indices = train_test_split(range(len(train)), test_size=0.2, random_state=42)

train_data = [train['text'].iloc[i] for i in train_indices] ## input text
train_target = [train['author'].iloc[i] for i in train_indices] ## target label

valid_data = [train['text'].iloc[i] for i in valid_indices] ## input text
valid_target = [train['author'].iloc[i] for i in valid_indices] ## target label

## cleaning
train_data = clean_texts(train_data)
valid_data = clean_texts(valid_data)

## tokenization
tokenizer = get_tokenizer()

train_dataloader = get_dataloader(train_data, train_target, tokenizer, params['MAX_LEN'], params['BATCH_SIZE'], shuffle = True)
valid_dataloader = get_dataloader(valid_data, valid_target, tokenizer, params['MAX_LEN'], params['BATCH_SIZE'], shuffle = True)

# Training

In [None]:
## model
model = BertForMultiLabelClassification(params['NUM_LABELS'])
model.to('cuda' if torch.cuda.is_available() else 'cpu')

## loss & optimizer
criterion = MultiLabelLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr = params['LEARNING_RATE'])

## learning rate scheduler
total_steps = len(train_dataloader) * params['EPOCHS']
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps = 0,
    num_training_steps = total_steps
)

## logging
wandb.init(project = 'Dacon_AuthorClassification')
wandb.config.update(params)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
for epoch in range(params['EPOCHS']):
    model.train() ## training
    train_loss = 0
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device) ## input text
        attention_mask = batch['attention_mask'].to(device) ## mask for padding
        labels = F.one_hot(batch['labels'].squeeze(), num_classes = 5).to(device) ## target label

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask) ## batch_size X num_labels
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        scheduler.step()

    ## logging
    wandb.log({'train_loss' : train_loss/len(train_dataloader)})

    model.eval() ## evaluating
    val_loss = 0
    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids = batch['input_ids'].to(device) ## input text
            attention_mask = batch['attention_mask'].to(device) ## mask for padding
            labels = F.one_hot(batch['labels'].squeeze(), num_classes = 5).to(device) ## target label

            outputs = model(input_ids, attention_mask) ## batch_size X num_labels
            loss = criterion(outputs, labels.float())
            val_loss += loss.item()

    ## logging
    wandb.log({'valid_loss' : val_loss/len(valid_dataloader)})
    print(f"Epoch {epoch+1}/{params['EPOCHS']}, Validation Loss: {val_loss/len(valid_dataloader)}")

## model saving
now = datetime.datetime.now().strftime('%d%H')
save_model(model, f'{model_name}_{now}.pt')
wandb.save(f'{model_name}_{now}.pt')