In [2]:
import logging
import torch
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, Trainer
from datasets import Dataset, ClassLabel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from train_test_split import select_eval_with_cluster
from preprocessing import preprocess
from evaluation import evaluate
from bert import tokenize, get_BERT, prepare_dataset, compute_metrics

[nltk_data] Downloading package stopwords to /Users/jonas/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /Users/jonas/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package wordnet to /Users/jonas/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [4]:
logging.basicConfig(level=logging.INFO)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [6]:
MODEL = 'models/baseline'
TOKENIZER = 'bert-base-uncased' # 'cardiffnlp/twitter-roberta-base-sentiment-latest'
PREPROCESSING = None

In [7]:
df_cluster_map = pd.read_csv('clustering+bert/eval.csv')
df_cluster_map

Unnamed: 0.1,Unnamed: 0,index,cluster
0,0,922648.0,0
1,1,944379.0,4
2,2,2182552.0,4
3,3,786886.0,4
4,4,1130778.0,3
...,...,...,...
1249995,1249995,1478680.0,2
1249996,1249996,1972646.0,4
1249997,1249997,1710597.0,5
1249998,1249998,1835784.0,4


In [9]:
CLUSTERS = df_cluster_map['cluster'].unique()
CLUSTERS

array([0, 4, 3, 5, 2, 1, 6])

In [None]:
model = get_BERT(MODEL, device)

In [None]:
def evaluate_cluster(cluster: int) -> float:
  df_eval = select_eval_with_cluster(df_cluster_map, cluster)
  dataset_eval = prepare_dataset(df_eval, preprocessing=PREPROCESSING)

  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
  eval_tokenized = tokenize(dataset_eval, tokenizer)

  trainer = Trainer(model, eval_dataset=eval_tokenized, tokenizer=tokenizer, compute_metrics=compute_metrics)
  metrics = trainer.evaluate()
  return metrics

In [None]:
metrics = {}

for cluster in CLUSTERS:
  metrics[clusters] = evaluate_cluster(cluster)