In [1]:
from auto_mm_bench.datasets import dataset_registry
import xgboost as xgb
import lightgbm as lgb
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict, Dataset
from sklearn.metrics import precision_score, recall_score, roc_auc_score

  from .autonotebook import tqdm as notebook_tqdm


## Tab

In [2]:
dataset_name = 'imdb_genre_prediction'

train_dataset = dataset_registry.create(dataset_name, 'train')
test_dataset = dataset_registry.create(dataset_name, 'test')
train_dataset.data

label_cols = train_dataset.label_columns
tab_cols = ['Year','Runtime (Minutes)', 'Rating', 'Votes', 'Revenue (Millions)','Metascore', 'Rank']
text_cols = ['Description']

ds = load_dataset('james-burton/imdb_genre_prediction')

train_df = ds['train'].to_pandas()
test_df = ds['test'].to_pandas()
X_train_tab = train_df[tab_cols]
y_train = train_df[label_cols]
X_test_tab = test_df[tab_cols]
y_test = test_df[label_cols]

tab_model = lgb.LGBMClassifier(random_state=42)
tab_model.fit(X_train_tab,y_train)
y_pred = tab_model.predict(X_test_tab)
y_pred_probs = tab_model.predict_proba(X_test_tab)

print('Accuracy: ', np.mean(y_test.values.flatten() == y_pred))
print('ROC AUC: ', roc_auc_score(y_test, y_pred_probs[:,1]))

Found cached dataset parquet (/home/james/.cache/huggingface/datasets/james-burton___parquet/james-burton--imdb_genre_prediction-f183d7ab5d966777/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 3/3 [00:00<00:00, 799.52it/s]

Accuracy:  0.795
ROC AUC:  0.8273446101491343



  y = column_or_1d(y, warn=True)
  y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)


## Text

In [3]:
text_model = AutoModelForSequenceClassification.from_pretrained('../models/imdb_genre/true-smoke-9/checkpoint-11', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
    
# Tokenize the dataset
def encode(examples):
    return {
        "labels": np.array([examples['Genre_is_Drama']]),
                **tokenizer(examples['Description'], truncation=True, padding="max_length")}
ds = ds.map(encode)
trainer = Trainer(model=text_model)
preds = trainer.predict(ds['test']).predictions


print('Accuracy: ', np.mean(np.argmax(preds, axis=1) == y_test.values.flatten()))
print('ROC AUC: ', roc_auc_score(y_test.values.flatten(), preds[:,1]))

Loading cached processed dataset at /home/james/.cache/huggingface/datasets/james-burton___parquet/james-burton--imdb_genre_prediction-f183d7ab5d966777/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-ce1bd941fbff1bcd.arrow
Loading cached processed dataset at /home/james/.cache/huggingface/datasets/james-burton___parquet/james-burton--imdb_genre_prediction-f183d7ab5d966777/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-515a0447fbeb31f7.arrow
The following columns in the test set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: Year, Description, Rank, Revenue (Millions), Title, Actors, Genre_is_Drama, Director, Rating, Votes, Runtime (Minutes), Metascore. If Year, Description, Rank, Revenue (Millions), Title, Actors, Genre_is_Drama, Director, Rating, Votes, Runtime (Minutes), Metascore are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ign

Accuracy:  0.685
ROC AUC:  0.7862075868281453





In [4]:
preds

array([[-0.33137363,  0.67620337],
       [ 0.7707363 , -0.43859714],
       [-0.15477103,  0.461668  ],
       [-0.51372033,  0.5515399 ],
       [ 0.33200806, -0.09582425],
       [ 0.30575514, -0.08119192],
       [-0.20814325,  0.45128378],
       [-0.11868256,  0.35581833],
       [ 0.64184636, -0.40104374],
       [ 0.18736653,  0.15220319],
       [-0.24325113,  0.24067873],
       [ 0.919999  , -0.72023445],
       [ 0.5940566 , -0.5238441 ],
       [ 0.31310907, -0.17255639],
       [ 0.87721115, -0.66871613],
       [ 0.53969944, -0.2651341 ],
       [ 0.34864676, -0.07072233],
       [ 0.69314694, -0.555075  ],
       [-0.40269408,  0.6778864 ],
       [ 0.44806248, -0.14066026],
       [ 0.8339696 , -0.5885786 ],
       [ 0.6659689 , -0.5506062 ],
       [-0.06016606,  0.24585232],
       [ 0.8024231 , -0.6064743 ],
       [-0.20577443,  0.40001944],
       [-0.20434828,  0.63057333],
       [ 0.09521575,  0.23511131],
       [ 0.10173157,  0.10188404],
       [-0.20979197,