In [None]:
import pandas as pd
import numpy as np
import os
os.chdir("/Users/morizin/Documents/Code/jigsaw-competition")

import re
from src.jigsaw.utils.common import read_csv
from src.jigsaw import logger

In [None]:
class CFG:
    fold : int = 4
    dataset : list[str] = ['artifacts/data/folded_cleaned_raw']
    files: list[str] = ['train.csv']
    url2sem : bool = False
    features : list[str] = ['body', 'subreddit', 'rule', 'positive_example_1', 'positive_example_2', 'negative_example_1', 'negative_example_2']
    labels : list[str] | str = 'rule_violation'


    truncation : bool | str = True
    model_name : str = 'microsoft/deberta-v3-small'
    padding : bool | str = 'max_length'
    max_length : int = 2048

    outdir : str = './model'
    nepochs : int = 1
    learning_rate : float = 2e-5
    batch_size: int = 4
    gradient_accumulation_step : int = 1
    weight_decay : float = 0.01
    warmup_ratio: float = 0.1

config = CFG()


In [None]:
from pandas.core.frame import DataFrame
import os

def get_datas(config: CFG) -> DataFrame:
    data_coll = []
    for dataset in config.dataset:
        for file in config.files:
            data = read_csv(os.path.join(dataset, file))
            columns = config.features
            if all([col in data.columns for col in columns]):
                data_coll.append(data)
            else:
                logger.error(
                    f"The dataset can't be inlcuded as it have unmatched columns names {data.columns}"
                )
    data = pd.concat(data_coll, axis=0)
    return data

data = get_datas(config)

In [None]:
from torch.utils.data import Dataset
from ensure import ensure_annotations
from pandas.core.frame import DataFrame
from transformers import AutoTokenizer
from transformers.models.deberta_v2.tokenization_deberta_v2 import DebertaV2Tokenizer
from transformers.models.deberta_v2.tokenization_deberta_v2_fast import (
    DebertaV2TokenizerFast,
)
from pandas.api.types import is_string_dtype
from src.jigsaw.utils.data import build_prompt, url_to_semantics
from src.jigsaw.utils.common import read_csv
from typing import Dict
import torch


class ClassifierDataset(Dataset):
    @ensure_annotations
    def __init__(
        self,
        config: CFG,
        data: DataFrame | str,
        tokenizer: DebertaV2Tokenizer | DebertaV2TokenizerFast | str,
    ):
        if isinstance(data, str):
            data = read_csv(data)

        if isinstance(data, DataFrame):
            self.data = data
        else:
            error = f"'data' can be either str or pd.DataFrame. 'data' has type {type(data).__name__}"
            logger.error(error)
            raise Exception(error)

        if isinstance(tokenizer, str):
            tokenizer = AutoTokenizer.from_pretrained(tokenizer)
            
        if isinstance(tokenizer, (DebertaV2Tokenizer, DebertaV2TokenizerFast)):
            self.tokenizer = tokenizer
        else:
            error = f"'tokenizer' can be either str, DebertaV2Tokenizer, DebertaV2TokenizerFast. 'tokenizer' has type {type(tokenizer).__name__}"
            logger.error(error)
            raise Exception(error)

        if config.url2sem:
            for (col, dtype) in data.dtypes.items():
                if is_string_dtype(dtype):
                    data[col] = data[col] + data[col].apply(url_to_semantics)

        self.completion = data.apply(build_prompt, axis=1).to_list()
        
        self.encoding = self.tokenizer(
            self.completion, truncation=config.truncation, padding = config.padding, max_length=config.max_length
        )

        if isinstance(config.labels, str):
            config.labels = [config.labels]

        if any([col in data.columns for col in config.labels]):
            self.labels = data[config.labels].to_numpy()
        else:
            self.labels = None

    def __len__(self,) -> int:
        assert len(self.encoding['input_ids']) == len(self.labels), f"Input and Output length mismatch {len(self.encoding)} != {len(self.labels)}"
        return len(self.encoding['input_ids'])
    
    def __getitem__(self, idx) -> Dict[str, torch.Tensor]: 
        items = {key : torch.tensor(value[idx]) for (key, value) in self.encoding.items()}
        if self.labels is not None:
            items['label'] = torch.tensor(self.labels[idx].flatten())
        return items


train_dataset = ClassifierDataset(
    config, data, config.model_name
)

In [None]:
# from torch.utils.data import DataLoader

# train_dataloader = DataLoader(
#     train_dataset,
#     batch_size=16,
#     shuffle = True,
#     num_workers=0,
#     pin_memory=True
# )

# for i in train_dataloader:
#     print(i['input_ids'].shape)

In [None]:
import torch.nn as nn
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(config.model_name, trust_remote_code = True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
del model.dropout

In [None]:
model

In [None]:
training_args = TrainingArguments(
    output_dir=config.outdir,
    overwrite_output_dir=True,
    do_train=True,
    per_device_train_batch_size=config.batch_size,
    gradient_accumulation_steps= config.gradient_accumulation_step,
    learning_rate=config.learning_rate,
    weight_decay=config.weight_decay,
    warmup_ratio= config.warmup_ratio,
    num_train_epochs= config.nepochs,
    report_to= 'none',
    save_strategy='no'
)

trainer = Trainer(
    model = model, 
    args = training_args,
    train_dataset=train_dataset
)

In [None]:
trainer.train()