In [80]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer

checkpoint = "distilbert/distilbert-base-uncased"

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
from datasets import load_dataset

raw_dataset = load_dataset("json", data_files="./movesets-train.json")

Generating train split: 467 examples [00:00, 16802.83 examples/s]


In [4]:
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['input_text', 'output_text'],
        num_rows: 467
    })
})

In [25]:
raw_dataset = raw_dataset.rename_column("output_text", "labels")
raw_dataset = raw_dataset.rename_column("input_text", "text")

In [87]:
raw_dataset = raw_dataset.class_encode_column("labels")



Casting to class labels: 100%|██████████| 467/467 [00:00<00:00, 38470.78 examples/s]


In [89]:
raw_dataset["train"].features

{'text': Value(dtype='string', id=None),
 'labels': ClassLabel(names=['other', 'physical', 'special'], id=None)}

In [90]:
def tokenize_data(dataset):
    return tokenizer(dataset["text"], padding="max_length", truncation=True)

In [91]:
tokenized_dataset = raw_dataset.map(tokenize_data, batched=True)


Map:   0%|          | 0/467 [00:00<?, ? examples/s]

Map: 100%|██████████| 467/467 [00:00<00:00, 6945.27 examples/s]


In [92]:
# tokenized_dataset.rename_column("output_text", "labels")
# tokenized_dataset.rename_column("input_text", "text")
tokenized_dataset = tokenized_dataset.remove_columns(["text"])

In [93]:
tokenized_dataset["train"][0]

{'labels': 1,
 'input_ids': [101,
  2054,
  2828,
  1997,
  2693,
  3503,
  2023,
  6412,
  1024,
  1996,
  5310,
  9530,
  14876,
  26698,
  1996,
  22277,
  2007,
  3177,
  1010,
  2059,
  18296,
  2229,
  1012,
  1996,
  2886,
  4915,
  2302,
  8246,
  1012,
  102,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
 

In [72]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [95]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments("first-try-poke")

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=tokenized_dataset["train"],
    data_collator=data_collator,
)

In [96]:
trainer.train()

  0%|          | 0/177 [08:44<?, ?it/s]
100%|██████████| 177/177 [03:50<00:00,  1.30s/it]

{'train_runtime': 230.8881, 'train_samples_per_second': 6.068, 'train_steps_per_second': 0.767, 'train_loss': 0.4812613449527719, 'epoch': 3.0}





TrainOutput(global_step=177, training_loss=0.4812613449527719, metrics={'train_runtime': 230.8881, 'train_samples_per_second': 6.068, 'train_steps_per_second': 0.767, 'total_flos': 185590135194624.0, 'train_loss': 0.4812613449527719, 'epoch': 3.0})

In [97]:
model = AutoModelForSequenceClassification.from_pretrained("./first-try-poke/checkpoint-177")

In [101]:
from transformers import pipeline

text = "What type of move matches this description: The user attacks by swinging its tail as if it were a vicious wave in a raging storm."

classifier = pipeline("text-classification", model=model, tokenizer=tokenizer, device="mps")

In [107]:
text = "what's the move that throws rocks on the field"
move = classifier(text)

type = move[0]["label"]

types = {
    "LABEL_0": "Other",
    "LABEL_1": "Physical",
    "LABEL_2": "Special"
}

print(types[type])

Physical
