In [9]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch.optim import Adam
import wandb
import yaml
import argparse
from tqdm import tqdm
import numpy as np
import os
from set_seed import set_seed
from transformers import AutoTokenizer
import torchmetrics
import pandas as pd

from models import SBERT_base_Model, BERT_base_Model
from datasets import KorSTSDatasets, Collate_fn, bucket_pair_indices, KorSTSDatasets_for_BERT


In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)

parser = argparse.ArgumentParser()
parser.add_argument('--model_name', default='monologg/koelectra-base-v3-discriminator', type=str)
parser.add_argument('--model_type', default='BERT', type=str)
parser.add_argument('--model_path', default='results/monologg-koelectra-base-v3-discriminator.pt', type=str)
args = parser.parse_args(args=[])

test_datasets = KorSTSDatasets_for_BERT('NLP_dataset/han_processed_test.csv', args.model_name)
valid_datasets = KorSTSDatasets_for_BERT('NLP_dataset/han_processed_dev.csv', args.model_name)
collate_fn = Collate_fn(test_datasets.pad_id, args.model_name)


In [None]:

test_loader = DataLoader(
    test_datasets, 
    collate_fn=collate_fn,
    batch_size=64,
)
valid_loader = DataLoader(
    valid_datasets,
    collate_fn=collate_fn,
    batch_size=64,
)

In [40]:
model = BERT_base_Model(args.model_name)
model.load_state_dict(torch.load(args.model_path))
print("weights loaded from", args.model_path)
model.to(device)

Some weights of the model checkpoint at monologg/koelectra-base-v3-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias', 'discriminator_predictions.dense_prediction.weight']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: 

weights loaded from results/monologg-koelectra-base-v3-discriminator.pt


BERT_base_Model(
  (bert): ElectraForSequenceClassification(
    (electra): ElectraModel(
      (embeddings): ElectraEmbeddings(
        (word_embeddings): Embedding(35000, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): ElectraEncoder(
        (layer): ModuleList(
          (0): ElectraLayer(
            (attention): ElectraAttention(
              (self): ElectraSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): ElectraSelfOutput(
                (dense): Linear(in_features=768

In [11]:
test_datasets[0]

IndexError: list index out of range

In [41]:
val_predictions = []
val_labels = []
test_predictions = []
with torch.no_grad():
    model.eval()
    for i, data in enumerate(tqdm(valid_loader)):
        if args.model_type == "SBERT":
            s1, s2, label = data
            s1 = s1.to(device)
            s2 = s2.to(device)
            label = label.to(device)
            logits = model(s1, s2)
        else:
            s1, label = data
            s1 = s1.to(device)
            label = label.to(device)
            logits = model(s1)
        logits = logits.squeeze(-1)
        for logit in logits.to(torch.device("cpu")).detach():
            val_predictions.append(logit)
        for lab in label.to(torch.device("cpu")).detach(): 
            val_labels.append(lab)

    for i, data in enumerate(tqdm(test_loader)):
        if args.model_type == "SBERT":
            s1, s2, label = data
            s1 = s1.to(device)
            s2 = s2.to(device)
            logits = model(s1, s2)
        else:
            s1, label = data
            s1 = s1.to(device)
            logits = model(s1)
        logits = logits.squeeze(-1)
        print(logits)
        for logit in logits.to(torch.device("cpu")).detach():
            test_predictions.append(logit)

100%|██████████| 9/9 [00:00<00:00, 10.22it/s]
0it [00:00, ?it/s]


In [42]:
print(len(val_predictions))
print(len(val_labels))
pearson = torchmetrics.functional.pearson_corrcoef(torch.tensor(val_predictions), torch.tensor(val_labels))
print("valid pearson = ",pearson)
output = pd.read_csv('NLP_dataset/sample_submission.csv')
print(len(test_predictions))
#output['target'] = test_predictions
#output.to_csv('output.csv', index=False)


550
550
valid pearson =  tensor(0.9127)
0
