In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import peft


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.chdir(r"C:\Users\ethan\Desktop\scam_detection")

In [3]:
data_df = pd.read_csv('spam.csv', encoding='latin1')
data_df.drop(['Unnamed: 2', 'Unnamed: 3', 'Unnamed: 4'], axis=1, inplace=True)
data_df = data_df.rename(columns={'v1': 'label', 'v2': 'text'})

data_df.head()

Unnamed: 0,label,text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [4]:
from sklearn import model_selection
train_df , val_df = model_selection.train_test_split(data_df, test_size=0.3, random_state=42)

In [5]:
import datasets
train_ds = datasets.Dataset.from_pandas(train_df)
val_ds = datasets.Dataset.from_pandas(val_df)

dataset_dict = datasets.DatasetDict({'train': train_ds, 'val': val_ds})

for split in dataset_dict.keys():
    dataset_dict[split] = dataset_dict[split].remove_columns("__index_level_0__")

dataset_dict


DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 3900
    })
    val: Dataset({
        features: ['label', 'text'],
        num_rows: 1672
    })
})

In [6]:

def label_to_int(label):
    if label == 'spam':
        return 1
    elif label == 'ham':
        return 0
def convert_labels(example):
    example['label'] = label_to_int(example['label'])
    return example

dataset_dict = dataset_dict.map(convert_labels)

print(dataset_dict['train']['label'][:5])

Map: 100%|██████████| 3900/3900 [00:00<00:00, 50592.24 examples/s]
Map: 100%|██████████| 1672/1672 [00:00<00:00, 55831.00 examples/s]

[1, 0, 0, 0, 0]





In [7]:
np.array(dataset_dict['train']['label']).sum()/len(dataset_dict['train']['label'])

0.13538461538461538

In [18]:
import transformers
model_checkpoint = "roberta-base"
id2label = {0: "ham", 1: "spam"}
label2id = {"ham": 0, "spam": 1}

config = transformers.AutoConfig.from_pretrained(
    model_checkpoint,
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)

model = transformers.AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
model

RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [10]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=False)

def tokenize_function(examples):
    return tokenizer(examples["text"], return_tensors="pt", padding="max_length", truncation=True, max_length=512)

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))
print(type(dataset_dict))

tokenized_datasets = dataset_dict.map(tokenize_function, batched=True)
tokenized_datasets


<class 'datasets.dataset_dict.DatasetDict'>


Map: 100%|██████████| 3900/3900 [00:00<00:00, 6969.05 examples/s]
Map: 100%|██████████| 1672/1672 [00:00<00:00, 7299.76 examples/s]


DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 3900
    })
    val: Dataset({
        features: ['label', 'text', 'input_ids', 'attention_mask'],
        num_rows: 1672
    })
})

In [11]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [12]:
import evaluate
accuracy = evaluate.load("accuracy")

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": accuracy.compute(predictions=predictions,references=labels)}

In [13]:
example_list = ["Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...", 
                "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's", 
                "Nah I don't think he goes to usf, he lives around here though",
                "Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030"]

print("Untrained Model Predictions:")
for text in example_list:
    inputs = tokenizer.encode(text, return_tensors="pt")
    logits = model(inputs).logits
    predictions = torch.argmax(logits)
    print(text + " - " + id2label[predictions.tolist()])
    

Untrained Model Predictions:
Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat... - ham
Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's - ham
Nah I don't think he goes to usf, he lives around here though - ham
Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030 - ham


In [14]:
model.train()
model.gradient_checkpointing_enable()
model = peft.prepare_model_for_kbit_training(model)


In [20]:
config = peft.LoraConfig(
    r=4,
    lora_alpha=32,
    target_modules=["query"],
    lora_dropout=0.05,
    bias = "none",
    task_type="SEQ_CLS"
)

model = peft.get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 665,858 || all params: 125,313,028 || trainable%: 0.531355766137899


In [21]:
lr = 1e-3
batch_size = 4
num_epochs = 10

training_args = transformers.TrainingArguments(
    output_dir = model_checkpoint + "-lora-text-classification",
    learning_rate=lr,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)




In [23]:
trainer = transformers.Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["val"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

                                                 
  1%|          | 91/9750 [01:05<12:49, 12.55it/s] 

{'loss': 0.1071, 'grad_norm': 3.4779591260303278e-06, 'learning_rate': 0.0009487179487179487, 'epoch': 0.51}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 

[A[A                                           
  1%|          | 91/9750 [01:57<12:49, 12.55it/s] 
[A
[A

{'eval_loss': 0.05097267031669617, 'eval_accuracy': {'accuracy': 0.9904306220095693}, 'eval_runtime': 14.3926, 'eval_samples_per_second': 116.171, 'eval_steps_per_second': 29.043, 'epoch': 1.0}


                                                 
  1%|          | 91/9750 [02:00<12:49, 12.55it/s] 

{'loss': 0.0913, 'grad_norm': 8.695931319380179e-05, 'learning_rate': 0.0008974358974358974, 'epoch': 1.03}


                                                 
  1%|          | 91/9750 [02:40<12:49, 12.55it/s]  

{'loss': 0.0546, 'grad_norm': 0.0005916114314459264, 'learning_rate': 0.0008461538461538462, 'epoch': 1.54}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [03:30<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.03613193333148956, 'eval_accuracy': {'accuracy': 0.993421052631579}, 'eval_runtime': 14.2891, 'eval_samples_per_second': 117.012, 'eval_steps_per_second': 29.253, 'epoch': 2.0}


                                                 
  1%|          | 91/9750 [03:34<12:49, 12.55it/s]  

{'loss': 0.0575, 'grad_norm': 0.009487445466220379, 'learning_rate': 0.0007948717948717948, 'epoch': 2.05}


                                                 
  1%|          | 91/9750 [04:14<12:49, 12.55it/s]  

{'loss': 0.0506, 'grad_norm': 0.00041246708133257926, 'learning_rate': 0.0007435897435897436, 'epoch': 2.56}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [05:02<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.04127059876918793, 'eval_accuracy': {'accuracy': 0.9940191387559809}, 'eval_runtime': 14.2279, 'eval_samples_per_second': 117.516, 'eval_steps_per_second': 29.379, 'epoch': 3.0}


                                                 
  1%|          | 91/9750 [05:09<12:49, 12.55it/s]  

{'loss': 0.0552, 'grad_norm': 0.10144289582967758, 'learning_rate': 0.0006923076923076923, 'epoch': 3.08}


                                                 
  1%|          | 91/9750 [05:48<12:49, 12.55it/s]  

{'loss': 0.0373, 'grad_norm': 0.0009473948739469051, 'learning_rate': 0.0006410256410256411, 'epoch': 3.59}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [06:34<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.05171268433332443, 'eval_accuracy': {'accuracy': 0.9922248803827751}, 'eval_runtime': 14.2904, 'eval_samples_per_second': 117.001, 'eval_steps_per_second': 29.25, 'epoch': 4.0}


                                                 
  1%|          | 91/9750 [06:43<12:49, 12.55it/s]  

{'loss': 0.032, 'grad_norm': 7.141285459510982e-05, 'learning_rate': 0.0005897435897435898, 'epoch': 4.1}


                                                 
  1%|          | 91/9750 [07:23<12:49, 12.55it/s]  

{'loss': 0.028, 'grad_norm': 6.17435434833169e-05, 'learning_rate': 0.0005384615384615384, 'epoch': 4.62}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [08:07<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.0692494735121727, 'eval_accuracy': {'accuracy': 0.9922248803827751}, 'eval_runtime': 14.2523, 'eval_samples_per_second': 117.314, 'eval_steps_per_second': 29.329, 'epoch': 5.0}


                                                 
  1%|          | 91/9750 [08:17<12:49, 12.55it/s]  

{'loss': 0.046, 'grad_norm': 5.060219336883165e-06, 'learning_rate': 0.0004871794871794872, 'epoch': 5.13}


                                                 
  1%|          | 91/9750 [08:57<12:49, 12.55it/s]  

{'loss': 0.0226, 'grad_norm': 2.3519652003756164e-08, 'learning_rate': 0.0004358974358974359, 'epoch': 5.64}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [09:41<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.04549763351678848, 'eval_accuracy': {'accuracy': 0.993421052631579}, 'eval_runtime': 16.1029, 'eval_samples_per_second': 103.832, 'eval_steps_per_second': 25.958, 'epoch': 6.0}


                                                 
  1%|          | 91/9750 [09:54<12:49, 12.55it/s]  

{'loss': 0.0544, 'grad_norm': 2.320124821153513e-07, 'learning_rate': 0.00038461538461538467, 'epoch': 6.15}


                                                 
  1%|          | 91/9750 [10:34<12:49, 12.55it/s]  

{'loss': 0.0165, 'grad_norm': 0.00027446559397503734, 'learning_rate': 0.0003333333333333333, 'epoch': 6.67}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [11:14<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.05892323702573776, 'eval_accuracy': {'accuracy': 0.9940191387559809}, 'eval_runtime': 14.23, 'eval_samples_per_second': 117.498, 'eval_steps_per_second': 29.375, 'epoch': 7.0}


                                                 
  1%|          | 91/9750 [11:29<12:49, 12.55it/s]  

{'loss': 0.0266, 'grad_norm': 2.2728667559146487e-10, 'learning_rate': 0.00028205128205128203, 'epoch': 7.18}


                                                 
  1%|          | 91/9750 [12:09<12:49, 12.55it/s]  

{'loss': 0.0274, 'grad_norm': 0.00173016672488302, 'learning_rate': 0.0002307692307692308, 'epoch': 7.69}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [12:48<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.06508441269397736, 'eval_accuracy': {'accuracy': 0.9952153110047847}, 'eval_runtime': 14.4931, 'eval_samples_per_second': 115.365, 'eval_steps_per_second': 28.841, 'epoch': 8.0}


                                                 
  1%|          | 91/9750 [13:04<12:49, 12.55it/s]  

{'loss': 0.0125, 'grad_norm': 1.369333858747268e-06, 'learning_rate': 0.0001794871794871795, 'epoch': 8.21}


                                                 
  1%|          | 91/9750 [13:45<12:49, 12.55it/s]  

{'loss': 0.0078, 'grad_norm': 2.5170487916170714e-10, 'learning_rate': 0.0001282051282051282, 'epoch': 8.72}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [14:22<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.07613217085599899, 'eval_accuracy': {'accuracy': 0.9952153110047847}, 'eval_runtime': 14.865, 'eval_samples_per_second': 112.479, 'eval_steps_per_second': 28.12, 'epoch': 9.0}


                                                 
  1%|          | 91/9750 [14:41<12:49, 12.55it/s]  

{'loss': 0.0191, 'grad_norm': 9.259161743102595e-05, 'learning_rate': 7.692307692307693e-05, 'epoch': 9.23}


                                                 
  1%|          | 91/9750 [15:22<12:49, 12.55it/s]  

{'loss': 0.0114, 'grad_norm': 3.2196811883267173e-10, 'learning_rate': 2.564102564102564e-05, 'epoch': 9.74}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                                

  1%|          | 91/9750 [15:58<12:49, 12.55it/s]
[A
[A

{'eval_loss': 0.07166410982608795, 'eval_accuracy': {'accuracy': 0.9952153110047847}, 'eval_runtime': 14.9773, 'eval_samples_per_second': 111.635, 'eval_steps_per_second': 27.909, 'epoch': 10.0}


                                                 
100%|██████████| 9750/9750 [15:34<00:00, 10.44it/s]

{'train_runtime': 934.138, 'train_samples_per_second': 41.75, 'train_steps_per_second': 10.437, 'train_loss': 0.03918485010587252, 'epoch': 10.0}





TrainOutput(global_step=9750, training_loss=0.03918485010587252, metrics={'train_runtime': 934.138, 'train_samples_per_second': 41.75, 'train_steps_per_second': 10.437, 'total_flos': 1.0341106274304e+16, 'train_loss': 0.03918485010587252, 'epoch': 10.0})

In [24]:
model.to("cuda:0")

print("Trained Model Predictions:")
for text in example_list:
    inputs = tokenizer.encode(text, return_tensors="pt").to("cuda:0")
    logits = model(inputs).logits
    predictions = torch.argmax(logits)
    print(text + " - " + id2label[predictions.tolist()])


Trained Model Predictions:
Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat... - ham
Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's - spam
Nah I don't think he goes to usf, he lives around here though - ham
Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free! Call The Mobile Update Co FREE on 08002986030 - spam
