In [4]:
import argparse
import random
import os
import re

import wandb

import pandas as pd
import numpy as np

from tqdm.auto import tqdm

import transformers
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer
import torch
import torchmetrics
import pytorch_lightning as pl
import torch.nn.functional as F
from pytorch_lightning.loggers import WandbLogger
from torch.utils.checkpoint import checkpoint

# seed 고정
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
random.seed(0)

In [None]:
def undersampling(data_path):
    '''
    label 0인 데이터 중 1000개를 추출하여 다른 label에 맞추기 위해 under sampling 하는 함수
    '''
    df = pd.read_csv(data_path)
    df_only_0 = df[df['label'] == 0][1000:2000].copy()
    df_new = df[df['label'] != 0].copy()
    df_new = pd.concat([df_new, df_only_0])
    return df_new

def swap_sentence(data_path):
    '''
    sentence_1과 sentence_2를 바꾸는 함수
    데이터 불균형 해소 목적으로 증강
    '''
    df = pd.read_csv(data_path)
    df_swapped = df.copy()
    df_swapped['sentence_1'] = df['sentence_2']
    df_swapped['sentence_2'] = df['sentence_1']
    df_swapped = df_swapped[df_swapped['label'] != 0]
    return df_swapped

def copy_sentence(data_path, index_min=250, index_max=750):
    '''
    sentence_1에 sentence_2를 대입하여 같은 문장 배치하는 함수
    label 5인 데이터를 증강하기 위한 목적
    '''
    df = pd.read_csv(data_path)
    df_copied = df[df['label']==0][index_min:index_max].copy()
    df_copied['sentence_1'] = df_copied['sentence_2']
    df_copied['label'] = 5.0
    return df_copied

def concatenate(data_path, *dataframe):
    result = pd.concat(dataframe)
    result.to_csv(data_path, index=False)

def augmentation(train_data_path, new_data_path):
    df_undersampling = undersampling(train_data_path)
    df_swapped = swap_sentence(train_data_path)
    df_copied = copy_sentence(train_data_path)
    concatenate(new_data_path, df_undersampling, df_swapped, df_copied)

if __name__ == "__main__":
    augmentation('/data/ephemeral/home/data/train.csv', '/data/ephemeral/home/data/yj_aug_train.csv')