# Part 3: Deep Learning Model with Transformers
## Author: Brady Lamson
## Date: Fall 2023
# Overview and Motivation

This portion of the project has a couple distinct goals. 

Firstly, I shall utilize the `distilibert-base-uncased` model to attempt to predict the primary and secondary types of pokemon entirely from their text descriptions.

Secondly, this notebook will function as a detailed walkthrough to fine tuning a huggingface transformer. Much of the information for this task is scattered throughout documentation and articles of varying levels of utility, so compiling all of that information into one notebook will result in something that will hopefully be useful to me and anyone who may read this. 

This walkthrough will also utilize hyperparameter search using the `optuna` library, which many articles I have read online seem to lack. I hope this will give this walkthrough a useful niche that other guides have not filled.

## References

As I've never done multi-label classification using `transformers` before I'll be using [this guide by Ronak Patel](https://colab.research.google.com/github/rap12391/transformers_multilabel_toxic/blob/master/toxic_multilabel.ipynb#scrollTo=CQQ7CoOag_r7) that is featured in [towardsdatascience](https://towardsdatascience.com/transformers-for-multilabel-classification-71a1a0daf5e1). I won't be following it 1:1 but it's there to help me get some traction.

## Potential Limitations

Performance of this model will be sought after but is not the end goal. I fear that the dataset I am working with will put a cap on performance. Pokemon types are extremely varied, with 19 types existing in this dataset alone. On top of that, I am predicting on both primary and secondary types which turns this into a multi-label prediction problem. Thus, combinations of types become important and many combinations of types only appear once. This is a limitation that likely cannot be overcome without removing problematic rows from the training split or simply acquiring more data. 

A future improvement that is outside the scope of this project is to collect all of the pokemon descriptions from each game. There are many pokemon games, and using this would allow us to duplicate many pokemon and artifically make certain type combinations more frequent and inflate our dataset. This would also provide more descriptions to train on as they tend to be similar but not identical in every game. This is obviously not without its downsides as it would inflate the frequency of already frequent type combinations, but a variant of this plan with a bit more thought put into it may be worth considering if maximizing model performance is a priority.

# Data Loading and Preprocessing

Our goal here is to do the same pre-processing as in part 2. So we'll have a bit of a repeat of that content.
From there we'll need to convert our dataset to the transformers `DatasetDict` which will contain all of our splits. The big difference here is taking our same dataframe and doing what we need to do to it to get it working within the transforers framework.

In [1]:
import pandas as pd
import numpy as np
import os

# Set seed
np.random.seed(776)

In [2]:
data_path = "./data/pokemon.csv"
data_exists = os.path.isfile(data_path)

if not data_exists:
    # This part requires a kaggle api key. On linux this will be saved to your home directory in .kaggle/kaggle.json
    !kaggle datasets download -d cristobalmitchell/pokedex
    !unzip pokedex.zip -d data

df = (
    # load in the data
    pd.read_csv(data_path, sep='\t', encoding='utf-16-le')
    # select the relevant columns
    .loc[:, ['english_name', 'primary_type', 'secondary_type', 'description']]
    # Change the type columns into categories and handle NaNs in secondary typing
    .assign(
        primary_type=lambda x: x['primary_type'].astype("category"),
        secondary_type=lambda x: x['secondary_type'].fillna("none").astype("category")
    )
)
display(df.head())
display(df.info())
display(df.describe())

Unnamed: 0,english_name,primary_type,secondary_type,description
0,Bulbasaur,grass,poison,There is a plant seed on its back right from t...
1,Ivysaur,grass,poison,"When the bulb on its back grows large, it appe..."
2,Venusaur,grass,poison,Its plant blooms when it is absorbing solar en...
3,Charmander,fire,none,It has a preference for hot things. When it ra...
4,Charmeleon,fire,none,"It has a barbaric nature. In battle, it whips ..."


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 898 entries, 0 to 897
Data columns (total 4 columns):
 #   Column          Non-Null Count  Dtype   
---  ------          --------------  -----   
 0   english_name    898 non-null    object  
 1   primary_type    898 non-null    category
 2   secondary_type  898 non-null    category
 3   description     898 non-null    object  
dtypes: category(2), object(2)
memory usage: 17.3+ KB


None

Unnamed: 0,english_name,primary_type,secondary_type,description
count,898,898,898,898
unique,898,18,19,896
top,Bulbasaur,water,none,Although it’s alien to this world and a danger...
freq,1,123,429,3


## Vectorize Categorical Data

Here we'll do the same one-hot encoding as in part 2. Here we'll do it before the splits though as, in retrospect, doing this after the split made no sense. 

In [5]:
names = list(df.secondary_type.unique())
id2label = {i: label for i, label in enumerate(names)}
label2id = {label: i for i, label in enumerate(names)}

def map_types(row):
    type_encoding = [0] * len(names)
    primary_id = label2id[row['primary_type']]
    secondary_id = label2id[row['secondary_type']]

    type_encoding[primary_id] = 1
    type_encoding[secondary_id] = 1
    
    # return [primary_id, secondary_id]
    return type_encoding

df['labels'] = df.apply(lambda row: map_types(row), axis=1)
df

Unnamed: 0,english_name,primary_type,secondary_type,description,cleaned_text,labels
0,Bulbasaur,grass,poison,There is a plant seed on its back right from t...,plant seed back right day pokémon born seed sl...,"[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
1,Ivysaur,grass,poison,"When the bulb on its back grows large, it appe...",bulb back grows large appears lose ability sta...,"[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
2,Venusaur,grass,poison,Its plant blooms when it is absorbing solar en...,plant blooms absorbing solar energy stays move...,"[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ..."
3,Charmander,fire,none,It has a preference for hot things. When it ra...,preference hot things rains steam said spout t...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ..."
4,Charmeleon,fire,none,"It has a barbaric nature. In battle, it whips ...",barbaric nature battle whips fiery tail around...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ..."
...,...,...,...,...,...,...
893,Regieleki,electric,none,This Pokémon is a cluster of electrical energy...,pokémon cluster electrical energy said removin...,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
894,Regidrago,dragon,none,An academic theory proposes that Regidrago’s a...,academic theory proposes regidrago arms head a...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
895,Glastrier,ice,none,Glastrier emits intense cold from its hooves. ...,glastrier emits intense cold hooves also belli...,"[0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
896,Spectrier,ghost,none,It probes its surroundings with all its senses...,probes surroundings senses save one use sense ...,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [6]:
print(df.iloc[0]["primary_type"])
print(df.iloc[0]["secondary_type"])
print(df.iloc[0]["labels"])

grass
poison
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


## Sanity Check: Verify our Preprocessing and Mapping Worked

In [8]:
for index, row in df.head().iterrows():
    actual_primary = row["primary_type"]
    actual_secondary = row["secondary_type"]
    mapped_types = [i for i, x in enumerate(row["labels"]) if x == 1]
    mapped_primary = id2label[mapped_types[0]]
    mapped_secondary = id2label[mapped_types[1]]

    print(f"Actual: {actual_primary}, {actual_secondary}")
    print(f"Mapped: {mapped_primary}, {mapped_secondary}\n")

Actual: grass, poison
Mapped: poison, grass

Actual: grass, poison
Mapped: poison, grass

Actual: grass, poison
Mapped: poison, grass

Actual: fire, none
Mapped: none, fire

Actual: fire, none
Mapped: none, fire



## Split The Dataset and Create DatasetDict

In [9]:
from sklearn.model_selection import train_test_split

# This provides a 70/15/15 split
train, test = train_test_split(df, test_size=0.3, random_state=10)
test, val = train_test_split(test, test_size=0.5, random_state=10)

Now what we want to do is convert this into a format transformers can work with, the `Dataset` object. It's just another way of storing data is all, nothing scary. We use their `Dataset.from_pandas()` method to easily convert and provide some additional information. Really the only important part here is specifying that the labels are a "Sequence" of "ClassLabels". Or, in laymens terms, a list of class ids.

Then we just use a larger container `DatasetDict` to easily store all 3 of our splits. So we don't have to juggle 5000 million different objects. 

In [10]:
from datasets import (
    Dataset, DatasetDict, Features,
    ClassLabel, Value, Sequence
)
from transformers import AutoTokenizer

In [11]:
def create_dataset(split: pd.DataFrame, mapper: dict) -> Dataset:
    """
    Converts a pandas dataframe into a Dataset object
    keeps only a handful of the columns I care about
    """
    
    names = list(mapper.keys())
    ds = Dataset.from_pandas(
        df = split[['english_name', 'description', 'labels']],
        features = Features({
            "english_name": Value(dtype="string"),
            "description": Value(dtype="string"),
            'labels': Sequence(
                feature=Value(dtype="float32"),
                length=len(names)
            ),
            "__index_level_0__": Value(dtype="int64")
        })
    )

    return ds

In [12]:
ds = DatasetDict({
    "train": create_dataset(train, label2id),
    "test": create_dataset(test, label2id),
    "val": create_dataset(val, label2id)
})

ds

DatasetDict({
    train: Dataset({
        features: ['english_name', 'description', 'labels', '__index_level_0__'],
        num_rows: 628
    })
    test: Dataset({
        features: ['english_name', 'description', 'labels', '__index_level_0__'],
        num_rows: 135
    })
    val: Dataset({
        features: ['english_name', 'description', 'labels', '__index_level_0__'],
        num_rows: 135
    })
})

In [13]:
# Exampe row
ds["train"][0]

{'english_name': 'Wartortle',
 'description': 'It is recognized as a symbol of longevity. If its shell has algae on it, that Wartortle is very old.',
 'labels': [0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0],
 '__index_level_0__': 7}

## Tokenize Pokedex Description

In [14]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_description(example):
    tokenized_desc = tokenizer(example['description'], padding='max_length', is_split_into_words=False, max_length=60)

    return tokenized_desc

ds = ds.map(
    lambda example: tokenize_description(example)
)

Map:   0%|          | 0/628 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

In [15]:
# Take note of the new features
ds["train"]

Dataset({
    features: ['english_name', 'description', 'labels', '__index_level_0__', 'input_ids', 'attention_mask'],
    num_rows: 628
})

### Verify tokenizer worked as intended

In [16]:
first_row = ds["train"][0]
print(tokenizer.decode(first_row["input_ids"], skip_special_tokens=True))
print(first_row["description"])

it is recognized as a symbol of longevity. if its shell has algae on it, that wartortle is very old.
It is recognized as a symbol of longevity. If its shell has algae on it, that Wartortle is very old.


### Verify equality of lengths

In [17]:
for index, example in enumerate(ds['train']):
    if index < 5:
        input_ids_length = len(example['input_ids'])
        attention_mask_length = len(example['attention_mask'])
        
        # Print the lengths for each feature in this row
        print(f'Input IDs Length: {input_ids_length}')
        print(f'Attention Mask Length: {attention_mask_length}\n')

lengths = [len(row["input_ids"]) for row in ds["train"]]
max(lengths)

Input IDs Length: 60
Attention Mask Length: 60

Input IDs Length: 60
Attention Mask Length: 60

Input IDs Length: 60
Attention Mask Length: 60

Input IDs Length: 60
Attention Mask Length: 60

Input IDs Length: 60
Attention Mask Length: 60



60

In [18]:
ds.set_format("torch")

# Building the Trainer

In [19]:
from transformers import (
    DataCollator, AutoModelForSequenceClassification, TrainingArguments, Trainer
)

Below is code modified from [this notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb). I had to tweak stuff like changing sigmoid to softmax and adding in my own custom function for extracting the top 2 predictions. Many examples utilize a threshold but due to the large number of classes I feel like that's not a good fit for this model. 

Gonna be honest here, mostly taking these metrics at face-value. Stuff gets weird in multi-label models and I had a lot of trouble figuring this out on my own. 

In [40]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
from torch.nn.functional import softmax
from torch import tensor, topk

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply softmax to our logits to get our probabilities
    probs = softmax(torch.Tensor(predictions))

    def tensor_to_indices(tensor):
        # Here this function takes a tensor of probabilities and returns the indices of the two highest probs
        _, indices = topk(tensor, 2)
        indices = indices.tolist()
    
        return sorted(indices)
    
    # next, use said indices to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    pred_indices = tensor_to_indices(probs)
    for index, row in enumerate(y_pred):
        row[pred_indices[index]] = 1
    
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics


def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

In [23]:
num_labels = len(label2id.keys())
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=len(names), problem_type="multi_label_classification", id2label=id2label, label2id=label2id
)
model

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.bias', 'classifier.weight', 'classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [41]:
training_arguments = TrainingArguments(
    output_dir="models",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.1,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    remove_unused_columns=False,
    load_best_model_at_end=True
)

cols_to_train_on = ["input_ids", "attention_mask", "labels"]

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=ds["train"].select_columns(cols_to_train_on),
    eval_dataset=ds["val"].select_columns(cols_to_train_on),
    compute_metrics=compute_metrics,
    tokenizer=tokenizer
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,0.1659,0.269975,0.211503,0.559563,0.02963
2,0.1492,0.26779,0.267161,0.590711,0.059259
3,0.1375,0.268863,0.252319,0.582405,0.051852
4,0.1319,0.271124,0.252319,0.582405,0.044444
5,0.1231,0.272594,0.233766,0.572022,0.037037
6,0.1141,0.276763,0.244898,0.578251,0.044444
7,0.1079,0.279984,0.263451,0.588634,0.059259
8,0.1039,0.278569,0.270872,0.592787,0.066667
9,0.1003,0.279651,0.241187,0.576175,0.044444
10,0.0978,0.282006,0.244898,0.578251,0.037037


WOOOOO -----
tensor([[0.0345, 0.0159, 0.4387,  ..., 0.0185, 0.0384, 0.0385],
        [0.0682, 0.0367, 0.0263,  ..., 0.0094, 0.0447, 0.0099],
        [0.0317, 0.5222, 0.0031,  ..., 0.0071, 0.0362, 0.0030],
        ...,
        [0.0280, 0.3116, 0.0167,  ..., 0.0096, 0.0593, 0.0139],
        [0.0267, 0.1262, 0.1946,  ..., 0.0119, 0.0357, 0.1758],
        [0.0736, 0.3607, 0.0106,  ..., 0.0513, 0.0229, 0.0111]])
[[0, 8], [0, 8], [0, 8], [1, 0], [1, 2], [1, 2], [1, 3], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 6], [1, 6], [1, 6], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 11], [1, 12], [1, 12], [1, 12], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18],

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[0.0221, 0.0259, 0.2714,  ..., 0.0134, 0.0267, 0.0365],
        [0.0154, 0.0404, 0.0099,  ..., 0.0033, 0.0138, 0.0071],
        [0.0100, 0.3751, 0.0019,  ..., 0.0033, 0.0141, 0.0025],
        ...,
        [0.0103, 0.5907, 0.0073,  ..., 0.0041, 0.0385, 0.0100],
        [0.0153, 0.2264, 0.0997,  ..., 0.0071, 0.0267, 0.1659],
        [0.0126, 0.6907, 0.0020,  ..., 0.0090, 0.0069, 0.0052]])
[[0, 1], [0, 14], [1, 2], [1, 2], [1, 2], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 6], [1, 6], [1, 6], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 11], [1, 13], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 17], [1, 18], [1, 18], [1, 18], [1, 18],

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[0.0241, 0.0179, 0.4606,  ..., 0.0144, 0.0226, 0.0439],
        [0.0325, 0.0517, 0.0120,  ..., 0.0042, 0.0174, 0.0093],
        [0.0107, 0.6329, 0.0013,  ..., 0.0031, 0.0147, 0.0017],
        ...,
        [0.0147, 0.5064, 0.0091,  ..., 0.0047, 0.0497, 0.0104],
        [0.0181, 0.2602, 0.0946,  ..., 0.0073, 0.0215, 0.2199],
        [0.0118, 0.8682, 0.0012,  ..., 0.0093, 0.0036, 0.0049]])
[[0, 1], [0, 8], [1, 0], [1, 2], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 12], [1, 12], [1, 12], [1, 13], [1, 13], [1, 13], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 14], [1, 17], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], [1, 18], 

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[2.5240e-02, 2.4547e-02, 3.0166e-01,  ..., 1.4958e-02, 1.8843e-02,
         3.9302e-02],
        [2.1594e-02, 3.6881e-02, 3.8243e-03,  ..., 2.5235e-03, 1.0620e-02,
         3.3865e-03],
        [7.4693e-03, 6.2448e-01, 6.2874e-04,  ..., 2.0239e-03, 8.7271e-03,
         1.0144e-03],
        ...,
        [1.3736e-02, 4.6127e-01, 5.7586e-03,  ..., 4.3051e-03, 4.8369e-02,
         6.8296e-03],
        [2.0470e-02, 2.7769e-01, 5.0803e-02,  ..., 7.3598e-03, 2.1330e-02,
         1.4872e-01],
        [2.3912e-02, 7.2967e-01, 9.9070e-04,  ..., 1.4658e-02, 4.9046e-03,
         3.5821e-03]])
[[0, 1], [0, 8], [0, 8], [1, 0], [1, 2], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 5], [1, 6], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 11], [1, 12], [1, 

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[2.0763e-02, 1.3582e-02, 4.4781e-01,  ..., 1.2231e-02, 1.5822e-02,
         4.5300e-02],
        [1.9988e-02, 4.1928e-02, 6.2873e-03,  ..., 2.2963e-03, 1.0065e-02,
         6.3019e-03],
        [4.2946e-03, 7.4269e-01, 3.9157e-04,  ..., 1.0532e-03, 5.4863e-03,
         7.8093e-04],
        ...,
        [6.9822e-03, 6.5195e-01, 4.2548e-03,  ..., 2.2844e-03, 2.7320e-02,
         7.6285e-03],
        [1.6852e-02, 2.3673e-01, 7.4075e-02,  ..., 6.3528e-03, 1.6912e-02,
         2.2978e-01],
        [1.0621e-02, 8.6865e-01, 5.8674e-04,  ..., 9.0501e-03, 2.2443e-03,
         2.8467e-03]])
[[0, 1], [0, 8], [0, 17], [1, 0], [1, 0], [1, 0], [1, 2], [1, 2], [1, 2], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 11], [1, 12], [1, 12], 

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[2.5431e-02, 2.1328e-02, 3.2579e-01,  ..., 1.6165e-02, 1.2807e-02,
         4.2732e-02],
        [2.3765e-02, 7.0263e-02, 2.7846e-03,  ..., 2.2505e-03, 7.0133e-03,
         4.4207e-03],
        [3.5312e-03, 8.7157e-01, 2.2979e-04,  ..., 8.5105e-04, 3.9201e-03,
         4.9250e-04],
        ...,
        [6.8491e-03, 5.9550e-01, 3.1415e-03,  ..., 2.2123e-03, 2.3508e-02,
         5.4354e-03],
        [1.6147e-02, 4.3093e-01, 2.9454e-02,  ..., 5.7809e-03, 9.5838e-03,
         1.7313e-01],
        [9.7396e-03, 8.8490e-01, 4.3095e-04,  ..., 9.2850e-03, 1.5674e-03,
         2.5110e-03]])
[[0, 1], [0, 8], [0, 8], [1, 0], [1, 2], [1, 2], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 6], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10],

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[2.0062e-02, 2.5252e-02, 2.4382e-01,  ..., 1.2616e-02, 1.2432e-02,
         2.8876e-02],
        [1.4558e-02, 3.3585e-02, 2.2823e-03,  ..., 1.5373e-03, 6.3287e-03,
         2.1168e-03],
        [2.3379e-03, 8.3908e-01, 2.1540e-04,  ..., 7.0387e-04, 3.4231e-03,
         4.2053e-04],
        ...,
        [5.0951e-03, 7.2086e-01, 2.4376e-03,  ..., 1.7149e-03, 2.7081e-02,
         3.3581e-03],
        [1.5989e-02, 4.0206e-01, 3.1882e-02,  ..., 5.2081e-03, 1.4409e-02,
         1.1711e-01],
        [7.0865e-03, 8.7094e-01, 2.8643e-04,  ..., 6.3925e-03, 1.9132e-03,
         1.0855e-03]])
[[0, 1], [0, 8], [0, 17], [1, 0], [1, 0], [1, 2], [1, 2], [1, 2], [1, 3], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 5], [1, 6], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[1.8947e-02, 2.1494e-02, 3.0739e-01,  ..., 1.1300e-02, 9.5443e-03,
         4.4381e-02],
        [1.4089e-02, 3.0203e-02, 2.5458e-03,  ..., 1.4294e-03, 5.1985e-03,
         2.8772e-03],
        [1.7946e-03, 8.5959e-01, 1.6597e-04,  ..., 5.1501e-04, 2.3536e-03,
         4.3511e-04],
        ...,
        [4.4431e-03, 7.4887e-01, 2.2557e-03,  ..., 1.3481e-03, 2.1100e-02,
         4.0723e-03],
        [1.3395e-02, 3.5014e-01, 3.9321e-02,  ..., 4.6254e-03, 8.8922e-03,
         2.2077e-01],
        [3.9335e-03, 9.1509e-01, 2.3439e-04,  ..., 3.9714e-03, 1.0752e-03,
         1.6417e-03]])
[[0, 1], [0, 6], [0, 8], [0, 17], [1, 0], [1, 2], [1, 2], [1, 2], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[1.8704e-02, 1.9626e-02, 2.6497e-01,  ..., 1.3109e-02, 9.9849e-03,
         3.2812e-02],
        [1.2269e-02, 2.2652e-02, 1.9150e-03,  ..., 1.3335e-03, 4.9737e-03,
         1.8249e-03],
        [2.0349e-03, 8.4145e-01, 1.7579e-04,  ..., 6.6797e-04, 2.7862e-03,
         3.8491e-04],
        ...,
        [4.7168e-03, 7.0603e-01, 2.2381e-03,  ..., 1.6261e-03, 2.6719e-02,
         3.2898e-03],
        [1.4189e-02, 3.5915e-01, 3.5001e-02,  ..., 5.1912e-03, 9.6885e-03,
         1.8843e-01],
        [3.7227e-03, 9.2297e-01, 1.9348e-04,  ..., 4.9731e-03, 1.1536e-03,
         1.0588e-03]])
[[0, 1], [0, 8], [0, 17], [1, 0], [1, 0], [1, 0], [1, 2], [1, 2], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 6], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [

  probs = softmax(torch.Tensor(predictions))


WOOOOO -----
tensor([[1.8416e-02, 2.3541e-02, 3.1584e-01,  ..., 1.2036e-02, 9.6283e-03,
         4.1337e-02],
        [1.5202e-02, 4.0350e-02, 2.2235e-03,  ..., 1.3950e-03, 5.8749e-03,
         2.6865e-03],
        [1.3331e-03, 9.2172e-01, 1.1399e-04,  ..., 4.1046e-04, 1.8451e-03,
         3.0296e-04],
        ...,
        [3.8073e-03, 7.3136e-01, 1.9318e-03,  ..., 1.2849e-03, 2.0539e-02,
         3.4540e-03],
        [1.2178e-02, 4.3186e-01, 2.9150e-02,  ..., 4.0981e-03, 7.7948e-03,
         2.0184e-01],
        [2.5750e-03, 9.5197e-01, 1.3867e-04,  ..., 3.3844e-03, 7.5887e-04,
         9.4348e-04]])
[[0, 1], [0, 8], [0, 17], [1, 0], [1, 0], [1, 0], [1, 2], [1, 2], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 4], [1, 6], [1, 6], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1

  probs = softmax(torch.Tensor(predictions))


TrainOutput(global_step=400, training_loss=0.12316525280475617, metrics={'train_runtime': 171.577, 'train_samples_per_second': 36.602, 'train_steps_per_second': 2.331, 'total_flos': 97517281636800.0, 'train_loss': 0.12316525280475617, 'epoch': 10.0})

In [55]:
trainer.predict(ds["test"].select_columns(cols_to_train_on))

WOOOOO -----
tensor([[9.0379e-03, 2.5823e-01, 5.4908e-03,  ..., 3.2145e-03, 8.1895e-03,
         1.4389e-02],
        [3.9041e-01, 2.6214e-02, 3.2358e-03,  ..., 7.4139e-03, 3.2413e-02,
         8.6807e-04],
        [7.5333e-05, 9.9686e-01, 2.8050e-05,  ..., 9.5401e-05, 1.0841e-04,
         3.6687e-04],
        ...,
        [3.3390e-02, 2.7972e-01, 3.0716e-02,  ..., 5.3130e-02, 1.6841e-02,
         3.8278e-02],
        [4.6783e-04, 8.8663e-01, 1.2667e-03,  ..., 2.2241e-03, 1.4061e-03,
         3.3956e-03],
        [6.1318e-04, 9.3498e-01, 5.5742e-04,  ..., 1.2035e-03, 8.2783e-04,
         3.6331e-03]])
[[0, 3], [0, 14], [0, 17], [0, 17], [1, 0], [1, 0], [1, 0], [1, 0], [1, 3], [1, 3], [1, 3], [1, 4], [1, 4], [1, 4], [1, 4], [1, 5], [1, 5], [1, 6], [1, 6], [1, 6], [1, 7], [1, 7], [1, 7], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 8], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 9], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1, 10], [1,

  probs = softmax(torch.Tensor(predictions))


PredictionOutput(predictions=array([[-3.0452335 ,  0.30720177, -3.5435908 , ..., -4.0789914 ,
        -3.1438055 , -2.5802224 ],
       [ 0.5469722 , -2.1539192 , -4.2459393 , ..., -3.4168632 ,
        -1.9416597 , -5.5616994 ],
       [-4.7129016 ,  4.7775445 , -5.7008457 , ..., -4.4767337 ,
        -4.3489485 , -3.1298065 ],
       ...,
       [-3.69219   , -1.5666349 , -3.7756586 , ..., -3.2276976 ,
        -4.3766274 , -3.5555682 ],
       [-5.1851215 ,  2.3619444 , -4.1890297 , ..., -3.626136  ,
        -4.0846944 , -3.2029939 ],
       [-4.6511045 ,  2.6785204 , -4.746437  , ..., -3.9767478 ,
        -4.3509526 , -2.8719153 ]], dtype=float32), label_ids=array([[0., 1., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 1., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 1.]], dtype=float32), metrics={'test_loss': 0.2651369869709015, 'test_f1': 0.3, 'test_roc_auc': 0