## Import

In [1]:
import math
import logging
import logging.config
import re
import os
import random
import datetime
from pprint import pprint
from itertools import combinations

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.utils.data import DataLoader

import faiss
import deepspeed
import pandas as pd
import numpy as np
import torch_optimizer as optim
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans, MiniBatchKMeans
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
    get_scheduler,
    set_seed
)
from datasets import load_dataset, load_metric
from accelerate import Accelerator, DeepSpeedPlugin

from tqdm.auto import tqdm
from rank_bm25 import BM25Okapi
import hyptorch.nn as hypnn
from pytorch_metric_learning import miners, losses, testers
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

## Hyper parameters

In [2]:
TRAIN_DATA = "open/bm25-graphcodebert-base-all_case_v3.csv"
VAL_DATA = "open/sample_train.csv"
SAMPLE_DATA = "open/sample_train.csv"
CODE_DATA_PATH = "open/code"
TEST_DATA = "open/test.csv"
SUBMISSION = 'open/sample_submission.csv'
PRETRAINED_MODEL = "microsoft/graphcodebert-base"
NUM_LABELS = 2
DIM = 768
MAX_LEN = 512
BATCH = 32
NUM_WORKERS = 4
GRADIENT_ACCUMULATION_STEPS = 4
GRADIENT_CHECKPOINTING = True
EPOCHS = 5
MAX_LR = 5e-3
MIN_LR = 5e-6
WD = 1e-2
SEED = 12361
TRAIN_SELECT_DATA = 2000
VAL_SELECT_DATA = int(TRAIN_SELECT_DATA*0.1)
TRAIN_TEST_SPLIT_RATIO = 0.1
OUTPUT_DIR = "./results"
SAVE_MODEL = f"{PRETRAINED_MODEL}_{datetime.datetime.now().strftime('%H:%M:%S:%m')}"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ['MASTER_ADDR'] = 'localhost'
# os.environ['MASTER_PORT'] = '29500' # modify if RuntimeError: Address already in use
# os.environ['RANK'] = "0"
# os.environ['LOCAL_RANK'] = "0"
# os.environ['WORLD_SIZE'] = "1"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

## Logging

In [3]:
config = {
    "version": 1,
    "formatters": {
        "simple": {"format": "[%(asctime)s] %(message)s", "datefmt": "%Y-%m-%d %H:%M:%S"},
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "formatter": "simple",
            "level": "INFO",
        },
        "file": {
            "class": "logging.FileHandler",
            "filename": f"logs/{datetime.datetime.now().strftime('%H:%M:%S:%m')}.log",
            "formatter": "simple",
            "level": "INFO",
        },
    },
    "root": {"handlers": ["console", "file"], "level": "INFO"},
    "loggers": {"parent": {"level": "INFO"}, "parent.child": {"level": "DEBUG"},},
}

logging.config.dictConfig(config)
logger = logging.getLogger()

## Fix seed

In [4]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
os.environ["PYTHONHASHSEED"] = str(SEED)

## Class & Functions

### Utils

In [5]:
class AverageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def preprocess_function(examples):
    for i in range(1, 3):
        for j in range(len(examples[f"code{i}"])):
            examples[f"code{i}"][j] = re.sub("#.*", "", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub('""".*?"""', "", examples[f"code{i}"][j], flags=re.S)
            examples[f"code{i}"][j] = re.sub("'''.*?'''", "", examples[f"code{i}"][j], flags=re.S)
            examples[f"code{i}"][j] = re.sub("b'.*?'", "b''", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub('b".*?"', 'b""', examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub("^from .*? import .*?\n", "", examples[f"code{i}"][j], flags=re.MULTILINE) # TODO: 이거 포함시켜서 preprocess 하면 성능 향상 되는지 확인하기
            examples[f"code{i}"][j] = re.sub("^import .*?\n", "", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub("@.*", "", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub("^\n", "", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub("^ *?\n", "", examples[f"code{i}"][j], flags=re.MULTILINE)
            examples[f"code{i}"][j] = re.sub("    ", "\t", examples[f"code{i}"][j], flags=re.MULTILINE)
            
    outputs = tokenizer(examples['code1'], examples['code2'], padding="max_length", max_length=MAX_LEN, truncation=True)
    if 'similar' in examples:
        outputs["labels"] = examples["similar"]
    return outputs

def save_model(save_name, model, optimizer, epoch, train_loss):
    torch.save({
        "epoch": epoch,
        "total_epoch": EPOCHS,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": train_loss,
    }, f"{save_name}.pt")

### Preprocess script to csv

In [6]:
# def make_train_dataset_from_codefolder(path):
#     scripts_list = []
#     problem_nums = []

#     for problem_folder in tqdm(os.listdir(path)):
#         scripts = os.listdir(os.path.join(path, problem_folder))
#         problem_num = scripts[0].split('_')[0]
#         for script in scripts:
#             script_file = os.path.join(path, problem_folder, script)
#             with open(script_file, 'r', encoding='utf-8') as file:
#                 lines = file.read()
#             lines = re.sub(r"#.*", "", lines, flags=re.MULTILINE)
#             lines = re.sub(r'""".*?"""', "", lines, flags=re.S)
#             lines = re.sub(r"^\n", "", lines, flags=re.MULTILINE)
#             lines = re.sub(r"^ *?\n", "", lines, flags=re.MULTILINE)
#             lines = re.sub(r"    ", "\t", lines, flags=re.MULTILINE)
#             scripts_list.append(lines)
#         problem_nums.extend([problem_num]*len(scripts))

#     df = pd.DataFrame(data = {'code':scripts_list, 'problem_num':problem_nums})
#     logger.info(f"Descirbe: \n{df.describe()}")
#     logger.info(f"Head: \n{df.head()}")
#     logger.info(f"Length: \n{len(df)}")

#     df['tokens'] = df['code'].apply(tokenizer.tokenize)
#     df['len'] = df['tokens'].apply(len)
#     logger.info(f"Tokens Describe: \n{df.describe()}")

#     ndf = df[df['len'] <= 512].reset_index(drop=True)
#     logger.info(f"Max Length Clipping Describe: \n{ndf.describe()}")

#     train_df, valid_df, train_label, valid_label = train_test_split(
#         ndf,
#         ndf['problem_num'],
#         random_state=42,
#         test_size=0.1,
#         stratify=ndf['problem_num'],
#     )

#     train_df = train_df.reset_index(drop=True)
#     valid_df = valid_df.reset_index(drop=True)
#     logger.info("Done!")
#     return train_df, valid_df
    
# # def preprocess_bm25(df, file_name="preprocess_bm25"):
# #     codes = df['code'].to_list()
# #     problems = df['problem_num'].unique().tolist()
# #     problems.sort()

# #     tokenized_corpus = [tokenizer.tokenize(code) for code in codes]
# #     bm25 = BM25Okapi(tokenized_corpus)

# #     total_positive_pairs = []
# #     total_negative_pairs = []

# #     for problem in tqdm(problems):
# #         solution_codes = df[df['problem_num'] == problem]['code']
# #         positive_pairs = list(combinations(solution_codes.to_list(),2))

# #         solution_codes_indices = solution_codes.index.to_list()
# #         negative_pairs = []

# #         first_tokenized_code = tokenizer.tokenize(positive_pairs[0][0])
# #         negative_code_scores = bm25.get_scores(first_tokenized_code)
# #         negative_code_ranking = negative_code_scores.argsort()[::-1] # 내림차순
# #         ranking_idx = 0

# #         for solution_code in solution_codes:
# #             negative_solutions = []
# #             while len(negative_solutions) < len(positive_pairs) // len(solution_codes):
# #                 high_score_idx = negative_code_ranking[ranking_idx]

# #                 if high_score_idx not in solution_codes_indices:
# #                     negative_solutions.append(df['code'].iloc[high_score_idx])
# #                 ranking_idx += 1

# #             for negative_solution in negative_solutions:
# #                 negative_pairs.append((solution_code, negative_solution))

# #         total_positive_pairs.extend(positive_pairs)
# #         total_negative_pairs.extend(negative_pairs)

# #     pos_code1 = list(map(lambda x:x[0],total_positive_pairs))
# #     pos_code2 = list(map(lambda x:x[1],total_positive_pairs))

# #     neg_code1 = list(map(lambda x:x[0],total_negative_pairs))
# #     neg_code2 = list(map(lambda x:x[1],total_negative_pairs))

# #     pos_label = [1]*len(pos_code1)
# #     neg_label = [0]*len(neg_code1)

# #     pos_code1.extend(neg_code1)
# #     total_code1 = pos_code1
# #     pos_code2.extend(neg_code2)
# #     total_code2 = pos_code2
# #     pos_label.extend(neg_label)
# #     total_label = pos_label
# #     pair_data = pd.DataFrame(data={
# #         'code1':total_code1,
# #         'code2':total_code2,
# #         'similar':total_label
# #     })
# #     pair_data = pair_data.sample(frac=1).reset_index(drop=True)
# #     pair_data.to_csv(f'open/{file_name}.csv',index=False)


# def preprocess_bm25(df, file_name="preprocess_bm25"):
#     problems = sorted(df['problem_num'].unique().tolist())
#     positive_pairs = []
#     negative_pairs = []
#     for problem in tqdm(problems):
#         positive_codes = df[df["problem_num"]==problem]["code"].to_list()
#         # negative_codes = df[df["problem_num"]!=problem]["code"].to_list()
#         positive_tokenized_corpus = [tokenizer.tokenize(code) for code in positive_codes]
#         # negative_tokenized_corpus = [tokenizer.tokenize(code) for code in negative_codes]
#         positive_bm25 = BM25Okapi(positive_tokenized_corpus)
#         # negative_bm25 = BM25Okapi(negative_tokenized_corpus)

#         # get positive_pairs
#         for idx, code in enumerate(positive_codes, start=1):
#             tokenized_code = tokenizer.tokenize(code)
#             positive_scores = positive_bm25.get_scores(tokenized_code)
#             # negative_scores = negative_bm25.get_scores(tokenized_code)
            
#             for _ in range(2):
#                 for i in range(len(positive_scores)):
#                     positive_bottom = positive_scores.argsort()[i]
#                     if (code, positive_bottom) not in positive_pairs and (positive_bottom, code) not in positive_pairs:
#                         positive_pairs.append((code, positive_bottom))
#                         break
            
#             if idx == len(positive_codes):
#                 for p in problems:
#                     if problem == p:
#                         continue
#                     negative_codes = df[df["problem_num"]==p]["code"].to_list()
#                     for negative_code in negative_codes:
#                         if (code, negative_code) not in negative_pairs and (negative_code, code) not in negative_pairs:
#                             negative_pairs.append((code, negative_code))
#                             break
            
#         # print("positive:", len(positive_pairs), "negative:", len(negative_pairs))
#             # for i in range(len(negative_scores)):
#             #     negative_bottom = negative_scores.argsort()[i]
#             #     if (code, negative_bottom) not in negative_pairs and (negative_bottom, code) not in negative_pairs:
#             #         negative_pairs.append((code, negative_bottom))
#             #         break

#     positive_labels = [1]*len(positive_pairs)
#     negative_labels = [0]*len(negative_pairs)

#     total_pairs = []
#     total_labels = []
#     total_pairs.extend(positive_pairs)
#     total_pairs.extend(negative_pairs)
#     total_labels.extend(positive_labels)
#     total_labels.extend(negative_labels)
#     total_code1 = list(map(lambda x:x[0], total_pairs))
#     total_code2 = list(map(lambda x:x[1], total_pairs))
        
#     pair_data = pd.DataFrame(data={
#         'code1':total_code1,
#         'code2':total_code2,
#         'similar':total_labels
#     })
#     pair_data = pair_data.sample(frac=1).reset_index(drop=True)
#     pair_data.to_csv(f'{file_name}.csv',index=False)

In [7]:
# train_df, val_df = make_train_dataset_from_codefolder(CODE_DATA_PATH)
# preprocess_bm25(train_df, file_name="preprocess_bm25_train")
# preprocess_bm25(val_df, file_name="preprocess_bm25_val")

### Trainer

In [8]:
def train(model, optimizer, dataloader):
    metric = load_metric("accuracy")
    train_loss = AverageMeter()
    model.train()
    with tqdm(dataloader, total=len(dataloader), unit="batch") as train_bar:
        for idx, batch in enumerate(train_bar, start=1):
            embeddings = model(batch)
            # hard_pairs = miner(embeddings, batch["labels"])
            # loss = loss_func(embeddings, batch["labels"], hard_pairs)
            loss = loss_func(embeddings, batch["labels"])
            # loss = sce_loss(embeddings, batch["labels"])
            loss /= GRADIENT_ACCUMULATION_STEPS
            accelerator.backward(loss)
            
            if (idx % GRADIENT_ACCUMULATION_STEPS == 0) or (idx == len(dataloader)):
                accelerator.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=2)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                predictions = torch.argmax(loss_func.get_logits(embeddings), dim=-1)
                # predictions = torch.argmax(embeddings, dim=-1)
                metric.add_batch(predictions=predictions, references=batch["labels"])
                train_loss.update(loss.item(), BATCH)
                train_bar.set_postfix(train_loss=loss.item(), train_acc=metric.compute()["accuracy"])
            
            ########## DEEPSPEED #############
            # if (idx % GRADIENT_ACCUMULATION_STEPS == 0) or (idx == len(dataloader)):
            # model.backward(loss)
            # model.step()
            # train_loss.update(loss.item(), BATCH)
            # train_bar.set_postfix(train_loss=loss.item())
            ########## DEEPSPEED #############
    return train_loss.avg

def valid(model, dataloader):
    metric = load_metric("accuracy")
    val_loss = AverageMeter()
    model.eval()
    with tqdm(dataloader, total=len(dataloader), unit="batch") as val_bar:
        for idx, batch in enumerate(val_bar):
            with torch.no_grad():
                embeddings = model(batch)
            loss = loss_func(embeddings, batch["labels"])
            # loss = sce_loss(embeddings, batch["labels"])
            predictions = torch.argmax(loss_func.get_logits(embeddings), dim=-1)
            # predictions = torch.argmax(embeddings, dim=-1)
            metric.add_batch(predictions=predictions, references=batch["labels"])
            val_loss.update(loss.item(), BATCH)
            val_bar.set_postfix(val_loss=loss.item())

    return val_loss.avg, metric.compute()["accuracy"]

def predict(model, dataloader):
    pred_list = []
    model.eval()
    for batch in tqdm(dataloader):
        with torch.no_grad():
            embeddings = model(batch)
        predictions = torch.argmax(loss_func.get_logits(embeddings), dim=-1).tolist()
        pred_list.extend(predictions)
    return pred_list

# def valid_kmeans(model, dataloader):
#     metric = load_metric("accuracy")
#     val_loss = AverageMeter()
#     kmeans = MiniBatchKMeans(n_clusters=2, random_state=SEED)
#     embeds = []
#     labels = []
#     # tmp = []
#     # match_finder = MatchFinder(distance=CosineSimilarity(), threshold=0.7)
#     # inference_model = InferenceModel(model, match_finder=match_finder, normalize_embeddings=False)
#     model.eval()
    
    
#     ###########################################################
#     # tmp = []
#     # for batch in trainloader:
#     #     with torch.no_grad():
#     #         outputs = model(batch)
#     #     tmp.extend(outputs["embed"].cpu().numpy())
#     # kmeans.fit(tmp)
#     ###########################################################
    
    
#     with tqdm(dataloader, total=len(dataloader), unit="batch") as val_bar:
#         for idx, batch in enumerate(val_bar):
#             with torch.no_grad():
#                 # outputs = model(**batch)
#                 outputs = model(batch)
#                 # decision = inference_model.is_match(batch, batch["labels"])
#             # tmp.extend(decision)
#             loss = loss_func(outputs, batch["labels"])
#             embeds.extend(outputs.detach().cpu().numpy())
#             labels.extend(batch["labels"])
#             val_loss.update(loss.item(), BATCH)
#             val_bar.set_postfix(val_loss=loss.item())
#         klabel = kmeans.fit_predict(embeds)
#         # print(tmp)
#         # print("acc:", np.sum(tmp) / len(tmp))
#         metric.add_batch(predictions=klabel, references=labels)
#     return val_loss.avg, metric.compute()["accuracy"]

# def predict_kmeans(model, dataloader):
#     kmeans = MiniBatchKMeans(n_clusters=2, random_state=SEED)
#     embeds = []
#     model.eval()
#     for batch in tqdm(dataloader):
#         # batch = {k: v.to("cuda:0") for k, v in batch.items()}
#         with torch.no_grad():
#             outputs = model.predict(batch)
#         embeds.extend(outputs["embed"].detach().cpu().numpy())
#     klabel = kmeans.fit_predict(embeds)
#     return klabel


# def valid_metric(model, train_dataloader, val_dataloader):
#     metric = load_metric("accuracy")
#     val_loss = AverageMeter()
    
#     train_labels = []
#     predictions = []
#     index = faiss.IndexFlatL2(DIM)

#     model.eval()
#     for batch in tqdm(train_dataloader):
#         with torch.no_grad():
#             outputs = model(batch)
#         index.add(outputs["embed"].cpu().numpy())
#         train_labels.extend(batch["labels"])
    
#     with tqdm(val_dataloader, total=len(val_dataloader), unit="batch") as val_bar:
#         for idx, batch in enumerate(val_bar):
#             predictions = []
#             with torch.no_grad():
#                 outputs = model(batch)
#             d, i = index.search(outputs["embed"].cpu().numpy(), 3)
#             for idxs in i:
#                 prediction = [train_labels[idx].item() for idx in idxs]
#                 num_0 = prediction.count(0)
#                 num_1 = prediction.count(1)
#                 if num_0 > num_1:
#                     predictions.extend([0])
#                 else:
#                     predictions.extend([1])

#             # i = i.squeeze()
#             # prediction = [train_labels[idx].item() for idx in i]
            
#             metric.add_batch(predictions=predictions, references=batch["labels"])
#             val_loss.update(outputs["loss"], BATCH)
#             val_bar.set_postfix(val_loss=outputs["loss"].item())

#     return val_loss.avg, metric.compute()["accuracy"]

# def predict_metric(model, train_dataloader, test_dataloader):
#     metric = load_metric("accuracy")
#     val_loss = AverageMeter()
    
#     train_labels = []
#     predictions = []
#     index = faiss.IndexFlatL2(DIM)

#     model.eval()
#     for batch in tqdm(train_dataloader):
#         with torch.no_grad():
#             outputs = model(batch)
#         index.add(outputs["embed"].cpu().numpy())
#         train_labels.extend(batch["labels"])
    
#     for batch in tqdm(test_dataloader):
#         with torch.no_grad():
#             outputs = model.predict(batch)
#         d, i = index.search(outputs["embed"].cpu().numpy(), 3)
#         for idxs in i:
#             prediction = [train_labels[idx].item() for idx in idxs]
#             num_0 = prediction.count(0)
#             num_1 = prediction.count(1)
#             if num_0 > num_1:
#                 predictions.extend([0])
#             else:
#                 predictions.extend([1])
#     return predictions

## Load Train / Val / Test dataset

In [9]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)
tokenizer.truncation_side = "left"
train_dataset = load_dataset("csv", data_files=TRAIN_DATA)['train']
# train_dataset = train_dataset.select(range(TRAIN_SELECT_DATA))
train_dataset = train_dataset.map(
    preprocess_function,
    remove_columns=['code1', 'code2', 'similar'],
    load_from_cache_file=False,
    batched=True
)
train_dataset.set_format("torch")
train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=BATCH,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

val_dataset = load_dataset("csv", data_files=VAL_DATA)['train']
# val_dataset = val_dataset.select(range(VAL_SELECT_DATA))
val_dataset = val_dataset.map(
    preprocess_function,
    remove_columns=['code1', 'code2', 'similar'],
    load_from_cache_file=False,
    batched=True
)
val_dataset.set_format("torch")
val_dataloader = DataLoader(
    val_dataset,
    shuffle=False,
    batch_size=BATCH,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

Downloading and preparing dataset csv/default to /home/djlee/.cache/huggingface/datasets/csv/default-cfaa861a98c64ab5/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /home/djlee/.cache/huggingface/datasets/csv/default-cfaa861a98c64ab5/0.0.0/433e0ccc46f9880962cc2b12065189766fbb2bee57a221866138fb9203c83519. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2691 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/18 [00:00<?, ?ba/s]

## Define Model

In [10]:
# class Network(nn.Module):
#     def __init__(self):
#         super(Network, self).__init__()
#         self.pretrained_model = AutoModel.from_pretrained(PRETRAINED_MODEL)
#         self.embedding = nn.Linear(self.pretrained_model.config.hidden_size, DIM)
#         if GRADIENT_CHECKPOINTING:
#             self.pretrained_model.gradient_checkpointing_enable()

#     def forward(self, inputs):
#         outputs = self.get_embeddings(inputs)
#         hard_pairs = miner(outputs, inputs["labels"])
#         loss = loss_func(outputs, inputs["labels"], hard_pairs)
#         return {"embed": outputs, "loss": loss}
    
#     def predict(self, inputs):
#         outputs = self.get_embeddings(inputs)
#         return {"embed": outputs}
    
#     def get_embeddings(self, inputs):
#         outputs = self.pretrained_model(
#             input_ids=inputs.get("input_ids"),
#             token_type_ids=inputs.get("token_type_ids"),
#             attention_mask=inputs.get("attention_mask")
#         )
#         outputs = self.embedding(outputs["pooler_output"])
#         outputs = F.normalize(outputs)
#         return outputs

class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.pretrained_model = AutoModel.from_pretrained(PRETRAINED_MODEL)
        if GRADIENT_CHECKPOINTING:
            self.pretrained_model.gradient_checkpointing_enable()

    def forward(self, inputs):
        embeddings = self.pretrained_model(
            input_ids=inputs.get("input_ids"),
            token_type_ids=inputs.get("token_type_ids"),
            attention_mask=inputs.get("attention_mask")
        )
        embeddings = mean_pooling(embeddings, inputs.get("attention_mask"))
        embeddings = self.l2_norm(embeddings)
        
        return embeddings
    
    def l2_norm(self, inputs):
        inputs_size = inputs.size()
        buffer = torch.pow(inputs, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _outputs = torch.div(inputs, norm.view(-1, 1).expand_as(inputs))

        outputs = _outputs.view(inputs_size)

        return outputs
    
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

In [11]:
# deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
accelerator = Accelerator(
    fp16=True,
    # deepspeed_plugin=deepspeed_plugin,
)

model = Network()
    
# optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
optimizer = optim.Lamb(
    model.parameters(),
    lr=MIN_LR,
    betas=(0.9, 0.999),
    eps=1e-6,
    weight_decay=WD
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=MAX_LR,
    steps_per_epoch=len(train_dataloader),
    epochs=EPOCHS
)

# TEST
# miner = miners.MultiSimilarityMiner()
# miner = miners.BatchEasyHardMiner()
# miner = miners.TripletMarginMiner(margin=0.01, type_of_triplets="semihard")
# loss_func = losses.SoftTripleLoss(
#     num_classes=NUM_LABELS,
#     embedding_size=DIM,
#     centers_per_class=10,
#     la=20,
#     gamma=0.1,
#     margin=0.01,
# )
loss_func = losses.ProxyAnchorLoss(
    num_classes=NUM_LABELS,
    embedding_size=DIM,
)

# scheduler = get_scheduler(
#     name="linear",
#     optimizer=optimizer,
#     num_warmup_steps=0,
#     num_training_steps=EPOCHS*len(train_dataloader)
# )

# config =  {
#     "train_batch_size": BATCH,
#     # "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS,
#     "optimizer": {
#         "type": "Adam",
#         "params": {
#             "lr": LR,
#             "weight_decay": WD
#         }
#     },
#     "fp16": {
#         "enabled": True
#     },
#     "zero_optimization": True
# }

# model, optimizer, _, _ = deepspeed.initialize(model=model,
#                                               config_params=config,
#                                               model_parameters=model.parameters())

model, optimizer, loss_func, scheduler, train_dataloader, val_dataloader = accelerator.prepare(
    model, optimizer, loss_func, scheduler, train_dataloader, val_dataloader
)

## Train

In [12]:
best_accuracy = 0
for epoch in range(1, EPOCHS + 1):
    train_loss = train(model, optimizer, train_dataloader)
    val_loss, val_accuracy = valid(model, val_dataloader)
    accelerator.wait_for_everyone()
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        save_model("results_accelerator/best", model, optimizer, epoch, train_loss)
        # unwrapped_model = accelerator.unwrap_model(model)
        # unwrapped_model.save_pretrained("./results_accelerator/best", save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
    
    save_model(f"results_accelerator/epoch{epoch}", model, optimizer, epoch, train_loss)
    # unwrapped_model = accelerator.unwrap_model(model)
    # unwrapped_model.save_pretrained(f"./results_accelerator/epoch{epoch}", save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
    logger.info(f"epoch:{epoch}/{EPOCHS} | train -> loss: {train_loss}")
    logger.info(
        f"epoch:{epoch}/{EPOCHS} | validation -> loss: {val_loss} | acc: {val_accuracy}"
    )

  0%|          | 0/84094 [00:00<?, ?batch/s]

  0%|          | 0/562 [00:00<?, ?batch/s]

[2022-06-10 13:33:19] epoch:1/5 | train -> loss: 0.6394323408846555
[2022-06-10 13:33:19] epoch:1/5 | validation -> loss: 2.434509139617796 | acc: 0.9939343350027824


  0%|          | 0/84094 [00:00<?, ?batch/s]

KeyboardInterrupt: 

## Predict

In [13]:
test_dataset = load_dataset("csv", data_files=TEST_DATA)['train']
test_dataset = test_dataset.map(
    preprocess_function,
    remove_columns=["pair_id", 'code1', 'code2'],
    load_from_cache_file=False, # TODO: 변경시 False
    batched=True
)
test_dataset.set_format("torch")
test_dataloader = DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=BATCH,
    pin_memory=True,
    num_workers=NUM_WORKERS,
)

test_dataloader = accelerator.prepare(test_dataloader)
checkpoint = torch.load("results_accelerator/best.pt")
model.load_state_dict(checkpoint['model_state_dict'])

predictions = predict(model, test_dataloader)

df = pd.read_csv(SUBMISSION)
df['similar'] = predictions
df.to_csv('./submission.csv', index=False)

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?ba/s]

  0%|          | 0/5616 [00:00<?, ?it/s]

In [None]:
torch.cuda.empty_cache()